MaxComputeJniWriter.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.maxcompute;

import org.apache.doris.common.jni.JniWriter;
import org.apache.doris.common.jni.vec.VectorColumn;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.common.maxcompute.MCUtils;

import com.aliyun.odps.Odps;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.table.configuration.ArrowOptions;
import com.aliyun.odps.table.configuration.ArrowOptions.TimestampUnit;
import com.aliyun.odps.table.configuration.CompressionCodec;
import com.aliyun.odps.table.configuration.RestOptions;
import com.aliyun.odps.table.configuration.WriterOptions;
import com.aliyun.odps.table.enviroment.Credentials;
import com.aliyun.odps.table.enviroment.EnvironmentSettings;
import com.aliyun.odps.table.write.BatchWriter;
import com.aliyun.odps.table.write.TableBatchWriteSession;
import com.aliyun.odps.table.write.TableWriteSessionBuilder;
import com.aliyun.odps.table.write.WriterAttemptId;
import com.aliyun.odps.table.write.WriterCommitMessage;
import com.aliyun.odps.type.TypeInfo;
import com.google.common.base.Strings;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.log4j.Logger;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * MaxComputeJniWriter writes C++ Block data to MaxCompute tables via Storage API (Arrow).
 * Loaded by C++ as: org/apache/doris/maxcompute/MaxComputeJniWriter
 */
public class MaxComputeJniWriter extends JniWriter {
    private static final Logger LOG = Logger.getLogger(MaxComputeJniWriter.class);

    private static final String ACCESS_KEY = "access_key";
    private static final String SECRET_KEY = "secret_key";
    private static final String ENDPOINT = "endpoint";
    private static final String QUOTA = "quota";
    private static final String PROJECT = "project";
    private static final String TABLE = "table";
    private static final String WRITE_SESSION_ID = "write_session_id";
    private static final String BLOCK_ID = "block_id";
    private static final String PARTITION_SPEC = "partition_spec";
    private static final String CONNECT_TIMEOUT = "connect_timeout";
    private static final String READ_TIMEOUT = "read_timeout";
    private static final String RETRY_COUNT = "retry_count";
    private static final String MAX_WRITE_BATCH_ROWS = "max_write_batch_rows";

    private final Map<String, String> params;

    // 128MB batch threshold ��� controls peak Arrow native memory per batch.
    // Arrow uses sun.misc.Unsafe.allocateMemory() which is invisible to JVM metrics.
    // Each batch temporarily holds ~batchDataBytes of native memory.
    // With 3 concurrent writers, total Arrow native = 3 * 128MB = ~384MB.
    // Using 1GB was too large: 3 writers * 1GB = 3GB invisible native memory.
    private static final long MAX_ARROW_BATCH_BYTES = 128 * 1024 * 1024L;

    // Segmented commit: commit and recreate batchWriter every N rows to prevent
    // MaxCompute SDK native memory accumulation. Without this, the SDK buffers
    // all written data internally, causing process RSS to grow linearly with
    // total data volume until SIGSEGV.
    private static final long ROWS_PER_SEGMENT = 5000;

    private final String endpoint;
    private final String project;
    private final String tableName;
    private final String quota;
    private String writeSessionId;
    private long blockId;
    private long nextBlockId; // For creating new segments with unique blockIds
    private String partitionSpec;
    private int connectTimeout;
    private int readTimeout;
    private int retryCount;
    private int maxWriteBatchRows;

    // Storage API objects
    private TableBatchWriteSession writeSession;
    private BatchWriter<VectorSchemaRoot> batchWriter;
    private BufferAllocator allocator;
    private List<TypeInfo> columnTypeInfos;
    private List<String> columnNames;
    // Collect commit messages from all segments (each batchWriter commit produces one)
    private final List<WriterCommitMessage> commitMessages = new java.util.ArrayList<>();

    // Per-segment row counter (resets after each segment commit)
    private long segmentRows = 0;

    // Writer options cached for creating new batchWriters
    private WriterOptions writerOptions;

    // Statistics
    private long writtenRows = 0;
    private long writtenBytes = 0;

