MaxComputeJniScanner.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.maxcompute;
import org.apache.doris.common.jni.JniScanner;
import org.apache.doris.common.jni.vec.ColumnType;
import org.apache.doris.common.maxcompute.MCUtils;
import com.aliyun.odps.Odps;
import com.aliyun.odps.table.configuration.CompressionCodec;
import com.aliyun.odps.table.configuration.ReaderOptions;
import com.aliyun.odps.table.configuration.RestOptions;
import com.aliyun.odps.table.enviroment.Credentials;
import com.aliyun.odps.table.enviroment.EnvironmentSettings;
import com.aliyun.odps.table.read.SplitReader;
import com.aliyun.odps.table.read.TableBatchReadSession;
import com.aliyun.odps.table.read.split.InputSplit;
import com.aliyun.odps.table.read.split.impl.IndexedInputSplit;
import com.aliyun.odps.table.read.split.impl.RowRangeInputSplit;
import com.google.common.base.Strings;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.log4j.Logger;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/**
* MaxComputeJ JniScanner. BE will read data from the scanner object.
*/
public class MaxComputeJniScanner extends JniScanner {
static {
//Set `NullCheckingForGet.NULL_CHECKING_ENABLED` false.
//We will call isNull() before calling getXXX(), so we can set this parameter
// to skip the repeated check of isNull().
System.setProperty("arrow.enable_null_check_for_get", "false");
}
private static final Logger LOG = Logger.getLogger(MaxComputeJniScanner.class);
// 256MB byte budget per scanner batch ��� limits the C++ Block size at the source.
// With large rows (e.g. 585KB/row STRING), batch_size=4096 would create ~2.4GB Blocks.
// The pipeline's AsyncResultWriter queues up to 3 Blocks per instance, and with
// parallel_pipeline_task_num instances, total queue memory = instances * 3 * block_size.
// 256MB keeps queue memory manageable: 5 instances * 3 * 256MB = 3.8GB.
private static final long MAX_BATCH_BYTES = 256 * 1024 * 1024L;
private static final String ACCESS_KEY = "access_key";
private static final String SECRET_KEY = "secret_key";
private static final String ENDPOINT = "endpoint";
private static final String QUOTA = "quota";
private static final String PROJECT = "project";
private static final String TABLE = "table";
private static final String START_OFFSET = "start_offset";
private static final String SPLIT_SIZE = "split_size";
private static final String SESSION_ID = "session_id";
private static final String SCAN_SERIALIZER = "scan_serializer";
private static final String TIME_ZONE = "time_zone";
private static final String CONNECT_TIMEOUT = "connect_timeout";
private static final String READ_TIMEOUT = "read_timeout";
private static final String RETRY_COUNT = "retry_count";
private enum SplitType {
BYTE_SIZE,
ROW_OFFSET
}
private SplitType splitType;
private TableBatchReadSession scan;
public String sessionId;
private String project;
private String table;
private SplitReader<VectorSchemaRoot> currentSplitReader;
private VectorSchemaRoot currentBatch = null;
private int currentBatchRowOffset = 0;
private MaxComputeColumnValue columnValue;
private Map<String, Integer> readColumnsToId;
private long startOffset = -1L;
private long splitSize = -1L;
public EnvironmentSettings settings;
public ZoneId timeZone;
public MaxComputeJniScanner(int batchSize, Map<String, String> params) {
String[] requiredFields = params.get("required_fields").split(",");
String[] types = params.get("columns_types").split("#");
ColumnType[] columnTypes = new ColumnType[types.length];
for (int i = 0; i < types.length; i++) {
columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
}
initTableInfo(columnTypes, requiredFields, batchSize);
if (!Strings.isNullOrEmpty(params.get(START_OFFSET))
&& !Strings.isNullOrEmpty(params.get(SPLIT_SIZE))) {
startOffset = Long.parseLong(params.get(START_OFFSET));
splitSize = Long.parseLong(params.get(SPLIT_SIZE));
if (splitSize == -1) {
splitType = SplitType.BYTE_SIZE;
} else {
splitType = SplitType.ROW_OFFSET;
}
}
String endpoint = Objects.requireNonNull(params.get(ENDPOINT), "required property '" + ENDPOINT + "'.");
String quota = Objects.requireNonNull(params.get(QUOTA), "required property '" + QUOTA + "'.");
String scanSerializer = Objects.requireNonNull(params.get(SCAN_SERIALIZER),
"required property '" + SCAN_SERIALIZER + "'.");
project = Objects.requireNonNull(params.get(PROJECT), "required property '" + PROJECT + "'.");
table = Objects.requireNonNull(params.get(TABLE), "required property '" + TABLE + "'.");
sessionId = Objects.requireNonNull(params.get(SESSION_ID), "required property '" + SESSION_ID + "'.");
String timeZoneName = Objects.requireNonNull(params.get(TIME_ZONE), "required property '" + TIME_ZONE + "'.");
try {
timeZone = ZoneId.of(timeZoneName);
} catch (Exception e) {
LOG.warn(e.getMessage() + " Set timeZoneName = " + timeZoneName + "fail, use systemDefault.");
timeZone = ZoneId.systemDefault();
}
Odps odps = MCUtils.createMcClient(params);
odps.setDefaultProject(project);
odps.setEndpoint(endpoint);
Credentials credentials = Credentials.newBuilder().withAccount(odps.getAccount())
.withAppAccount(odps.getAppAccount()).build();
int connectTimeout = 10; // 10s
if (!Strings.isNullOrEmpty(params.get(CONNECT_TIMEOUT))) {
connectTimeout = Integer.parseInt(params.get(CONNECT_TIMEOUT));
}
int readTimeout = 120; // 120s
if (!Strings.isNullOrEmpty(params.get(READ_TIMEOUT))) {
readTimeout = Integer.parseInt(params.get(READ_TIMEOUT));
}
int retryTimes = 4; // 4 times
if (!Strings.isNullOrEmpty(params.get(RETRY_COUNT))) {
retryTimes = Integer.parseInt(params.get(RETRY_COUNT));
}
RestOptions restOptions = RestOptions.newBuilder()
.withConnectTimeout(connectTimeout)
.withReadTimeout(readTimeout)
.withRetryTimes(retryTimes).build();
settings = EnvironmentSettings.newBuilder()
.withCredentials(credentials)
.withServiceEndpoint(odps.getEndpoint())
.withQuotaName(quota)
.withRestOptions(restOptions)
.build();
try {
scan = (TableBatchReadSession) deserialize(scanSerializer);
} catch (Exception e) {
String errorMsg = "Failed to deserialize table batch read session.";
LOG.warn(errorMsg, e);
throw new IllegalArgumentException(errorMsg, e);
}
}
@Override
protected void initTableInfo(ColumnType[] requiredTypes, String[] requiredFields, int batchSize) {
super.initTableInfo(requiredTypes, requiredFields, batchSize);
readColumnsToId = new HashMap<>();
for (int i = 0; i < fields.length; i++) {
if (!Strings.isNullOrEmpty(fields[i])) {
readColumnsToId.put(fields[i], i);
}
}
}
@Override
public void open() throws IOException {
try {
InputSplit split;
if (splitType == SplitType.BYTE_SIZE) {
split = new IndexedInputSplit(sessionId, (int) startOffset);
} else {
split = new RowRangeInputSplit(sessionId, startOffset, splitSize);
}
currentSplitReader = scan.createArrowReader(split, ReaderOptions.newBuilder().withSettings(settings)
.withCompressionCodec(CompressionCodec.ZSTD)
.withReuseBatch(true)
.build());
} catch (Exception e) {
String errorMsg = "MaxComputeJniScanner Failed to open table batch read session.";
LOG.warn(errorMsg, e);
close();
throw new IOException(errorMsg, e);
}
}
@Override
public void close() throws IOException {
if (currentSplitReader != null) {
try {
currentSplitReader.close();
} catch (Exception e) {
LOG.warn("Failed to close MaxCompute split reader for table " + project + "." + table, e);
}
}
startOffset = -1;
splitSize = -1;
currentBatch = null;
currentBatchRowOffset = 0;
currentSplitReader = null;
settings = null;
scan = null;
readColumnsToId.clear();
}
@Override
protected int getNext() throws IOException {
if (currentSplitReader == null) {
return 0;
}
columnValue = new MaxComputeColumnValue();
columnValue.setTimeZone(timeZone);
int expectedRows = batchSize;
return readVectors(expectedRows);
}
private VectorSchemaRoot getNextBatch() throws IOException {
try {
if (!currentSplitReader.hasNext()) {
currentSplitReader.close();
currentSplitReader = null;
return null;
}
return currentSplitReader.get();
} catch (Exception e) {
String errorMsg = "MaxComputeJniScanner readVectors get batch fail";
LOG.warn(errorMsg, e);
throw new IOException(e.getMessage(), e);
}
}
private int readVectors(int expectedRows) throws IOException {
int curReadRows = 0;
long accumulatedBytes = 0;
while (curReadRows < expectedRows) {
// Stop early if accumulated variable-width bytes approach int32 limit
if (accumulatedBytes >= MAX_BATCH_BYTES) {
break;
}
if (currentBatch == null) {
currentBatch = getNextBatch();
if (currentBatch == null || currentBatch.getRowCount() == 0) {
currentBatch = null;
break;
}
currentBatchRowOffset = 0;
}
try {
int rowsToAppend = Math.min(expectedRows - curReadRows,
currentBatch.getRowCount() - currentBatchRowOffset);
List<FieldVector> fieldVectors = currentBatch.getFieldVectors();
// Limit rows to avoid int32 overflow in VectorColumn's String byte buffer
rowsToAppend = limitRowsByVarWidthBytes(
fieldVectors, currentBatchRowOffset, rowsToAppend,
MAX_BATCH_BYTES - accumulatedBytes);
if (rowsToAppend <= 0) {
break;
}
long startTime = System.nanoTime();
for (FieldVector column : fieldVectors) {
Integer readColumnId = readColumnsToId.get(column.getName());
if (readColumnId == null) {
continue;
}
columnValue.reset(column);
for (int j = currentBatchRowOffset; j < currentBatchRowOffset + rowsToAppend; j++) {
columnValue.setColumnIdx(j);
appendData(readColumnId, columnValue);
}
}
appendDataTime += System.nanoTime() - startTime;
// Track bytes for the rows just appended
accumulatedBytes += estimateVarWidthBytes(
fieldVectors, currentBatchRowOffset, rowsToAppend);
currentBatchRowOffset += rowsToAppend;
curReadRows += rowsToAppend;
if (currentBatchRowOffset >= currentBatch.getRowCount()) {
currentBatch = null;
currentBatchRowOffset = 0;
}
} catch (Exception e) {
String errorMsg = String.format("MaxComputeJniScanner Fail to read arrow data. "
+ "curReadRows = {}, expectedRows = {}", curReadRows, expectedRows);
LOG.warn(errorMsg, e);
throw new RuntimeException(errorMsg, e);
}
}
if (LOG.isDebugEnabled() && curReadRows > 0 && curReadRows < expectedRows) {
LOG.debug("readVectors: returning " + curReadRows + " rows (limited by byte budget)"
+ ", totalVarWidthBytes=" + accumulatedBytes
+ ", expectedRows=" + expectedRows);
}
return curReadRows;
}
/**
* Limit the number of rows to append so that no single variable-width column
* exceeds the remaining byte budget. This prevents int32 overflow in
* VectorColumn's appendIndex for String/Binary child byte arrays.
*
* Uses Arrow's offset buffer for O(1)-per-row byte size calculation ���
* no data copying involved.
*/
private int limitRowsByVarWidthBytes(List<FieldVector> fieldVectors,
int offset, int maxRows, long remainingBudget) {
if (remainingBudget <= 0) {
return 0;
}
int safeRows = maxRows;
for (FieldVector fv : fieldVectors) {
if (fv instanceof BaseVariableWidthVector) {
BaseVariableWidthVector vec = (BaseVariableWidthVector) fv;
// Find how many rows fit within the budget for THIS column
int rows = findMaxRowsWithinBudget(vec, offset, maxRows, remainingBudget);
safeRows = Math.min(safeRows, rows);
}
}
// Always allow at least 1 row to make progress, even if it exceeds budget
return Math.max(1, safeRows);
}
/**
* Binary search for the maximum number of rows starting at 'offset'
* whose total bytes in the variable-width vector fit within 'budget'.
*/
private int findMaxRowsWithinBudget(BaseVariableWidthVector vec,
int offset, int maxRows, long budget) {
if (maxRows <= 0) {
return 0;
}
// Total bytes for all maxRows
long totalBytes = (long) vec.getOffsetBuffer().getInt((long) (offset + maxRows) * 4)
- (long) vec.getOffsetBuffer().getInt((long) offset * 4);
if (totalBytes <= budget) {
return maxRows;
}
// Binary search for the cutoff point
int lo = 1;
int hi = maxRows - 1;
int startOff = vec.getOffsetBuffer().getInt((long) offset * 4);
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
long bytes = (long) vec.getOffsetBuffer().getInt((long) (offset + mid) * 4) - startOff;
if (bytes <= budget) {
lo = mid + 1;
} else {
hi = mid - 1;
}
}
// 'hi' is the largest count whose bytes <= budget (could be 0)
return hi;
}
/**
* Estimate total variable-width bytes for the given row range across all columns.
* Returns the max bytes of any single column (since each column has its own
* VectorColumn child buffer and the overflow is per-column).
*/
private long estimateVarWidthBytes(List<FieldVector> fieldVectors,
int offset, int rows) {
long maxColumnBytes = 0;
for (FieldVector fv : fieldVectors) {
if (fv instanceof BaseVariableWidthVector) {
BaseVariableWidthVector vec = (BaseVariableWidthVector) fv;
long bytes = (long) vec.getOffsetBuffer().getInt((long) (offset + rows) * 4)
- (long) vec.getOffsetBuffer().getInt((long) offset * 4);
maxColumnBytes = Math.max(maxColumnBytes, bytes);
}
}
return maxColumnBytes;
}
private static Object deserialize(String serializedString) throws IOException, ClassNotFoundException {
byte[] serializedBytes = Base64.getDecoder().decode(serializedString);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serializedBytes);
ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
return objectInputStream.readObject();
}
}