MaxComputeJniScanner.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.maxcompute;

import org.apache.doris.common.jni.JniScanner;
import org.apache.doris.common.jni.vec.ColumnType;
import org.apache.doris.common.maxcompute.MCUtils;

import com.aliyun.odps.Odps;
import com.aliyun.odps.table.configuration.CompressionCodec;
import com.aliyun.odps.table.configuration.ReaderOptions;
import com.aliyun.odps.table.configuration.RestOptions;
import com.aliyun.odps.table.enviroment.Credentials;
import com.aliyun.odps.table.enviroment.EnvironmentSettings;
import com.aliyun.odps.table.read.SplitReader;
import com.aliyun.odps.table.read.TableBatchReadSession;
import com.aliyun.odps.table.read.split.InputSplit;
import com.aliyun.odps.table.read.split.impl.IndexedInputSplit;
import com.aliyun.odps.table.read.split.impl.RowRangeInputSplit;
import com.google.common.base.Strings;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.log4j.Logger;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * MaxComputeJ JniScanner. BE will read data from the scanner object.
 */
public class MaxComputeJniScanner extends JniScanner {
    static {
        //Set `NullCheckingForGet.NULL_CHECKING_ENABLED` false.
        //We will call isNull() before calling getXXX(), so we can set this parameter
        // to skip the repeated check of isNull().
        System.setProperty("arrow.enable_null_check_for_get", "false");
    }

    private static final Logger LOG = Logger.getLogger(MaxComputeJniScanner.class);

    // 256MB byte budget per scanner batch ��� limits the C++ Block size at the source.
    // With large rows (e.g. 585KB/row STRING), batch_size=4096 would create ~2.4GB Blocks.
    // The pipeline's AsyncResultWriter queues up to 3 Blocks per instance, and with
    // parallel_pipeline_task_num instances, total queue memory = instances * 3 * block_size.
    // 256MB keeps queue memory manageable: 5 instances * 3 * 256MB = 3.8GB.
    private static final long MAX_BATCH_BYTES = 256 * 1024 * 1024L;

    private static final String ACCESS_KEY = "access_key";
    private static final String SECRET_KEY = "secret_key";
    private static final String ENDPOINT = "endpoint";
    private static final String QUOTA = "quota";
    private static final String PROJECT = "project";
    private static final String TABLE = "table";

    private static final String START_OFFSET = "start_offset";
    private static final String SPLIT_SIZE = "split_size";
    private static final String SESSION_ID = "session_id";
    private static final String SCAN_SERIALIZER = "scan_serializer";
    private static final String TIME_ZONE = "time_zone";

    private static final String CONNECT_TIMEOUT = "connect_timeout";
    private static final String READ_TIMEOUT = "read_timeout";
    private static final String RETRY_COUNT  = "retry_count";

    private enum SplitType {
        BYTE_SIZE,
        ROW_OFFSET
    }

    private SplitType splitType;
    private TableBatchReadSession scan;
    public  String sessionId;

    private String project;
    private String table;

    private SplitReader<VectorSchemaRoot> currentSplitReader;
    private VectorSchemaRoot currentBatch = null;
    private int currentBatchRowOffset = 0;
    private MaxComputeColumnValue columnValue;

    private Map<String, Integer> readColumnsToId;

    private long startOffset = -1L;
    private long splitSize = -1L;
    public EnvironmentSettings settings;
    public ZoneId timeZone;

