MaxComputeJniWriter.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.maxcompute;

import org.apache.doris.common.jni.JniWriter;
import org.apache.doris.common.jni.vec.VectorTable;
import org.apache.doris.common.maxcompute.MCUtils;

import com.aliyun.odps.Odps;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.table.configuration.ArrowOptions;
import com.aliyun.odps.table.configuration.ArrowOptions.TimestampUnit;
import com.aliyun.odps.table.configuration.RestOptions;
import com.aliyun.odps.table.configuration.WriterOptions;
import com.aliyun.odps.table.enviroment.Credentials;
import com.aliyun.odps.table.enviroment.EnvironmentSettings;
import com.aliyun.odps.table.write.BatchWriter;
import com.aliyun.odps.table.write.TableBatchWriteSession;
import com.aliyun.odps.table.write.TableWriteSessionBuilder;
import com.aliyun.odps.table.write.WriterAttemptId;
import com.aliyun.odps.table.write.WriterCommitMessage;
import com.aliyun.odps.type.TypeInfo;
import com.google.common.base.Strings;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.log4j.Logger;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * MaxComputeJniWriter writes C++ Block data to MaxCompute tables via Storage API (Arrow).
 * Loaded by C++ as: org/apache/doris/maxcompute/MaxComputeJniWriter
 */
public class MaxComputeJniWriter extends JniWriter {
    private static final Logger LOG = Logger.getLogger(MaxComputeJniWriter.class);

    private static final String ACCESS_KEY = "access_key";
    private static final String SECRET_KEY = "secret_key";
    private static final String ENDPOINT = "endpoint";
    private static final String QUOTA = "quota";
    private static final String PROJECT = "project";
    private static final String TABLE = "table";
    private static final String WRITE_SESSION_ID = "write_session_id";
    private static final String BLOCK_ID = "block_id";
    private static final String PARTITION_SPEC = "partition_spec";
    private static final String CONNECT_TIMEOUT = "connect_timeout";
    private static final String READ_TIMEOUT = "read_timeout";
    private static final String RETRY_COUNT = "retry_count";

    private final Map<String, String> params;
    private final String endpoint;
    private final String project;
    private final String tableName;
    private final String quota;
    private String writeSessionId;
    private long blockId;
    private String partitionSpec;
    private int connectTimeout;
    private int readTimeout;
    private int retryCount;

    // Storage API objects
    private TableBatchWriteSession writeSession;
    private BatchWriter<VectorSchemaRoot> batchWriter;
    private BufferAllocator allocator;
    private List<TypeInfo> columnTypeInfos;
    private List<String> columnNames;
    private WriterCommitMessage commitMessage;

    // Statistics
    private long writtenRows = 0;
    private long writtenBytes = 0;

    public MaxComputeJniWriter(int batchSize, Map<String, String> params) {
        super(batchSize, params);
        this.params = params;
        this.endpoint = Objects.requireNonNull(params.get(ENDPOINT), "required property '" + ENDPOINT + "'.");
        this.project = Objects.requireNonNull(params.get(PROJECT), "required property '" + PROJECT + "'.");
        this.tableName = Objects.requireNonNull(params.get(TABLE), "required property '" + TABLE + "'.");
        this.quota = params.getOrDefault(QUOTA, "");
        this.writeSessionId = Objects.requireNonNull(params.get(WRITE_SESSION_ID),
                "required property '" + WRITE_SESSION_ID + "'.");
        this.blockId = Long.parseLong(params.getOrDefault(BLOCK_ID, "0"));
        this.partitionSpec = params.getOrDefault(PARTITION_SPEC, "");
        this.connectTimeout = Integer.parseInt(params.getOrDefault(CONNECT_TIMEOUT, "10"));
        this.readTimeout = Integer.parseInt(params.getOrDefault(READ_TIMEOUT, "120"));
        this.retryCount = Integer.parseInt(params.getOrDefault(RETRY_COUNT, "4"));
    }

