ScannerLoader.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.classloader;
import org.apache.doris.common.jni.utils.ExpiringMap;
import org.apache.doris.common.jni.utils.Log4jOutputStream;
import org.apache.doris.common.jni.utils.UdfClassCache;
import com.google.common.collect.Streams;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
/**
* BE will load scanners by JNI call, and then the JniConnector on BE will get scanner class by getLoadedClass.
*/
public class ScannerLoader {
public static final Logger LOG = Logger.getLogger(ScannerLoader.class);
private static final Map<String, Class<?>> loadedClasses = new HashMap<>();
private static final ExpiringMap<String, UdfClassCache> udfLoadedClasses = new ExpiringMap<>();
private static final String CLASS_SUFFIX = ".class";
private static final String LOAD_PACKAGE = "org.apache.doris";
/**
* Load all classes from $DORIS_HOME/lib/java_extensions/*
*/
public void loadAllScannerJars() {
redirectStdStreamsToLog4j();
String basePath = System.getenv("DORIS_HOME");
File library = new File(basePath, "/lib/java_extensions/");
// TODO: add thread pool to load each scanner
listFiles(library).stream().filter(File::isDirectory).forEach(sd -> {
JniScannerClassLoader classLoader = new JniScannerClassLoader(sd.getName(), buildClassPath(sd),
this.getClass().getClassLoader());
try (ThreadClassLoaderContext ignored = new ThreadClassLoaderContext(classLoader)) {
loadJarClassFromDir(sd, classLoader);
}
});
}
private void redirectStdStreamsToLog4j() {
Logger outLogger = Logger.getLogger("stdout");
PrintStream logPrintStream = new PrintStream(new Log4jOutputStream(outLogger, Level.INFO));
System.setOut(logPrintStream);
Logger errLogger = Logger.getLogger("stderr");
PrintStream errorPrintStream = new PrintStream(new Log4jOutputStream(errLogger, Level.ERROR));
System.setErr(errorPrintStream);
}
public static UdfClassCache getUdfClassLoader(String functionSignature) {
return udfLoadedClasses.get(functionSignature);
}
public static synchronized void cacheClassLoader(String functionSignature, UdfClassCache classCache,
long expirationTime) {
LOG.info("Cache UDF for: " + functionSignature);
udfLoadedClasses.put(functionSignature, classCache, expirationTime * 60 * 1000L);
}
public synchronized void cleanUdfClassLoader(String functionSignature) {
LOG.info("cleanUdfClassLoader for: " + functionSignature);
udfLoadedClasses.remove(functionSignature);
}
/**
* Get loaded class for JNI scanners
* @param className JNI scanner class name
* @return scanner class object
* @throws ClassNotFoundException JNI scanner class not found
*/
public Class<?> getLoadedClass(String className) throws ClassNotFoundException {
String loadedClassName = getPackagePathName(className);
if (loadedClasses.containsKey(loadedClassName)) {
return loadedClasses.get(loadedClassName);
} else {
throw new ClassNotFoundException("JNI scanner has not been loaded or no such class: " + className);
}
}
private static List<URL> buildClassPath(File path) {
return listFiles(path).stream()
.map(ScannerLoader::classFileUrl)
.collect(Collectors.toList());
}
private static URL classFileUrl(File file) {
try {
return file.toURI().toURL();
} catch (MalformedURLException e) {
throw new UncheckedIOException(e);
}
}
public static List<File> listFiles(File library) {
try (DirectoryStream<Path> directoryStream = Files.newDirectoryStream(library.toPath())) {
return Streams.stream(directoryStream)
.map(Path::toFile)
.sorted()
.collect(Collectors.toList());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static void loadJarClassFromDir(File dir, JniScannerClassLoader classLoader) {
listFiles(dir).forEach(file -> {
Enumeration<JarEntry> entryEnumeration;
List<String> loadClassNames = new ArrayList<>();
try {
try (JarFile jar = new JarFile(file)) {
entryEnumeration = jar.entries();
while (entryEnumeration.hasMoreElements()) {
JarEntry entry = entryEnumeration.nextElement();
String className = entry.getName();
if (!className.endsWith(CLASS_SUFFIX)) {
continue;
}
className = className.substring(0, className.length() - CLASS_SUFFIX.length());
String packageClassName = getPackagePathName(className);
if (needToLoad(packageClassName)) {
loadClassNames.add(packageClassName);
}
}
}
for (String className : loadClassNames) {
loadedClasses.putIfAbsent(className, classLoader.loadClass(className));
}
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
});
}
private static String getPackagePathName(String className) {
return className.replace("/", ".");
}
private static boolean needToLoad(String className) {
return className.contains(LOAD_PACKAGE) && !className.contains("$");
}
}