    public MaxComputeJniWriter(int batchSize, Map<String, String> params) {
        super(batchSize, params);
        this.params = params;
        this.endpoint = Objects.requireNonNull(params.get(ENDPOINT), "required property '" + ENDPOINT + "'.");
        this.project = Objects.requireNonNull(params.get(PROJECT), "required property '" + PROJECT + "'.");
        this.tableName = Objects.requireNonNull(params.get(TABLE), "required property '" + TABLE + "'.");
        this.quota = params.getOrDefault(QUOTA, "");
        this.writeSessionId = Objects.requireNonNull(params.get(WRITE_SESSION_ID),
                "required property '" + WRITE_SESSION_ID + "'.");
        this.blockId = Long.parseLong(params.getOrDefault(BLOCK_ID, "0"));
        this.nextBlockId = this.blockId + 1; // Reserve blockId for first writer, increment for segments
        this.partitionSpec = params.getOrDefault(PARTITION_SPEC, "");
        this.connectTimeout = Integer.parseInt(params.getOrDefault(CONNECT_TIMEOUT, "10"));
        this.readTimeout = Integer.parseInt(params.getOrDefault(READ_TIMEOUT, "120"));
        this.retryCount = Integer.parseInt(params.getOrDefault(RETRY_COUNT, "4"));
        this.maxWriteBatchRows = Integer.parseInt(params.getOrDefault(MAX_WRITE_BATCH_ROWS, "4096"));
    }

    @Override
    public void open() throws IOException {
        try {
            Odps odps = MCUtils.createMcClient(params);
            odps.setDefaultProject(project);
            odps.setEndpoint(endpoint);

            Credentials credentials = Credentials.newBuilder().withAccount(odps.getAccount())
                    .withAppAccount(odps.getAppAccount()).build();

            RestOptions restOptions = RestOptions.newBuilder()
                    .withConnectTimeout(connectTimeout)
                    .withReadTimeout(readTimeout)
                    .withRetryTimes(retryCount).build();

            EnvironmentSettings settings = EnvironmentSettings.newBuilder()
                    .withCredentials(credentials)
                    .withServiceEndpoint(odps.getEndpoint())
                    .withQuotaName(Strings.isNullOrEmpty(quota) ? null : quota)
                    .withRestOptions(restOptions)
                    .build();

            // Restore the write session created by FE
            writeSession = new TableWriteSessionBuilder()
                    .identifier(com.aliyun.odps.table.TableIdentifier.of(project, tableName))
                    .withSessionId(writeSessionId)
                    .withSettings(settings)
                    .buildBatchWriteSession();

            // SDK skips ArrowOptions when restoring session via withSessionId,
            // set it via reflection to avoid NPE in ArrowWriterImpl
            ArrowOptions arrowOptions = ArrowOptions.newBuilder()
                    .withDatetimeUnit(TimestampUnit.MILLI)
                    .withTimestampUnit(TimestampUnit.MILLI)
                    .build();
            java.lang.reflect.Field arrowField = writeSession.getClass()
                    .getSuperclass().getDeclaredField("arrowOptions");
            arrowField.setAccessible(true);
            arrowField.set(writeSession, arrowOptions);

            // Get schema info for type mapping
            com.aliyun.odps.table.DataSchema dataSchema = writeSession.requiredSchema();
            columnTypeInfos = new java.util.ArrayList<>();
            columnNames = new java.util.ArrayList<>();
            for (com.aliyun.odps.Column col : dataSchema.getColumns()) {
                columnTypeInfos.add(col.getTypeInfo());
                columnNames.add(col.getName());
            }

            allocator = new RootAllocator(Long.MAX_VALUE);

            // Cache writer options for creating new batchWriters in segments
            writerOptions = WriterOptions.newBuilder()
                    .withSettings(settings)
                    .withCompressionCodec(CompressionCodec.ZSTD)
                    .build();
            batchWriter = writeSession.createArrowWriter(blockId,
                    WriterAttemptId.of(0), writerOptions);

            LOG.info("MaxComputeJniWriter opened: project=" + project + ", table=" + tableName
                    + ", writeSessionId=" + writeSessionId + ", partitionSpec=" + partitionSpec
                    + ", blockId=" + blockId);
        } catch (Exception e) {
            String errorMsg = "Failed to open MaxCompute write session for table " + project + "." + tableName;
            LOG.error(errorMsg, e);
            throw new IOException(errorMsg, e);
        }
    }