    @Override
    public void open() throws IOException {
        try {
            Odps odps = MCUtils.createMcClient(params);
            odps.setDefaultProject(project);
            odps.setEndpoint(endpoint);

            Credentials credentials = Credentials.newBuilder().withAccount(odps.getAccount())
                    .withAppAccount(odps.getAppAccount()).build();

            RestOptions restOptions = RestOptions.newBuilder()
                    .withConnectTimeout(connectTimeout)
                    .withReadTimeout(readTimeout)
                    .withRetryTimes(retryCount).build();

            EnvironmentSettings settings = EnvironmentSettings.newBuilder()
                    .withCredentials(credentials)
                    .withServiceEndpoint(odps.getEndpoint())
                    .withQuotaName(Strings.isNullOrEmpty(quota) ? null : quota)
                    .withRestOptions(restOptions)
                    .build();

            // Restore the write session created by FE
            writeSession = new TableWriteSessionBuilder()
                    .identifier(com.aliyun.odps.table.TableIdentifier.of(project, tableName))
                    .withSessionId(writeSessionId)
                    .withSettings(settings)
                    .buildBatchWriteSession();

            // SDK skips ArrowOptions when restoring session via withSessionId,
            // set it via reflection to avoid NPE in ArrowWriterImpl
            ArrowOptions arrowOptions = ArrowOptions.newBuilder()
                    .withDatetimeUnit(TimestampUnit.MILLI)
                    .withTimestampUnit(TimestampUnit.MILLI)
                    .build();
            java.lang.reflect.Field arrowField = writeSession.getClass()
                    .getSuperclass().getDeclaredField("arrowOptions");
            arrowField.setAccessible(true);
            arrowField.set(writeSession, arrowOptions);

            // Get schema info for type mapping
            com.aliyun.odps.table.DataSchema dataSchema = writeSession.requiredSchema();
            columnTypeInfos = new java.util.ArrayList<>();
            columnNames = new java.util.ArrayList<>();
            for (com.aliyun.odps.Column col : dataSchema.getColumns()) {
                columnTypeInfos.add(col.getTypeInfo());
                columnNames.add(col.getName());
            }

            allocator = new RootAllocator(Long.MAX_VALUE);

            // Create Arrow writer for this block
            WriterOptions writerOptions = WriterOptions.newBuilder()
                    .withSettings(settings)
                    .build();
            batchWriter = writeSession.createArrowWriter(blockId,
                    WriterAttemptId.of(0), writerOptions);

            LOG.info("MaxComputeJniWriter opened: project=" + project + ", table=" + tableName
                    + ", writeSessionId=" + writeSessionId + ", partitionSpec=" + partitionSpec
                    + ", blockId=" + blockId);
        } catch (Exception e) {
            String errorMsg = "Failed to open MaxCompute write session for table " + project + "." + tableName;
            LOG.error(errorMsg, e);
            throw new IOException(errorMsg, e);
        }
    }

    @Override
    protected void writeInternal(VectorTable inputTable) throws IOException {
        int numRows = inputTable.getNumRows();
        int numCols = inputTable.getNumColumns();
        if (numRows == 0) {
            return;
        }

        try {
            Object[][] data = inputTable.getMaterializedData();

            // Get a pre-allocated VectorSchemaRoot from the batch writer
            VectorSchemaRoot root = batchWriter.newElement();
            root.setRowCount(numRows);

            for (int col = 0; col < numCols && col < columnTypeInfos.size(); col++) {
                OdpsType odpsType = columnTypeInfos.get(col).getOdpsType();
                fillArrowVector(root, col, odpsType, data[col], numRows);
            }

            batchWriter.write(root);
            writtenRows += numRows;
        } catch (Exception e) {
            String errorMsg = "Failed to write data to MaxCompute table " + project + "." + tableName;
            LOG.error(errorMsg, e);
            throw new IOException(errorMsg, e);
        }
    }

