DorisFunctionRegistry.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/hive/blob/master/hplsql/src/main/java/org/apache/hive/hplsql/functions/HmsFunctionRegistry.java
// and modified by Doris
package org.apache.doris.plsql.functions;
import org.apache.doris.catalog.DatabaseIf;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.nereids.PLLexer;
import org.apache.doris.nereids.PLParser;
import org.apache.doris.nereids.PLParser.Create_function_stmtContext;
import org.apache.doris.nereids.PLParser.Create_procedure_stmtContext;
import org.apache.doris.nereids.PLParser.Expr_func_paramsContext;
import org.apache.doris.nereids.PLParserBaseVisitor;
import org.apache.doris.nereids.parser.CaseInsensitiveStream;
import org.apache.doris.nereids.trees.plans.commands.info.FuncNameInfo;
import org.apache.doris.plsql.Exec;
import org.apache.doris.plsql.Scope;
import org.apache.doris.plsql.Var;
import org.apache.doris.plsql.metastore.PlsqlMetaClient;
import org.apache.doris.plsql.metastore.PlsqlProcedureKey;
import org.apache.doris.plsql.metastore.PlsqlStoredProcedure;
import org.apache.doris.qe.ConnectContext;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
public class DorisFunctionRegistry implements FunctionRegistry {
private final Exec exec;
private final boolean trace;
private final PlsqlMetaClient client;
private final BuiltinFunctions builtinFunctions;
private final Map<String, ParserRuleContext> cache = new HashMap<>();
public DorisFunctionRegistry(Exec e, PlsqlMetaClient client, BuiltinFunctions builtinFunctions) {
this.exec = e;
this.client = client;
this.builtinFunctions = builtinFunctions;
this.trace = exec.getTrace();
}
@Override
public boolean exists(FuncNameInfo procedureName) {
return isCached(procedureName.toString()) || getProc(procedureName).isPresent();
}
@Override
public void remove(FuncNameInfo procedureName) {
try {
client.dropPlsqlStoredProcedure(procedureName.getName(), procedureName.getCtlId(), procedureName.getDbId());
} catch (Exception e) {
throw new RuntimeException("failed to remove procedure", e);
}
}
private boolean isCached(String name) {
return cache.containsKey(qualified(name));
}
@Override
public void removeCached(String name) {
cache.remove(qualified(name));
}
private String qualified(String name) {
return (ConnectContext.get().getDatabase() + "." + name).toUpperCase();
}
private String getDbName(long catalogId, long dbId) {
String dbName = "";
CatalogIf catalog = Env.getCurrentEnv().getCatalogMgr().getCatalog(catalogId);
if (catalog != null) {
DatabaseIf db = catalog.getDbNullable(dbId);
if (db != null) {
dbName = db.getFullName();
}
}
return dbName;
}
public boolean like(String str, String wild) {
str = str.toLowerCase();
return str.matches(wild.replace(".", "\\.").replace("?", ".").replace("%", ".*").toLowerCase());
}
public boolean applyFilter(String value, String filter) {
if (filter.isEmpty()) {
return true;
}
return like(value, filter);
}
@Override
public void showProcedure(List<List<String>> columns, String dbFilter, String procFilter) {
Map<PlsqlProcedureKey, PlsqlStoredProcedure> allProc = client.getAllPlsqlStoredProcedures();
for (Map.Entry<PlsqlProcedureKey, PlsqlStoredProcedure> entry : allProc.entrySet()) {
List<String> row = new ArrayList<>();
PlsqlStoredProcedure proc = entry.getValue();
if (!applyFilter(proc.getName(), procFilter)) {
continue;
}
String dbName = getDbName(proc.getCatalogId(), proc.getDbId());
if (!applyFilter(dbName, dbFilter)) {
continue;
}
row.add(proc.getName());
row.add(Long.toString(proc.getCatalogId()));
row.add(Long.toString(proc.getDbId()));
row.add(dbName);
row.add(proc.getPackageName());
row.add(proc.getOwnerName());
row.add(proc.getCreateTime());
row.add(proc.getModifyTime());
row.add(proc.getSource());
columns.add(row);
}
}
@Override
public void showCreateProcedure(FuncNameInfo procedureName, List<List<String>> columns) {
List<String> row = new ArrayList<>();
PlsqlStoredProcedure proc = client.getPlsqlStoredProcedure(procedureName.getName(),
procedureName.getCtlId(), procedureName.getDbId());
if (proc != null) {
row.add(proc.getName());
row.add(proc.getSource());
columns.add(row);
}
}
@Override
public boolean exec(FuncNameInfo procedureName, Expr_func_paramsContext ctx) {
if (builtinFunctions.exec(procedureName.toString(), ctx)) { // First look for built-in functions.
return true;
}
if (isCached(procedureName.toString())) {
trace(ctx, "EXEC CACHED FUNCTION " + procedureName);
execProcOrFunc(ctx, cache.get(qualified(procedureName.toString())), procedureName.toString());
return true;
}
Optional<PlsqlStoredProcedure> proc = getProc(procedureName);
if (proc.isPresent()) {
trace(ctx, "EXEC HMS FUNCTION " + procedureName);
ParserRuleContext procCtx = parse(proc.get());
execProcOrFunc(ctx, procCtx, procedureName.toString());
saveInCache(procedureName.toString(), procCtx);
return true;
}
return false;
}
/**
* Execute a stored procedure using CALL or EXEC statement passing parameters
*/
private void execProcOrFunc(Expr_func_paramsContext ctx, ParserRuleContext procCtx, String name) {
exec.callStackPush(name);
HashMap<String, Var> out = new HashMap<>();
ArrayList<Var> actualParams = getActualCallParameters(ctx);
exec.enterScope(Scope.Type.ROUTINE);
callWithParameters(ctx, procCtx, out, actualParams);
exec.callStackPop();
exec.leaveScope();
for (Map.Entry<String, Var> i : out.entrySet()) { // Set OUT parameters
exec.setVariable(i.getKey(), i.getValue());
}
}
private void callWithParameters(Expr_func_paramsContext ctx, ParserRuleContext procCtx, HashMap<String, Var> out,
ArrayList<Var> actualParams) {
if (procCtx instanceof Create_function_stmtContext) {
Create_function_stmtContext func = (Create_function_stmtContext) procCtx;
InMemoryFunctionRegistry.setCallParameters(func.multipartIdentifier().getText(), ctx, actualParams,
func.create_routine_params(), null, exec);
if (func.declare_block_inplace() != null) {
exec.visit(func.declare_block_inplace());
}
exec.visit(func.single_block_stmt());
} else {
Create_procedure_stmtContext proc = (Create_procedure_stmtContext) procCtx;
InMemoryFunctionRegistry.setCallParameters(proc.multipartIdentifier().getText(), ctx, actualParams,
proc.create_routine_params(), out, exec);
exec.visit(proc.procedure_block());
}
}
private ParserRuleContext parse(PlsqlStoredProcedure proc) {
PLLexer lexer = new PLLexer(new CaseInsensitiveStream(CharStreams.fromString(proc.getSource())));
CommonTokenStream tokens = new CommonTokenStream(lexer);
PLParser parser = new PLParser(tokens);
ProcedureVisitor visitor = new ProcedureVisitor();
parser.program().accept(visitor);
return visitor.func != null ? visitor.func : visitor.proc;
}
private Optional<PlsqlStoredProcedure> getProc(FuncNameInfo procedureName) {
return Optional.ofNullable(client.getPlsqlStoredProcedure(procedureName.getName(), procedureName.getCtlId(),
procedureName.getDbId()));
}
private ArrayList<Var> getActualCallParameters(Expr_func_paramsContext actual) {
if (actual == null || actual.func_param() == null) {
return null;
}
int cnt = actual.func_param().size();
ArrayList<Var> values = new ArrayList<>(cnt);
for (int i = 0; i < cnt; i++) {
values.add(evalPop(actual.func_param(i).expr()));
}
return values;
}
@Override
public void addUserFunction(Create_function_stmtContext ctx) {
FuncNameInfo procedureName = new FuncNameInfo(
exec.logicalPlanBuilder.visitMultipartIdentifier(ctx.multipartIdentifier()));
if (builtinFunctions.exists(procedureName.toString())) {
exec.info(ctx, procedureName.toString() + " is a built-in function which cannot be redefined.");
return;
}
trace(ctx, "CREATE FUNCTION " + procedureName.toString());
saveInCache(procedureName.toString(), ctx);
save(procedureName, Exec.getFormattedText(ctx), ctx.REPLACE() != null);
}
@Override
public void addUserProcedure(Create_procedure_stmtContext ctx) {
FuncNameInfo procedureName = new FuncNameInfo(
exec.logicalPlanBuilder.visitMultipartIdentifier(ctx.multipartIdentifier()));
if (builtinFunctions.exists(procedureName.toString())) {
exec.info(ctx, procedureName.toString() + " is a built-in function which cannot be redefined.");
return;
}
trace(ctx, "CREATE PROCEDURE " + procedureName.toString());
saveInCache(procedureName.toString(), ctx);
save(procedureName, Exec.getFormattedText(ctx), ctx.REPLACE() != null);
}
@Override
public void save(FuncNameInfo procedureName, String source, boolean isForce) {
try {
String createTime = TimeUtils.longToTimeString(Calendar.getInstance().getTimeInMillis());
String modifyTime = createTime;
if (isForce) {
// need to get create time and use that.
PlsqlStoredProcedure oldProc = client.getPlsqlStoredProcedure(procedureName.getName(),
procedureName.getCtlId(), procedureName.getDbId());
if (oldProc != null) {
createTime = oldProc.getCreateTime();
}
}
// TODO support packageName
client.addPlsqlStoredProcedure(procedureName.getName(), procedureName.getCtlId(), procedureName.getDbId(),
"",
ConnectContext.get().getQualifiedUser(), source, createTime, modifyTime, isForce);
} catch (Exception e) {
throw new RuntimeException("failed to save procedure", e);
}
}
private void saveInCache(String name, ParserRuleContext procCtx) {
// TODO, removeCached needs to be synchronized to all Observer FEs.
// Even if it is always executed on the Master FE, it still has to deal with Master switching.
// cache.put(qualified(name.toUpperCase()), procCtx);
}
/**
* Evaluate the expression and pop value from the stack
*/
private Var evalPop(ParserRuleContext ctx) {
exec.visit(ctx);
return exec.stackPop();
}
private void trace(ParserRuleContext ctx, String message) {
if (trace) {
exec.trace(ctx, message);
}
}
private static class ProcedureVisitor extends PLParserBaseVisitor<Void> {
Create_function_stmtContext func;
Create_procedure_stmtContext proc;
@Override
public Void visitCreate_procedure_stmt(Create_procedure_stmtContext ctx) {
proc = ctx;
return null;
}
@Override
public Void visitCreate_function_stmt(Create_function_stmtContext ctx) {
func = ctx;
return null;
}
}
}