InitMaterializationContextHook.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.catalog.AggStateType;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MTMV;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.mtmv.BaseTableInfo;
import org.apache.doris.mtmv.MTMVCache;
import org.apache.doris.mtmv.MTMVPlanUtil;
import org.apache.doris.mtmv.MTMVUtil;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.PlannerHook;
import org.apache.doris.nereids.hint.Hint;
import org.apache.doris.nereids.hint.UseMvHint;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.qe.ConnectContext;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* If enable query rewrite with mv, should init materialization context after analyze
*/
public class InitMaterializationContextHook implements PlannerHook {
public static final Logger LOG = LogManager.getLogger(InitMaterializationContextHook.class);
public static final InitMaterializationContextHook INSTANCE = new InitMaterializationContextHook();
@Override
public void afterRewrite(NereidsPlanner planner) {
initMaterializationContext(planner.getCascadesContext());
}
@VisibleForTesting
public void initMaterializationContext(CascadesContext cascadesContext) {
if (!cascadesContext.getConnectContext().getSessionVariable().isEnableMaterializedViewRewrite()) {
return;
}
doInitMaterializationContext(cascadesContext);
}
/**
* Init materialization context
* @param cascadesContext current cascadesContext in the planner
*/
protected void doInitMaterializationContext(CascadesContext cascadesContext) {
if (cascadesContext.getConnectContext().getSessionVariable().isInDebugMode()) {
LOG.info("MaterializationContext init return because is in debug mode, current queryId is {}",
cascadesContext.getConnectContext().getQueryIdentifier());
return;
}
Set<TableIf> collectedTables = Sets.newHashSet(cascadesContext.getStatementContext().getTables().values());
if (collectedTables.isEmpty()) {
return;
}
// Create sync materialization context
if (cascadesContext.getConnectContext().getSessionVariable()
.isEnableSyncMvCostBasedRewrite()) {
for (TableIf tableIf : collectedTables) {
if (tableIf instanceof OlapTable) {
for (MaterializationContext context : createSyncMvContexts(
(OlapTable) tableIf, cascadesContext)) {
cascadesContext.addMaterializationContext(context);
}
}
}
}
// Create async materialization context
for (MaterializationContext context : createAsyncMaterializationContext(cascadesContext,
collectedTables)) {
cascadesContext.addMaterializationContext(context);
}
}
private List<MaterializationContext> getMvIdWithUseMvHint(List<MaterializationContext> mtmvCtxs,
UseMvHint useMvHint) {
List<MaterializationContext> hintMTMVs = new ArrayList<>();
for (MaterializationContext mtmvCtx : mtmvCtxs) {
List<String> mvQualifier = mtmvCtx.generateMaterializationIdentifier();
if (useMvHint.getUseMvTableColumnMap().containsKey(mvQualifier)) {
hintMTMVs.add(mtmvCtx);
}
}
return hintMTMVs;
}
private List<MaterializationContext> getMvIdWithNoUseMvHint(List<MaterializationContext> mtmvCtxs,
UseMvHint useMvHint) {
List<MaterializationContext> hintMTMVs = new ArrayList<>();
if (useMvHint.isAllMv()) {
useMvHint.setStatus(Hint.HintStatus.SUCCESS);
return hintMTMVs;
}
for (MaterializationContext mtmvCtx : mtmvCtxs) {
List<String> mvQualifier = mtmvCtx.generateMaterializationIdentifier();
if (useMvHint.getNoUseMvTableColumnMap().containsKey(mvQualifier)) {
useMvHint.setStatus(Hint.HintStatus.SUCCESS);
useMvHint.getNoUseMvTableColumnMap().put(mvQualifier, true);
} else {
hintMTMVs.add(mtmvCtx);
}
}
return hintMTMVs;
}
/**
* get mtmvs by hint
* @param mtmvCtxs input mtmvs which could be used to rewrite sql
* @return set of mtmvs which pass the check of useMvHint
*/
public List<MaterializationContext> getMaterializationContextByHint(List<MaterializationContext> mtmvCtxs) {
Optional<UseMvHint> useMvHint = ConnectContext.get().getStatementContext().getUseMvHint("USE_MV");
Optional<UseMvHint> noUseMvHint = ConnectContext.get().getStatementContext().getUseMvHint("NO_USE_MV");
if (!useMvHint.isPresent() && !noUseMvHint.isPresent()) {
return mtmvCtxs;
}
List<MaterializationContext> result = mtmvCtxs;
if (noUseMvHint.isPresent()) {
result = getMvIdWithNoUseMvHint(result, noUseMvHint.get());
}
if (useMvHint.isPresent()) {
result = getMvIdWithUseMvHint(result, useMvHint.get());
}
return result;
}
protected Set<MTMV> getAvailableMTMVs(Set<TableIf> usedTables, CascadesContext cascadesContext) {
List<BaseTableInfo> usedBaseTables =
usedTables.stream().map(BaseTableInfo::new).collect(Collectors.toList());
return Env.getCurrentEnv().getMtmvService().getRelationManager()
.getAvailableMTMVs(usedBaseTables, cascadesContext.getConnectContext(),
false, ((connectContext, mtmv) -> {
return MTMVUtil.mtmvContainsExternalTable(mtmv) && (!connectContext.getSessionVariable()
.isEnableMaterializedViewRewriteWhenBaseTableUnawareness());
}));
}
private List<MaterializationContext> createAsyncMaterializationContext(CascadesContext cascadesContext,
Set<TableIf> usedTables) {
Set<MTMV> availableMTMVs;
try {
availableMTMVs = getAvailableMTMVs(usedTables, cascadesContext);
} catch (Exception e) {
LOG.warn(String.format("MaterializationContext getAvailableMTMVs generate fail, current sqlHash is %s",
cascadesContext.getConnectContext().getSqlHash()), e);
return ImmutableList.of();
}
if (CollectionUtils.isEmpty(availableMTMVs)) {
LOG.debug("Enable materialized view rewrite but availableMTMVs is empty, current sqlHash "
+ "is {}", cascadesContext.getConnectContext().getSqlHash());
return ImmutableList.of();
}
List<MaterializationContext> asyncMaterializationContext = new ArrayList<>();
for (MTMV materializedView : availableMTMVs) {
MTMVCache mtmvCache = null;
try {
mtmvCache = materializedView.getOrGenerateCache(cascadesContext.getConnectContext());
if (mtmvCache == null) {
continue;
}
// For async materialization context, the cascades context when construct the struct info maybe
// different from the current cascadesContext
// so regenerate the struct info table bitset
StructInfo mvStructInfo = mtmvCache.getStructInfo();
BitSet tableBitSetInCurrentCascadesContext = new BitSet();
mvStructInfo.getRelations().forEach(relation -> tableBitSetInCurrentCascadesContext.set(
cascadesContext.getStatementContext().getTableId(relation.getTable()).asInt()));
asyncMaterializationContext.add(new AsyncMaterializationContext(materializedView,
mtmvCache.getLogicalPlan(), mtmvCache.getOriginalPlan(), ImmutableList.of(),
ImmutableList.of(), cascadesContext,
mtmvCache.getStructInfo().withTableBitSet(tableBitSetInCurrentCascadesContext)));
} catch (Exception e) {
LOG.warn(String.format("MaterializationContext init mv cache generate fail, current queryId is %s",
cascadesContext.getConnectContext().getQueryIdentifier()), e);
}
}
return getMaterializationContextByHint(asyncMaterializationContext);
}
private List<MaterializationContext> createSyncMvContexts(OlapTable olapTable,
CascadesContext cascadesContext) {
int indexNumber = olapTable.getIndexNumber();
List<MaterializationContext> contexts = new ArrayList<>(indexNumber);
long baseIndexId = olapTable.getBaseIndexId();
int keyCount = 0;
for (Column column : olapTable.getFullSchema()) {
keyCount += column.isKey() ? 1 : 0;
}
for (Map.Entry<Long, MaterializedIndexMeta> entry : olapTable.getVisibleIndexIdToMeta().entrySet()) {
long indexId = entry.getKey();
String indexName = olapTable.getIndexNameById(indexId);
try {
if (indexId != baseIndexId) {
MaterializedIndexMeta meta = entry.getValue();
String createMvSql;
if (meta.getDefineStmt() != null) {
// get the original create mv sql
createMvSql = meta.getDefineStmt().originStmt;
} else {
// it's rollup, need assemble create mv sql manually
if (olapTable.getKeysType() == KeysType.AGG_KEYS) {
createMvSql = assembleCreateMvSqlForAggTable(olapTable.getQualifiedName(),
indexName, meta.getSchema(false), keyCount);
} else {
createMvSql =
assembleCreateMvSqlForDupOrUniqueTable(olapTable.getQualifiedName(),
indexName, meta.getSchema(false));
}
}
if (createMvSql != null) {
Optional<String> querySql =
new NereidsParser().parseForSyncMv(createMvSql);
if (!querySql.isPresent()) {
LOG.warn(String.format("can't parse %s ", createMvSql));
continue;
}
ConnectContext basicMvContext = MTMVPlanUtil.createBasicMvContext(
cascadesContext.getConnectContext());
basicMvContext.setDatabase(meta.getDbName());
MTMVCache mtmvCache = MTMVCache.from(querySql.get(),
basicMvContext, true,
false, cascadesContext.getConnectContext());
contexts.add(new SyncMaterializationContext(mtmvCache.getLogicalPlan(),
mtmvCache.getOriginalPlan(), olapTable, meta.getIndexId(), indexName,
cascadesContext, mtmvCache.getStatistics()));
} else {
LOG.warn(String.format("can't assemble create mv sql for index ", indexName));
}
}
} catch (Exception exception) {
LOG.warn(String.format("createSyncMvContexts exception, index id is %s, index name is %s, "
+ "table name is %s", entry.getValue(), indexName, olapTable.getQualifiedName()),
exception);
}
}
return getMaterializationContextByHint(contexts);
}
private String assembleCreateMvSqlForDupOrUniqueTable(String baseTableName, String mvName, List<Column> columns) {
StringBuilder createMvSqlBuilder = new StringBuilder();
createMvSqlBuilder.append(String.format("create materialized view %s as select ", mvName));
for (Column col : columns) {
createMvSqlBuilder.append(String.format("%s, ", col.getName()));
}
removeLastTwoChars(createMvSqlBuilder);
createMvSqlBuilder.append(String.format(" from %s", baseTableName));
return createMvSqlBuilder.toString();
}
private String assembleCreateMvSqlForAggTable(String baseTableName, String mvName,
List<Column> columns, int keyCount) {
StringBuilder createMvSqlBuilder = new StringBuilder();
createMvSqlBuilder.append(String.format("create materialized view %s as select ", mvName));
int mvKeyCount = 0;
for (Column column : columns) {
mvKeyCount += column.isKey() ? 1 : 0;
}
if (mvKeyCount < keyCount) {
StringBuilder keyColumnsStringBuilder = new StringBuilder();
StringBuilder aggColumnsStringBuilder = new StringBuilder();
for (Column col : columns) {
AggregateType aggregateType = col.getAggregationType();
if (aggregateType != null) {
switch (aggregateType) {
case SUM:
case MAX:
case MIN:
case HLL_UNION:
case BITMAP_UNION:
case QUANTILE_UNION: {
aggColumnsStringBuilder
.append(String.format("%s(%s), ", aggregateType, col.getName()));
break;
}
case GENERIC: {
AggStateType aggStateType = (AggStateType) col.getType();
aggColumnsStringBuilder.append(String.format("%s_union(%s), ",
aggStateType.getFunctionName(), col.getName()));
break;
}
default: {
// mv agg columns mustn't be NONE, REPLACE, REPLACE_IF_NOT_NULL agg type
LOG.warn(String.format("mv agg column %s mustn't be %s type",
col.getName(), aggregateType));
return null;
}
}
} else {
// use column name for key
Preconditions.checkState(col.isKey(),
String.format("%s must be key", col.getName()));
keyColumnsStringBuilder.append(String.format("%s, ", col.getName()));
}
}
Preconditions.checkState(keyColumnsStringBuilder.length() > 0,
"must contain at least one key column in rollup");
if (aggColumnsStringBuilder.length() > 0) {
removeLastTwoChars(aggColumnsStringBuilder);
} else {
removeLastTwoChars(keyColumnsStringBuilder);
}
createMvSqlBuilder.append(keyColumnsStringBuilder);
createMvSqlBuilder.append(aggColumnsStringBuilder);
if (aggColumnsStringBuilder.length() > 0) {
// all key columns should be group by keys, so remove the last ", " characters
removeLastTwoChars(keyColumnsStringBuilder);
}
createMvSqlBuilder.append(
String.format(" from %s group by %s", baseTableName, keyColumnsStringBuilder));
} else {
for (Column col : columns) {
createMvSqlBuilder.append(String.format("%s, ", col.getName()));
}
removeLastTwoChars(createMvSqlBuilder);
createMvSqlBuilder.append(String.format(" from %s", baseTableName));
}
return createMvSqlBuilder.toString();
}
private void removeLastTwoChars(StringBuilder stringBuilder) {
if (stringBuilder.length() >= 2) {
stringBuilder.delete(stringBuilder.length() - 2, stringBuilder.length());
}
}
}