    @Override
    protected void writeInternal(VectorTable inputTable) throws IOException {
        int numRows = inputTable.getNumRows();
        int numCols = inputTable.getNumColumns();
        if (numRows == 0) {
            return;
        }

        try {
            // Stream data directly from off-heap VectorColumn to Arrow vectors.
            // Unlike the previous getMaterializedData() approach that created
            // Object[][] (with String objects for STRING columns causing 3x memory
            // amplification), this reads bytes directly from VectorColumn and writes
            // to Arrow, keeping peak heap usage per batch to O(batch_rows * row_size)
            // instead of O(2 * batch_rows * row_size).
            int rowOffset = 0;
            while (rowOffset < numRows) {
                int batchRows = Math.min(maxWriteBatchRows, numRows - rowOffset);

                // For variable-width columns, check byte budget to avoid Arrow int32 overflow
                batchRows = limitWriteBatchByBytesStreaming(inputTable, numCols,
                        rowOffset, batchRows);

                VectorSchemaRoot root = batchWriter.newElement();
                try {
                    root.setRowCount(batchRows);

                    for (int col = 0; col < numCols && col < columnTypeInfos.size(); col++) {
                        OdpsType odpsType = columnTypeInfos.get(col).getOdpsType();
                        fillArrowVectorStreaming(root, col, odpsType,
                                inputTable.getColumn(col), rowOffset, batchRows);
                    }

                    batchWriter.write(root);
                } finally {
                    root.close();
                }

                writtenRows += batchRows;
                segmentRows += batchRows;
                rowOffset += batchRows;

                // Segmented commit: rotate batchWriter to release SDK native memory
                if (segmentRows >= ROWS_PER_SEGMENT) {
                    rotateBatchWriter();
                }
            }
        } catch (Exception e) {
            String errorMsg = "Failed to write data to MaxCompute table " + project + "." + tableName;
            LOG.error(errorMsg, e);
            throw new IOException(errorMsg, e);
        }
    }

    /**
     * Commit current batchWriter and create a new one with a fresh blockId.
     * This forces the MaxCompute SDK to flush and release internal native memory
     * buffers that accumulate during writes. Without rotation, the SDK holds all
     * serialized Arrow data in native memory until close(), causing process RSS
     * to grow linearly with total data volume.
     */
    private void rotateBatchWriter() throws IOException {
        try {
            // 1. Commit current batchWriter and save its commit message
            WriterCommitMessage msg = batchWriter.commit();
            commitMessages.add(msg);
            batchWriter = null;

            // 2. Close current allocator to release Arrow native memory
            allocator.close();
            allocator = null;

            // 3. Create new allocator and batchWriter with a new blockId
            long newBlockId = nextBlockId++;
            allocator = new RootAllocator(Long.MAX_VALUE);
            batchWriter = writeSession.createArrowWriter(newBlockId,
                    WriterAttemptId.of(0), writerOptions);

            LOG.info("Rotated batchWriter: oldBlockId=" + blockId + ", newBlockId=" + newBlockId
                    + ", totalCommitMessages=" + commitMessages.size()
                    + ", totalWrittenRows=" + writtenRows);

            blockId = newBlockId;
            segmentRows = 0;
        } catch (Exception e) {
            throw new IOException("Failed to rotate batchWriter for table "
                    + project + "." + tableName, e);
        }
    }


    private boolean isVariableWidthType(OdpsType type) {
        return type == OdpsType.STRING || type == OdpsType.VARCHAR
                || type == OdpsType.CHAR || type == OdpsType.BINARY;
    }

    /**
     * Limit write batch size by estimating variable-width column bytes directly
     * from the off-heap VectorColumn, without materializing data to Java heap.
     */
    private int limitWriteBatchByBytesStreaming(VectorTable inputTable, int numCols,
                                               int rowOffset, int batchRows) {
        for (int col = 0; col < numCols && col < columnTypeInfos.size(); col++) {
            OdpsType odpsType = columnTypeInfos.get(col).getOdpsType();
            if (!isVariableWidthType(odpsType)) {
                continue;
            }
            VectorColumn vc = inputTable.getColumn(col);
            batchRows = findMaxRowsForColumnStreaming(vc, rowOffset, batchRows, MAX_ARROW_BATCH_BYTES);
            if (batchRows <= 1) {
                return Math.max(1, batchRows);
            }
        }
        return batchRows;
    }