    public MaxComputeJniScanner(int batchSize, Map<String, String> params) {
        String[] requiredFields = params.get("required_fields").split(",");
        String[] types = params.get("columns_types").split("#");
        ColumnType[] columnTypes = new ColumnType[types.length];
        for (int i = 0; i < types.length; i++) {
            columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
        }
        initTableInfo(columnTypes, requiredFields, batchSize);

        if (!Strings.isNullOrEmpty(params.get(START_OFFSET))
                && !Strings.isNullOrEmpty(params.get(SPLIT_SIZE))) {
            startOffset = Long.parseLong(params.get(START_OFFSET));
            splitSize = Long.parseLong(params.get(SPLIT_SIZE));
            if (splitSize == -1) {
                splitType = SplitType.BYTE_SIZE;
            } else {
                splitType = SplitType.ROW_OFFSET;
            }
        }

        String endpoint = Objects.requireNonNull(params.get(ENDPOINT), "required property '" + ENDPOINT + "'.");
        String quota = Objects.requireNonNull(params.get(QUOTA), "required property '" + QUOTA + "'.");
        String scanSerializer = Objects.requireNonNull(params.get(SCAN_SERIALIZER),
                "required property '" + SCAN_SERIALIZER + "'.");
        project = Objects.requireNonNull(params.get(PROJECT), "required property '" + PROJECT + "'.");
        table = Objects.requireNonNull(params.get(TABLE), "required property '" + TABLE + "'.");
        sessionId = Objects.requireNonNull(params.get(SESSION_ID), "required property '" + SESSION_ID + "'.");
        String timeZoneName = Objects.requireNonNull(params.get(TIME_ZONE), "required property '" + TIME_ZONE + "'.");
        try {
            timeZone = ZoneId.of(timeZoneName);
        } catch (Exception e) {
            LOG.warn(e.getMessage() + " Set timeZoneName = " + timeZoneName + "fail, use systemDefault.");
            timeZone = ZoneId.systemDefault();
        }

        Odps odps = MCUtils.createMcClient(params);
        odps.setDefaultProject(project);
        odps.setEndpoint(endpoint);

        Credentials credentials = Credentials.newBuilder().withAccount(odps.getAccount())
                .withAppAccount(odps.getAppAccount()).build();


        int connectTimeout = 10; // 10s
        if (!Strings.isNullOrEmpty(params.get(CONNECT_TIMEOUT))) {
            connectTimeout = Integer.parseInt(params.get(CONNECT_TIMEOUT));
        }

        int readTimeout = 120; // 120s
        if (!Strings.isNullOrEmpty(params.get(READ_TIMEOUT))) {
            readTimeout =  Integer.parseInt(params.get(READ_TIMEOUT));
        }

        int retryTimes = 4; // 4 times
        if (!Strings.isNullOrEmpty(params.get(RETRY_COUNT))) {
            retryTimes = Integer.parseInt(params.get(RETRY_COUNT));
        }

        RestOptions restOptions = RestOptions.newBuilder()
                .withConnectTimeout(connectTimeout)
                .withReadTimeout(readTimeout)
                .withRetryTimes(retryTimes).build();

        settings = EnvironmentSettings.newBuilder()
                .withCredentials(credentials)
                .withServiceEndpoint(odps.getEndpoint())
                .withQuotaName(quota)
                .withRestOptions(restOptions)
                .build();

        try {
            scan = (TableBatchReadSession) deserialize(scanSerializer);
        } catch (Exception e) {
            String errorMsg = "Failed to deserialize table batch read session.";
            LOG.warn(errorMsg, e);
            throw new IllegalArgumentException(errorMsg, e);
        }
    }


    @Override
    protected void initTableInfo(ColumnType[] requiredTypes, String[] requiredFields, int batchSize) {
        super.initTableInfo(requiredTypes, requiredFields, batchSize);
        readColumnsToId = new HashMap<>();
        for (int i = 0; i < fields.length; i++) {
            if (!Strings.isNullOrEmpty(fields[i])) {
                readColumnsToId.put(fields[i], i);
            }
        }
    }