    private void fillArrowVector(VectorSchemaRoot root, int colIdx, OdpsType odpsType,
                                  Object[] colData, int numRows) {
        switch (odpsType) {
            case BOOLEAN: {
                BitVector vec = (BitVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, (Boolean) colData[i] ? 1 : 0);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case TINYINT: {
                TinyIntVector vec = (TinyIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).byteValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case SMALLINT: {
                SmallIntVector vec = (SmallIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).shortValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case INT: {
                IntVector vec = (IntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).intValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BIGINT: {
                BigIntVector vec = (BigIntVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).longValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case FLOAT: {
                Float4Vector vec = (Float4Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).floatValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DOUBLE: {
                Float8Vector vec = (Float8Vector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.set(i, ((Number) colData[i]).doubleValue());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DECIMAL: {
                DecimalVector vec = (DecimalVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        BigDecimal bd = (colData[i] instanceof BigDecimal)
                                ? (BigDecimal) colData[i]
                                : new BigDecimal(colData[i].toString());
                        vec.set(i, bd);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case STRING:
            case VARCHAR:
            case CHAR: {
                VarCharVector vec = (VarCharVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        byte[] bytes;
                        if (colData[i] instanceof byte[]) {
                            bytes = (byte[]) colData[i];
                        } else {
                            bytes = colData[i].toString().getBytes(StandardCharsets.UTF_8);
                        }
                        vec.setSafe(i, bytes);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATE: {
                DateDayVector vec = (DateDayVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else if (colData[i] instanceof LocalDate) {
                        vec.set(i, (int) ((LocalDate) colData[i]).toEpochDay());
                    } else {
                        vec.set(i, (int) LocalDate.parse(colData[i].toString()).toEpochDay());
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case DATETIME:
            case TIMESTAMP: {
                TimeStampMilliVector vec = (TimeStampMilliVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else if (colData[i] instanceof LocalDateTime) {
                        long millis = ((LocalDateTime) colData[i])
                                .atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
                        vec.set(i, millis);
                    } else if (colData[i] instanceof java.sql.Timestamp) {
                        vec.set(i, ((java.sql.Timestamp) colData[i]).getTime());
                    } else {
                        long millis = LocalDateTime.parse(colData[i].toString())
                                .atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
                        vec.set(i, millis);
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case BINARY: {
                VarBinaryVector vec = (VarBinaryVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else if (colData[i] instanceof byte[]) {
                        vec.setSafe(i, (byte[]) colData[i]);
                    } else {
                        vec.setSafe(i, colData[i].toString().getBytes(StandardCharsets.UTF_8));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
            case ARRAY: {
                ListVector listVec = (ListVector) root.getVector(colIdx);
                listVec.allocateNew();
                FieldVector dataVec = listVec.getDataVector();
                int elemIdx = 0;
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        listVec.setNull(i);
                    } else {
                        List<?> list = (List<?>) colData[i];
                        listVec.startNewValue(i);
                        for (Object elem : list) {
                            writeListElement(dataVec, elemIdx++, elem);
                        }
                        listVec.endValue(i, list.size());
                    }
                }
                listVec.setValueCount(numRows);
                dataVec.setValueCount(elemIdx);
                break;
            }
            case MAP: {
                MapVector mapVec = (MapVector) root.getVector(colIdx);
                mapVec.allocateNew();
                StructVector structVec = (StructVector) mapVec.getDataVector();
                FieldVector keyVec = structVec.getChildrenFromFields().get(0);
                FieldVector valVec = structVec.getChildrenFromFields().get(1);
                int elemIdx = 0;
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        mapVec.setNull(i);
                    } else {
                        Map<?, ?> map = (Map<?, ?>) colData[i];
                        mapVec.startNewValue(i);
                        for (Map.Entry<?, ?> entry : map.entrySet()) {
                            structVec.setIndexDefined(elemIdx);
                            writeListElement(keyVec, elemIdx, entry.getKey());
                            writeListElement(valVec, elemIdx, entry.getValue());
                            elemIdx++;
                        }
                        mapVec.endValue(i, map.size());
                    }
                }
                mapVec.setValueCount(numRows);
                structVec.setValueCount(elemIdx);
                keyVec.setValueCount(elemIdx);
                valVec.setValueCount(elemIdx);
                break;
            }
            case STRUCT: {
                StructVector structVec = (StructVector) root.getVector(colIdx);
                structVec.allocateNew();
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        structVec.setNull(i);
                    } else {
                        structVec.setIndexDefined(i);
                        Map<?, ?> struct = (Map<?, ?>) colData[i];
                        for (FieldVector childVec : structVec.getChildrenFromFields()) {
                            Object val = struct.get(childVec.getName());
                            writeListElement(childVec, i, val);
                        }
                    }
                }
                structVec.setValueCount(numRows);
                for (FieldVector childVec : structVec.getChildrenFromFields()) {
                    childVec.setValueCount(numRows);
                }
                break;
            }
            default: {
                // Fallback: write as VarChar
                VarCharVector vec = (VarCharVector) root.getVector(colIdx);
                vec.allocateNew(numRows);
                for (int i = 0; i < numRows; i++) {
                    if (colData[i] == null) {
                        vec.setNull(i);
                    } else {
                        vec.setSafe(i, colData[i].toString().getBytes(StandardCharsets.UTF_8));
                    }
                }
                vec.setValueCount(numRows);
                break;
            }
        }
    }

    private void writeListElement(FieldVector vec, int idx, Object elem) {
        if (elem == null) {
            if (vec instanceof BaseFixedWidthVector) {
                ((BaseFixedWidthVector) vec).setNull(idx);
            } else if (vec instanceof BaseVariableWidthVector) {
                ((BaseVariableWidthVector) vec).setNull(idx);
            } else if (vec instanceof StructVector) {
                ((StructVector) vec).setNull(idx);
            } else if (vec instanceof MapVector) {
                ((MapVector) vec).setNull(idx);
            } else if (vec instanceof ListVector) {
                ((ListVector) vec).setNull(idx);
            }
            return;
        }
        if (vec instanceof VarCharVector) {
            byte[] bytes = elem instanceof byte[] ? (byte[]) elem
                    : elem.toString().getBytes(StandardCharsets.UTF_8);
            ((VarCharVector) vec).setSafe(idx, bytes);
        } else if (vec instanceof IntVector) {
            ((IntVector) vec).setSafe(idx, ((Number) elem).intValue());
        } else if (vec instanceof BigIntVector) {
            ((BigIntVector) vec).setSafe(idx, ((Number) elem).longValue());
        } else if (vec instanceof Float8Vector) {
            ((Float8Vector) vec).setSafe(idx, ((Number) elem).doubleValue());
        } else if (vec instanceof Float4Vector) {
            ((Float4Vector) vec).setSafe(idx, ((Number) elem).floatValue());
        } else if (vec instanceof SmallIntVector) {
            ((SmallIntVector) vec).setSafe(idx, ((Number) elem).shortValue());
        } else if (vec instanceof TinyIntVector) {
            ((TinyIntVector) vec).setSafe(idx, ((Number) elem).byteValue());
        } else if (vec instanceof BitVector) {
            ((BitVector) vec).setSafe(idx, (Boolean) elem ? 1 : 0);
        } else if (vec instanceof DecimalVector) {
            BigDecimal bd = elem instanceof BigDecimal ? (BigDecimal) elem
                    : new BigDecimal(elem.toString());
            ((DecimalVector) vec).setSafe(idx, bd);
        } else if (vec instanceof StructVector) {
            StructVector structVec = (StructVector) vec;
            structVec.setIndexDefined(idx);
            Map<?, ?> struct = (Map<?, ?>) elem;
            for (FieldVector childVec : structVec.getChildrenFromFields()) {
                writeListElement(childVec, idx, struct.get(childVec.getName()));
            }
        } else if (vec instanceof MapVector) {
            MapVector mapVec = (MapVector) vec;
            StructVector entryVec = (StructVector) mapVec.getDataVector();
            FieldVector keyVec = entryVec.getChildrenFromFields().get(0);
            FieldVector valVec = entryVec.getChildrenFromFields().get(1);
            Map<?, ?> map = (Map<?, ?>) elem;
            int offset = mapVec.startNewValue(idx);
            int j = 0;
            for (Map.Entry<?, ?> entry : map.entrySet()) {
                entryVec.setIndexDefined(offset + j);
                writeListElement(keyVec, offset + j, entry.getKey());
                writeListElement(valVec, offset + j, entry.getValue());
                j++;
            }
            mapVec.endValue(idx, map.size());
            entryVec.setValueCount(offset + j);
            keyVec.setValueCount(offset + j);
            valVec.setValueCount(offset + j);
        } else if (vec instanceof ListVector) {
            ListVector listVec = (ListVector) vec;
            FieldVector dataVec = listVec.getDataVector();
            List<?> list = (List<?>) elem;
            int offset = listVec.startNewValue(idx);
            for (int j = 0; j < list.size(); j++) {
                writeListElement(dataVec, offset + j, list.get(j));
            }
            listVec.endValue(idx, list.size());
            dataVec.setValueCount(offset + list.size());
        } else {
            byte[] bytes = elem.toString().getBytes(StandardCharsets.UTF_8);
            ((VarCharVector) vec).setSafe(idx, bytes);
        }
    }

    @Override
    public void close() throws IOException {
        try {
            if (batchWriter != null) {
                commitMessage = batchWriter.commit();
                batchWriter = null;
            }
            if (allocator != null) {
                allocator.close();
                allocator = null;
            }
            LOG.info("MaxComputeJniWriter closed: writeSessionId=" + writeSessionId
                    + ", partitionSpec=" + partitionSpec
                    + ", writtenRows=" + writtenRows
                    + ", blockId=" + blockId);
        } catch (Exception e) {
            String errorMsg = "Failed to close MaxCompute arrow writer";
            LOG.error(errorMsg, e);
            throw new IOException(errorMsg, e);
        }
    }

    @Override
    public Map<String, String> getStatistics() {
        Map<String, String> stats = new HashMap<>();
        stats.put("mc_partition_spec", partitionSpec != null ? partitionSpec : "");

        // Serialize WriterCommitMessage to Base64
        if (commitMessage != null) {
            try {
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                ObjectOutputStream oos = new ObjectOutputStream(baos);
                oos.writeObject(commitMessage);
                oos.close();
                stats.put("mc_commit_message", Base64.getEncoder().encodeToString(baos.toByteArray()));
            } catch (IOException e) {
                LOG.error("Failed to serialize WriterCommitMessage", e);
            }
        }

        stats.put("counter:WrittenRows", String.valueOf(writtenRows));
        stats.put("bytes:WrittenBytes", String.valueOf(writtenBytes));
        stats.put("timer:WriteTime", String.valueOf(writeTime));
        stats.put("timer:ReadTableTime", String.valueOf(readTableTime));
        return stats;
    }
}