StmtExecutionAction.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.httpv2.rest;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Config;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.httpv2.entity.ResponseEntityBuilder;
import org.apache.doris.httpv2.util.ExecutionResultSet;
import org.apache.doris.httpv2.util.StatementSubmitter;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.system.SystemInfoService;

import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import java.lang.reflect.Type;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * For execute stmt or get create table stmt via http
 */
@RestController
public class StmtExecutionAction extends RestBaseController {
    private static final Logger LOG = LogManager.getLogger(StmtExecutionAction.class);
    private static StatementSubmitter stmtSubmitter = new StatementSubmitter();
    private static final String  NEW_LINE_PATTERN = "[\n\r]";

    private static final String NEW_LINE_REPLACEMENT = " ";

    private static final long DEFAULT_ROW_LIMIT = 1000;
    private static final long MAX_ROW_LIMIT = 10000;

    /**
     * Execute a SQL.
     * Request body:
     * {
     * "is_sync": 1,   // optional
     * "limit" : 1000  // optional
     * "stmt" : "select * from tbl1"   // required
     * }
     */
    @RequestMapping(path = "/api/query/{" + NS_KEY + "}/{" + DB_KEY + "}", method = {RequestMethod.POST})
    public Object executeSQL(@PathVariable(value = NS_KEY) String ns, @PathVariable(value = DB_KEY) String dbName,
            HttpServletRequest request, HttpServletResponse response, @RequestBody String body) {
        if (needRedirect(request.getScheme())) {
            return redirectToHttps(request);
        }

        ActionAuthorizationInfo authInfo = checkWithCookie(request, response, false);
        String fullDbName = getFullDbName(dbName);
        if (Config.enable_all_http_auth) {
            checkDbAuth(ConnectContext.get().getCurrentUserIdentity(), fullDbName, PrivPredicate.ADMIN);
        }

        if (ns.equalsIgnoreCase(SystemInfoService.DEFAULT_CLUSTER)) {
            ns = InternalCatalog.INTERNAL_CATALOG_NAME;
        }

        Type type = new TypeToken<StmtRequestBody>() {
        }.getType();
        StmtRequestBody stmtRequestBody = new Gson().fromJson(body, type);

        if (Strings.isNullOrEmpty(stmtRequestBody.stmt)) {
            return ResponseEntityBuilder.badRequest("Missing statement request body");
        }
        LOG.info("stmt: {}, isSync:{}, limit: {}", stmtRequestBody.stmt, stmtRequestBody.is_sync,
                stmtRequestBody.limit);

        ConnectContext.get().changeDefaultCatalog(ns);
        ConnectContext.get().setDatabase(fullDbName);

        String streamHeader = request.getHeader("X-Doris-Stream");
        boolean isStream = !("false".equalsIgnoreCase(streamHeader));
        return executeQuery(authInfo, stmtRequestBody.is_sync, stmtRequestBody.limit, stmtRequestBody,
                            response, isStream);
    }


    /**
     * Get all create table stmt of a SQL
     *
     * @param ns
     * @param dbName
     * @param request
     * @param response
     * @param sql plain text of sql
     * @return plain text of create table stmts
     */
    @RequestMapping(path = "/api/query_schema/{" + NS_KEY + "}/{" + DB_KEY + "}", method = {RequestMethod.POST})
    public Object querySchema(@PathVariable(value = NS_KEY) String ns, @PathVariable(value = DB_KEY) String dbName,
            HttpServletRequest request, HttpServletResponse response, @RequestBody String sql) {
        if (needRedirect(request.getScheme())) {
            return redirectToHttps(request);
        }

        checkWithCookie(request, response, false);

        if (ns.equalsIgnoreCase(SystemInfoService.DEFAULT_CLUSTER)) {
            ns = InternalCatalog.INTERNAL_CATALOG_NAME;
        }
        if (StringUtils.isNotBlank(sql)) {
            sql = sql.replaceAll(NEW_LINE_PATTERN, NEW_LINE_REPLACEMENT);
        }
        LOG.info("sql: {}", sql);
        ConnectContext.get().changeDefaultCatalog(ns);
        ConnectContext.get().setDatabase(getFullDbName(dbName));
        return getSchema(sql);
    }

    /**
     * Execute a query
     *
     * @param authInfo
     * @param isSync
     * @param limit
     * @param stmtRequestBody
     * @return
     */
    private ResponseEntity executeQuery(ActionAuthorizationInfo authInfo, boolean isSync, long limit,
            StmtRequestBody stmtRequestBody, HttpServletResponse response, boolean isStream) {
        StatementSubmitter.StmtContext stmtCtx = new StatementSubmitter.StmtContext(stmtRequestBody.stmt,
                authInfo.fullUserName, authInfo.password, limit, isStream, response, "");
        Future<ExecutionResultSet> future = stmtSubmitter.submit(stmtCtx);

        if (isSync) {
            try {
                ExecutionResultSet resultSet = future.get();
                // if use stream response, we not need to response an object.
                if (isStream) {
                    return null;
                }
                return ResponseEntityBuilder.ok(resultSet.getResult());
            } catch (InterruptedException | ExecutionException e) {
                LOG.warn("failed to execute stmt", e);
                return ResponseEntityBuilder.okWithCommonError("Failed to execute sql: " + e.getMessage());
            }
        } else {
            return ResponseEntityBuilder.okWithCommonError("Not support async query execution");
        }
    }

    @NotNull
    private String getSchema(String sql) {
        LogicalPlan unboundMvPlan = new NereidsParser().parseSingle(sql);
        try (StatementContext statementContext = new StatementContext(ConnectContext.get(),
                new OriginStatement(sql, 0))) {
            StatementContext originalContext = ConnectContext.get().getStatementContext();
            try {
                ConnectContext.get().setStatementContext(statementContext);
                NereidsPlanner planner = new NereidsPlanner(statementContext);
                planner.planWithLock(unboundMvPlan, PhysicalProperties.ANY, ExplainCommand.ExplainLevel.ANALYZED_PLAN);
                LogicalPlan logicalPlan = (LogicalPlan) planner.getCascadesContext().getRewritePlan();

                List<String> createStmts = PlanUtils.getLogicalScanFromRootPlan(logicalPlan).stream().map(plan -> {
                    TableIf tbl = plan.getTable();
                    List<String> createTableStmts = Lists.newArrayList();
                    Env.getDdlStmt(tbl, createTableStmts, null, null, false, true, -1L);
                    return createTableStmts.get(0);
                }).collect(Collectors.toList());
                return Joiner.on("\n\n").join(createStmts);
            } finally {
                ConnectContext.get().setStatementContext(originalContext);
            }
        }
    }

    private static class StmtRequestBody {
        public Boolean is_sync = true; // CHECKSTYLE IGNORE THIS LINE
        public Long limit = DEFAULT_ROW_LIMIT;
        public String stmt;
    }
}