    @Override
    public void open() throws IOException {
        try {
            InputSplit split;
            if (splitType == SplitType.BYTE_SIZE) {
                split = new IndexedInputSplit(sessionId, (int) startOffset);
            } else {
                split = new RowRangeInputSplit(sessionId, startOffset, splitSize);
            }

            currentSplitReader = scan.createArrowReader(split, ReaderOptions.newBuilder().withSettings(settings)
                    .withCompressionCodec(CompressionCodec.ZSTD)
                    .withReuseBatch(true)
                    .build());

        } catch (Exception e) {
            String errorMsg = "MaxComputeJniScanner Failed to open table batch read session.";
            LOG.warn(errorMsg, e);
            close();
            throw new IOException(errorMsg, e);
        }
    }

    @Override
    public void close() throws IOException {
        if (currentSplitReader != null) {
            try {
                currentSplitReader.close();
            } catch (Exception e) {
                LOG.warn("Failed to close MaxCompute split reader for table " + project + "." + table, e);
            }
        }
        startOffset = -1;
        splitSize = -1;
        currentBatch = null;
        currentBatchRowOffset = 0;
        currentSplitReader = null;
        settings = null;
        scan = null;
        readColumnsToId.clear();
    }

    @Override
    protected int getNext() throws IOException {
        if (currentSplitReader == null) {
            return 0;
        }
        columnValue = new MaxComputeColumnValue();
        columnValue.setTimeZone(timeZone);
        int expectedRows = batchSize;
        return readVectors(expectedRows);
    }

    private VectorSchemaRoot getNextBatch() throws IOException {
        try {
            if (!currentSplitReader.hasNext()) {
                currentSplitReader.close();
                currentSplitReader = null;
                return null;
            }
            return currentSplitReader.get();
        } catch (Exception e) {
            String errorMsg = "MaxComputeJniScanner readVectors get batch fail";
            LOG.warn(errorMsg, e);
            throw new IOException(e.getMessage(), e);
        }
    }

    private int readVectors(int expectedRows) throws IOException {
        int curReadRows = 0;
        long accumulatedBytes = 0;
        while (curReadRows < expectedRows) {
            // Stop early if accumulated variable-width bytes approach int32 limit
            if (accumulatedBytes >= MAX_BATCH_BYTES) {
                break;
            }
            if (currentBatch == null) {
                currentBatch = getNextBatch();
                if (currentBatch == null || currentBatch.getRowCount() == 0) {
                    currentBatch = null;
                    break;
                }
                currentBatchRowOffset = 0;
            }
            try {
                int rowsToAppend = Math.min(expectedRows - curReadRows,
                        currentBatch.getRowCount() - currentBatchRowOffset);
                List<FieldVector> fieldVectors = currentBatch.getFieldVectors();

                // Limit rows to avoid int32 overflow in VectorColumn's String byte buffer
                rowsToAppend = limitRowsByVarWidthBytes(
                        fieldVectors, currentBatchRowOffset, rowsToAppend,
                        MAX_BATCH_BYTES - accumulatedBytes);
                if (rowsToAppend <= 0) {
                    break;
                }

                long startTime = System.nanoTime();
                for (FieldVector column : fieldVectors) {
                    Integer readColumnId = readColumnsToId.get(column.getName());
                    if (readColumnId == null) {
                        continue;
                    }
                    columnValue.reset(column);
                    for (int j = currentBatchRowOffset; j < currentBatchRowOffset + rowsToAppend; j++) {
                        columnValue.setColumnIdx(j);
                        appendData(readColumnId, columnValue);
                    }
                }
                appendDataTime += System.nanoTime() - startTime;

                // Track bytes for the rows just appended
                accumulatedBytes += estimateVarWidthBytes(
                        fieldVectors, currentBatchRowOffset, rowsToAppend);

                currentBatchRowOffset += rowsToAppend;
                curReadRows += rowsToAppend;
                if (currentBatchRowOffset >= currentBatch.getRowCount()) {
                    currentBatch = null;
                    currentBatchRowOffset = 0;
                }
            } catch (Exception e) {
                String errorMsg = String.format("MaxComputeJniScanner Fail to read arrow data. "
                        + "curReadRows = {}, expectedRows = {}", curReadRows, expectedRows);
                LOG.warn(errorMsg, e);
                throw new RuntimeException(errorMsg, e);
            }
        }
        if (LOG.isDebugEnabled() && curReadRows > 0 && curReadRows < expectedRows) {
            LOG.debug("readVectors: returning " + curReadRows + " rows (limited by byte budget)"
                    + ", totalVarWidthBytes=" + accumulatedBytes
                    + ", expectedRows=" + expectedRows);
        }
        return curReadRows;
    }