    /**
     * Find the maximum number of rows (from rowOffset) whose total byte size
     * fits within budget, by reading offset metadata directly from VectorColumn.
     */
    private int findMaxRowsForColumnStreaming(VectorColumn vc, int rowOffset, int maxRows, long budget) {
        long totalBytes = estimateColumnBytesStreaming(vc, rowOffset, maxRows);
        if (totalBytes <= budget) {
            return maxRows;
        }
        int rows = maxRows;
        while (rows > 1) {
            rows = rows / 2;
            totalBytes = estimateColumnBytesStreaming(vc, rowOffset, rows);
            if (totalBytes <= budget) {
                int lo = rows;
                int hi = Math.min(rows * 2, maxRows);
                while (lo < hi) {
                    int mid = lo + (hi - lo + 1) / 2;
                    if (estimateColumnBytesStreaming(vc, rowOffset, mid) <= budget) {
                        lo = mid;
                    } else {
                        hi = mid - 1;
                    }
                }
                return lo;
            }
        }
        return 1;
    }

    /**
     * Estimate total bytes for a range of rows in a VectorColumn by reading
     * the offset array directly from off-heap memory, without creating any
     * byte[] objects. This is O(1) per row (just offset subtraction).
     */
    private long estimateColumnBytesStreaming(VectorColumn vc, int rowOffset, int rows) {
        long total = 0;
        long offsetAddr = vc.offsetAddress();
        for (int i = rowOffset; i < rowOffset + rows; i++) {
            if (!vc.isNullAt(i)) {
                // String offsets are stored as int32 in VectorColumn
                int startOff = i == 0 ? 0
                        : org.apache.doris.common.jni.utils.OffHeap.getInt(null, offsetAddr + 4L * (i - 1));
                int endOff = org.apache.doris.common.jni.utils.OffHeap.getInt(null, offsetAddr + 4L * i);
                total += (endOff - startOff);
            }
        }
        return total;
    }