    /**
     * Limit the number of rows to append so that no single variable-width column
     * exceeds the remaining byte budget. This prevents int32 overflow in
     * VectorColumn's appendIndex for String/Binary child byte arrays.
     *
     * Uses Arrow's offset buffer for O(1)-per-row byte size calculation ���
     * no data copying involved.
     */
    private int limitRowsByVarWidthBytes(List<FieldVector> fieldVectors,
            int offset, int maxRows, long remainingBudget) {
        if (remainingBudget <= 0) {
            return 0;
        }
        int safeRows = maxRows;
        for (FieldVector fv : fieldVectors) {
            if (fv instanceof BaseVariableWidthVector) {
                BaseVariableWidthVector vec = (BaseVariableWidthVector) fv;
                // Find how many rows fit within the budget for THIS column
                int rows = findMaxRowsWithinBudget(vec, offset, maxRows, remainingBudget);
                safeRows = Math.min(safeRows, rows);
            }
        }
        // Always allow at least 1 row to make progress, even if it exceeds budget
        return Math.max(1, safeRows);
    }

    /**
     * Binary search for the maximum number of rows starting at 'offset'
     * whose total bytes in the variable-width vector fit within 'budget'.
     */
    private int findMaxRowsWithinBudget(BaseVariableWidthVector vec,
            int offset, int maxRows, long budget) {
        if (maxRows <= 0) {
            return 0;
        }
        // Total bytes for all maxRows
        long totalBytes = (long) vec.getOffsetBuffer().getInt((long) (offset + maxRows) * 4)
                - (long) vec.getOffsetBuffer().getInt((long) offset * 4);
        if (totalBytes <= budget) {
            return maxRows;
        }
        // Binary search for the cutoff point
        int lo = 1;
        int hi = maxRows - 1;
        int startOff = vec.getOffsetBuffer().getInt((long) offset * 4);
        while (lo <= hi) {
            int mid = lo + (hi - lo) / 2;
            long bytes = (long) vec.getOffsetBuffer().getInt((long) (offset + mid) * 4) - startOff;
            if (bytes <= budget) {
                lo = mid + 1;
            } else {
                hi = mid - 1;
            }
        }
        // 'hi' is the largest count whose bytes <= budget (could be 0)
        return hi;
    }

    /**
     * Estimate total variable-width bytes for the given row range across all columns.
     * Returns the max bytes of any single column (since each column has its own
     * VectorColumn child buffer and the overflow is per-column).
     */
    private long estimateVarWidthBytes(List<FieldVector> fieldVectors,
            int offset, int rows) {
        long maxColumnBytes = 0;
        for (FieldVector fv : fieldVectors) {
            if (fv instanceof BaseVariableWidthVector) {
                BaseVariableWidthVector vec = (BaseVariableWidthVector) fv;
                long bytes = (long) vec.getOffsetBuffer().getInt((long) (offset + rows) * 4)
                        - (long) vec.getOffsetBuffer().getInt((long) offset * 4);
                maxColumnBytes = Math.max(maxColumnBytes, bytes);
            }
        }
        return maxColumnBytes;
    }

    private static Object deserialize(String serializedString) throws IOException, ClassNotFoundException {
        byte[] serializedBytes = Base64.getDecoder().decode(serializedString);
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serializedBytes);
        ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
        return objectInputStream.readObject();
    }
}