    /**
     * Fill an Arrow vector by reading data directly from a VectorColumn,
     * one row at a time. For STRING columns, this reads bytes directly
     * (getBytesWithOffset) instead of creating String objects, eliminating
     * the String -> byte[] double-copy that caused heap exhaustion.
     */
    private void fillArrowVectorStreaming(VectorSchemaRoot root, int colIdx, OdpsType odpsType,
                                          VectorColumn vc, int rowOffset, int numRows) {
        switch (odpsType) {
            case BOOLEAN: {
                BitVector vec = (BitVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getBoolean(rowOffset + i) ? 1 : 0);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case TINYINT: {
                TinyIntVector vec = (TinyIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getByte(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case SMALLINT: {
                SmallIntVector vec = (SmallIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getShort(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case INT: {
                IntVector vec = (IntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getInt(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BIGINT: {
                BigIntVector vec = (BigIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getLong(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case FLOAT: {
                Float4Vector vec = (Float4Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getFloat(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DOUBLE: {
                Float8Vector vec = (Float8Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getDouble(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DECIMAL: {
                DecimalVector vec = (DecimalVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, vc.getDecimal(rowOffset + i));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case STRING:
            case VARCHAR:
            case CHAR: {
                // KEY FIX: Read bytes directly from off-heap, no String creation.
                // Previously: getMaterializedData -> String[] -> toString().getBytes() -> Arrow
                // Now: getBytesWithOffset() -> byte[] -> Arrow (1 copy instead of 3)
                VarCharVector vec = (VarCharVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        byte[] bytes = vc.getBytesWithOffset(rowOffset + i);
                        vec.setSafe(i, bytes);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATE: {
                DateDayVector vec = (DateDayVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        LocalDate date = vc.getDate(rowOffset + i);
                        vec.set(i, (int) date.toEpochDay());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATETIME:
            case TIMESTAMP: {
                TimeStampMilliVector vec = (TimeStampMilliVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        LocalDateTime dt = vc.getDateTime(rowOffset + i);
                        long millis = dt.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
                        vec.set(i, millis);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BINARY: {
                VarBinaryVector vec = (VarBinaryVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (vc.isNullAt(rowOffset + i)) {
                        vec.setNull(i);
                    } else {
                        byte[] bytes = vc.getBytesWithOffset(rowOffset + i);
                        vec.setSafe(i, bytes);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            default: {
                // For complex types (ARRAY, MAP, STRUCT) and other types,
                // fall back to object-based materialization for this column only.
                Object[] colData = vc.getObjectColumn(rowOffset, rowOffset + numRows);
                fillArrowVector(root, colIdx, odpsType, colData, 0, numRows);
                break;
            }
        }
    }

    private void fillArrowVector(VectorSchemaRoot root, int colIdx, OdpsType odpsType,
                                  Object[] colData, int rowOffset, int numRows) {
        switch (odpsType) {
            case BOOLEAN: {
                BitVector vec = (BitVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, (Boolean) colData[rowOffset + i] ? 1 : 0);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case TINYINT: {
                TinyIntVector vec = (TinyIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).byteValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case SMALLINT: {
                SmallIntVector vec = (SmallIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).shortValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case INT: {
                IntVector vec = (IntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).intValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BIGINT: {
                BigIntVector vec = (BigIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).longValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case FLOAT: {
                Float4Vector vec = (Float4Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).floatValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DOUBLE: {
                Float8Vector vec = (Float8Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[rowOffset + i]).doubleValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DECIMAL: {
                DecimalVector vec = (DecimalVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        BigDecimal bd = (colData[rowOffset + i] instanceof BigDecimal)
                                ? (BigDecimal) colData[rowOffset + i]
                                : new BigDecimal(colData[rowOffset + i].toString());
                        vec.set(i, bd);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case STRING:
            case VARCHAR:
            case CHAR: {
                VarCharVector vec = (VarCharVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        byte[] bytes;
                        if (colData[rowOffset + i] instanceof byte[]) {
                            bytes = (byte[]) colData[rowOffset + i];
                        } else {
                            bytes = colData[rowOffset + i].toString().getBytes(StandardCharsets.UTF_8);
                        }
                        vec.setSafe(i, bytes);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATE: {
                DateDayVector vec = (DateDayVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else if (colData[rowOffset + i] instanceof LocalDate) {
                        vec.set(i, (int) ((LocalDate) colData[rowOffset + i]).toEpochDay());
                    } else {
                        vec.set(i, (int) LocalDate.parse(colData[rowOffset + i].toString()).toEpochDay());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATETIME:
            case TIMESTAMP: {
                TimeStampMilliVector vec = (TimeStampMilliVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else if (colData[rowOffset + i] instanceof LocalDateTime) {
                        long millis = ((LocalDateTime) colData[rowOffset + i])
                                .atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
                        vec.set(i, millis);
                    } else if (colData[rowOffset + i] instanceof java.sql.Timestamp) {
                        vec.set(i, ((java.sql.Timestamp) colData[rowOffset + i]).getTime());
                    } else {
                        long millis = LocalDateTime.parse(colData[rowOffset + i].toString())
                                .atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
                        vec.set(i, millis);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BINARY: {
                VarBinaryVector vec = (VarBinaryVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else if (colData[rowOffset + i] instanceof byte[]) {
                        vec.setSafe(i, (byte[]) colData[rowOffset + i]);
                    } else {
                        vec.setSafe(i, colData[rowOffset + i].toString().getBytes(StandardCharsets.UTF_8));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case ARRAY: {
                ListVector listVec = (ListVector) root.getVector(colIdx);
                listVec.allocateNew();
                FieldVector dataVec = listVec.getDataVector();
                int elemIdx = 0;
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        listVec.setNull(i);
                    } else {
                        List<?> list = (List<?>) colData[rowOffset + i];
                        listVec.startNewValue(i);
                        for (Object elem : list) {
                            writeListElement(dataVec, elemIdx++, elem);
                        }
                        listVec.endValue(i, list.size());
                    }
                }
                listVec.setValueCount(numRows);
                dataVec.setValueCount(elemIdx);
                break;
            }
            case MAP: {
                MapVector mapVec = (MapVector) root.getVector(colIdx);
                mapVec.allocateNew();
                StructVector structVec = (StructVector) mapVec.getDataVector();
                FieldVector keyVec = structVec.getChildrenFromFields().get(0);
                FieldVector valVec = structVec.getChildrenFromFields().get(1);
                int elemIdx = 0;
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        mapVec.setNull(i);
                    } else {
                        Map<?, ?> map = (Map<?, ?>) colData[rowOffset + i];
                        mapVec.startNewValue(i);
                        for (Map.Entry<?, ?> entry : map.entrySet()) {
                            structVec.setIndexDefined(elemIdx);
                            writeListElement(keyVec, elemIdx, entry.getKey());
                            writeListElement(valVec, elemIdx, entry.getValue());
                            elemIdx++;
                        }
                        mapVec.endValue(i, map.size());
                    }
                }
                mapVec.setValueCount(numRows);
                structVec.setValueCount(elemIdx);
                keyVec.setValueCount(elemIdx);
                valVec.setValueCount(elemIdx);
                break;
            }
            case STRUCT: {
                StructVector structVec = (StructVector) root.getVector(colIdx);
                structVec.allocateNew();
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        structVec.setNull(i);
                    } else {
                        structVec.setIndexDefined(i);
                        Map<?, ?> struct = (Map<?, ?>) colData[rowOffset + i];
                        for (FieldVector childVec : structVec.getChildrenFromFields()) {
                            Object val = struct.get(childVec.getName());
                            writeListElement(childVec, i, val);
                        }
                    }
                }
                structVec.setValueCount(numRows);
                for (FieldVector childVec : structVec.getChildrenFromFields()) {
                    childVec.setValueCount(numRows);
                }
                break;
            }
            default: {
                // Fallback: write as VarChar
                VarCharVector vec = (VarCharVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[rowOffset + i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.setSafe(i, colData[rowOffset + i].toString().getBytes(StandardCharsets.UTF_8));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
        }
    }

    private void writeListElement(FieldVector vec, int idx, Object elem) {
        if (elem == null) {
            if (vec instanceof BaseFixedWidthVector) {
                ((BaseFixedWidthVector) vec).setNull(idx);
            } else if (vec instanceof BaseVariableWidthVector) {
                ((BaseVariableWidthVector) vec).setNull(idx);
            } else if (vec instanceof StructVector) {
                ((StructVector) vec).setNull(idx);
            } else if (vec instanceof MapVector) {
                ((MapVector) vec).setNull(idx);
            } else if (vec instanceof ListVector) {
                ((ListVector) vec).setNull(idx);
            }
            return;
        }
        if (vec instanceof VarCharVector) {
            byte[] bytes = elem instanceof byte[] ? (byte[]) elem
                    : elem.toString().getBytes(StandardCharsets.UTF_8);
            ((VarCharVector) vec).setSafe(idx, bytes);
        } else if (vec instanceof IntVector) {
            ((IntVector) vec).setSafe(idx, ((Number) elem).intValue());
        } else if (vec instanceof BigIntVector) {
            ((BigIntVector) vec).setSafe(idx, ((Number) elem).longValue());
        } else if (vec instanceof Float8Vector) {
            ((Float8Vector) vec).setSafe(idx, ((Number) elem).doubleValue());
        } else if (vec instanceof Float4Vector) {
            ((Float4Vector) vec).setSafe(idx, ((Number) elem).floatValue());
        } else if (vec instanceof SmallIntVector) {
            ((SmallIntVector) vec).setSafe(idx, ((Number) elem).shortValue());
        } else if (vec instanceof TinyIntVector) {
            ((TinyIntVector) vec).setSafe(idx, ((Number) elem).byteValue());
        } else if (vec instanceof BitVector) {
            ((BitVector) vec).setSafe(idx, (Boolean) elem ? 1 : 0);
        } else if (vec instanceof DecimalVector) {
            BigDecimal bd = elem instanceof BigDecimal ? (BigDecimal) elem
                    : new BigDecimal(elem.toString());
            ((DecimalVector) vec).setSafe(idx, bd);
        } else if (vec instanceof StructVector) {
            StructVector structVec = (StructVector) vec;
            structVec.setIndexDefined(idx);
            Map<?, ?> struct = (Map<?, ?>) elem;
            for (FieldVector childVec : structVec.getChildrenFromFields()) {
                writeListElement(childVec, idx, struct.get(childVec.getName()));
            }
        } else if (vec instanceof MapVector) {
            MapVector mapVec = (MapVector) vec;
            StructVector entryVec = (StructVector) mapVec.getDataVector();
            FieldVector keyVec = entryVec.getChildrenFromFields().get(0);
            FieldVector valVec = entryVec.getChildrenFromFields().get(1);
            Map<?, ?> map = (Map<?, ?>) elem;
            int offset = mapVec.startNewValue(idx);
            int j = 0;
            for (Map.Entry<?, ?> entry : map.entrySet()) {
                entryVec.setIndexDefined(offset + j);
                writeListElement(keyVec, offset + j, entry.getKey());
                writeListElement(valVec, offset + j, entry.getValue());
                j++;
            }
            mapVec.endValue(idx, map.size());
            entryVec.setValueCount(offset + j);
            keyVec.setValueCount(offset + j);
            valVec.setValueCount(offset + j);
        } else if (vec instanceof ListVector) {
            ListVector listVec = (ListVector) vec;
            FieldVector dataVec = listVec.getDataVector();
            List<?> list = (List<?>) elem;
            int offset = listVec.startNewValue(idx);
            for (int j = 0; j < list.size(); j++) {
                writeListElement(dataVec, offset + j, list.get(j));
            }
            listVec.endValue(idx, list.size());
            dataVec.setValueCount(offset + list.size());
        } else {
            byte[] bytes = elem.toString().getBytes(StandardCharsets.UTF_8);
            ((VarCharVector) vec).setSafe(idx, bytes);
        }
    }

    @Override
    public void close() throws IOException {
        Exception firstException = null;
        try {
            // Commit the final segment's batchWriter
            if (batchWriter != null) {
                try {
                    WriterCommitMessage msg = batchWriter.commit();
                    commitMessages.add(msg);
                } catch (Exception e) {
                    firstException = e;
                    LOG.warn("Failed to commit batch writer for table " + project + "." + tableName, e);
                } finally {
                    batchWriter = null;
                }
            }
        } finally {
            if (allocator != null) {
                try {
                    allocator.close();
                } catch (Exception e) {
                    LOG.warn("Failed to close Arrow allocator (possible memory leak)", e);
                    if (firstException == null) {
                        firstException = e;
                    }
                } finally {
                    allocator = null;
                }
            }
        }
        LOG.info("MaxComputeJniWriter closed: writeSessionId=" + writeSessionId
                + ", partitionSpec=" + partitionSpec
                + ", writtenRows=" + writtenRows
                + ", totalSegments=" + commitMessages.size()
                + ", blockId=" + blockId);
        if (firstException != null) {
            throw new IOException("Failed to close MaxCompute arrow writer", firstException);
        }
    }

    @Override
    public Map<String, String> getStatistics() {
        Map<String, String> stats = new HashMap<>();
        stats.put("mc_partition_spec", partitionSpec != null ? partitionSpec : "");

        // Serialize all WriterCommitMessages (one per segment) as a List object.
        if (!commitMessages.isEmpty()) {
            try {
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                ObjectOutputStream oos = new ObjectOutputStream(baos);
                // Serialize the entire list as one object to avoid mixing
                // writeInt/writeObject which causes OptionalDataException
                oos.writeObject(new java.util.ArrayList<>(commitMessages));
                oos.close();
                stats.put("mc_commit_message", Base64.getEncoder().encodeToString(baos.toByteArray()));
            } catch (IOException e) {
                LOG.error("Failed to serialize WriterCommitMessages", e);
            }
        }

        stats.put("counter:WrittenRows", String.valueOf(writtenRows));
        stats.put("bytes:WrittenBytes", String.valueOf(writtenBytes));
        stats.put("timer:WriteTime", String.valueOf(writeTime));
        stats.put("timer:ReadTableTime", String.valueOf(readTableTime));
        return stats;
    }
}