MaxComputeJniScanner.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.maxcompute;

import org.apache.doris.common.jni.JniScanner;
import org.apache.doris.common.jni.vec.ColumnType;

import com.aliyun.odps.Odps;
import com.aliyun.odps.account.Account;
import com.aliyun.odps.account.AliyunAccount;
import com.aliyun.odps.table.configuration.CompressionCodec;
import com.aliyun.odps.table.configuration.ReaderOptions;
import com.aliyun.odps.table.configuration.RestOptions;
import com.aliyun.odps.table.enviroment.Credentials;
import com.aliyun.odps.table.enviroment.EnvironmentSettings;
import com.aliyun.odps.table.read.SplitReader;
import com.aliyun.odps.table.read.TableBatchReadSession;
import com.aliyun.odps.table.read.split.InputSplit;
import com.aliyun.odps.table.read.split.impl.IndexedInputSplit;
import com.aliyun.odps.table.read.split.impl.RowRangeInputSplit;
import com.google.common.base.Strings;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.log4j.Logger;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * MaxComputeJ JniScanner. BE will read data from the scanner object.
 */
public class MaxComputeJniScanner extends JniScanner {
    static {
        //Set `NullCheckingForGet.NULL_CHECKING_ENABLED` false.
        //We will call isNull() before calling getXXX(), so we can set this parameter
        // to skip the repeated check of isNull().
        System.setProperty("arrow.enable_null_check_for_get", "false");
    }

    private static final Logger LOG = Logger.getLogger(MaxComputeJniScanner.class);

    private static final String ACCESS_KEY = "access_key";
    private static final String SECRET_KEY = "secret_key";
    private static final String ENDPOINT = "endpoint";
    private static final String QUOTA = "quota";
    private static final String PROJECT = "project";
    private static final String TABLE = "table";

    private static final String START_OFFSET = "start_offset";
    private static final String SPLIT_SIZE = "split_size";
    private static final String SESSION_ID = "session_id";
    private static final String SCAN_SERIALIZER = "scan_serializer";
    private static final String TIME_ZONE = "time_zone";

    private static final String CONNECT_TIMEOUT = "connect_timeout";
    private static final String READ_TIMEOUT = "read_timeout";
    private static final String RETRY_COUNT  = "retry_count";

    private enum SplitType {
        BYTE_SIZE,
        ROW_OFFSET
    }

    private SplitType splitType;
    private TableBatchReadSession scan;
    public  String sessionId;

    private String project; //final ???
    private String table;

    private SplitReader<VectorSchemaRoot> currentSplitReader;
    private MaxComputeColumnValue columnValue;

    private Map<String, Integer> readColumnsToId;

    private long startOffset = -1L;
    private long splitSize = -1L;
    public EnvironmentSettings settings;
    public ZoneId timeZone;

    public MaxComputeJniScanner(int batchSize, Map<String, String> params) {
        String[] requiredFields = params.get("required_fields").split(",");
        String[] types = params.get("columns_types").split("#");
        ColumnType[] columnTypes = new ColumnType[types.length];
        for (int i = 0; i < types.length; i++) {
            columnTypes[i] = ColumnType.parseType(requiredFields[i], types[i]);
        }
        initTableInfo(columnTypes, requiredFields, batchSize);

        if (!Strings.isNullOrEmpty(params.get(START_OFFSET))
                && !Strings.isNullOrEmpty(params.get(SPLIT_SIZE))) {
            startOffset = Long.parseLong(params.get(START_OFFSET));
            splitSize = Long.parseLong(params.get(SPLIT_SIZE));
            if (splitSize == -1) {
                splitType = SplitType.BYTE_SIZE;
            } else {
                splitType = SplitType.ROW_OFFSET;
            }
        }

        String accessKey = Objects.requireNonNull(params.get(ACCESS_KEY), "required property '" + ACCESS_KEY + "'.");
        String secretKey = Objects.requireNonNull(params.get(SECRET_KEY), "required property '" + SECRET_KEY + "'.");
        String endpoint = Objects.requireNonNull(params.get(ENDPOINT), "required property '" + ENDPOINT + "'.");
        String quota = Objects.requireNonNull(params.get(QUOTA), "required property '" + QUOTA + "'.");
        String scanSerializer = Objects.requireNonNull(params.get(SCAN_SERIALIZER),
                "required property '" + SCAN_SERIALIZER + "'.");
        project = Objects.requireNonNull(params.get(PROJECT), "required property '" + PROJECT + "'.");
        table = Objects.requireNonNull(params.get(TABLE), "required property '" + TABLE + "'.");
        sessionId = Objects.requireNonNull(params.get(SESSION_ID), "required property '" + SESSION_ID + "'.");
        String timeZoneName = Objects.requireNonNull(params.get(TIME_ZONE), "required property '" + TIME_ZONE + "'.");
        try {
            timeZone = ZoneId.of(timeZoneName);
        } catch (Exception e) {
            LOG.warn(e.getMessage() + " Set timeZoneName = " + timeZoneName + "fail, use systemDefault.");
            timeZone = ZoneId.systemDefault();
        }


        Account account = new AliyunAccount(accessKey, secretKey);
        Odps odps = new Odps(account);

        odps.setDefaultProject(project);
        odps.setEndpoint(endpoint);

        Credentials credentials = Credentials.newBuilder().withAccount(odps.getAccount())
                .withAppAccount(odps.getAppAccount()).build();


        int connectTimeout = 10; // 10s
        if (!Strings.isNullOrEmpty(params.get(CONNECT_TIMEOUT))) {
            connectTimeout = Integer.parseInt(params.get(CONNECT_TIMEOUT));
        }

        int readTimeout = 120; // 120s
        if (!Strings.isNullOrEmpty(params.get(READ_TIMEOUT))) {
            readTimeout =  Integer.parseInt(params.get(READ_TIMEOUT));
        }

        int retryTimes = 4; // 4 times
        if (!Strings.isNullOrEmpty(params.get(RETRY_COUNT))) {
            retryTimes = Integer.parseInt(params.get(RETRY_COUNT));
        }

        RestOptions restOptions = RestOptions.newBuilder()
                .withConnectTimeout(connectTimeout)
                .withReadTimeout(readTimeout)
                .withRetryTimes(retryTimes).build();

        settings = EnvironmentSettings.newBuilder()
                .withCredentials(credentials)
                .withServiceEndpoint(odps.getEndpoint())
                .withQuotaName(quota)
                .withRestOptions(restOptions)
                .build();

        try {
            scan = (TableBatchReadSession) deserialize(scanSerializer);
        } catch (Exception e) {
            String errorMsg = "Failed to deserialize table batch read session.";
            LOG.warn(errorMsg, e);
            throw new IllegalArgumentException(errorMsg, e);
        }
    }


    @Override
    protected void initTableInfo(ColumnType[] requiredTypes, String[] requiredFields, int batchSize) {
        super.initTableInfo(requiredTypes, requiredFields, batchSize);
        readColumnsToId = new HashMap<>();
        for (int i = 0; i < fields.length; i++) {
            if (!Strings.isNullOrEmpty(fields[i])) {
                readColumnsToId.put(fields[i], i);
            }
        }
    }

    @Override
    public void open() throws IOException {
        try {
            InputSplit split;
            if (splitType == SplitType.BYTE_SIZE) {
                split = new IndexedInputSplit(sessionId, (int) startOffset);
            } else {
                split = new RowRangeInputSplit(sessionId, startOffset, splitSize);
            }

            currentSplitReader = scan.createArrowReader(split, ReaderOptions.newBuilder().withSettings(settings)
                    .withCompressionCodec(CompressionCodec.ZSTD)
                    .withReuseBatch(true)
                    .build());

        } catch (Exception e) {
            String errorMsg = "MaxComputeJniScanner Failed to open table batch read session.";
            LOG.warn(errorMsg, e);
            close();
            throw new IOException(errorMsg, e);
        }
    }

    @Override
    public void close() throws IOException {
        startOffset = -1;
        splitSize = -1;
        currentSplitReader = null;
        settings = null;
        scan = null;
        readColumnsToId.clear();
    }

    @Override
    protected int getNext() throws IOException {
        if (currentSplitReader == null) {
            return 0;
        }
        columnValue = new MaxComputeColumnValue();
        columnValue.setTimeZone(timeZone);
        int expectedRows = batchSize;
        return readVectors(expectedRows);
    }

    private int readVectors(int expectedRows) throws IOException {
        int curReadRows = 0;
        while (curReadRows < expectedRows) {
            try {
                if (!currentSplitReader.hasNext()) {
                    currentSplitReader.close();
                    currentSplitReader = null;
                    break;
                }
            } catch (Exception e) {
                String errorMsg = "MaxComputeJniScanner readVectors hasNext fail";
                LOG.warn(errorMsg, e);
                throw new IOException(e.getMessage(), e);
            }

            try {
                VectorSchemaRoot data = currentSplitReader.get();
                if (data.getRowCount() == 0) {
                    break;
                }

                List<FieldVector> fieldVectors = data.getFieldVectors();
                int batchRows = 0;
                for (FieldVector column : fieldVectors) {
                    Integer readColumnId = readColumnsToId.get(column.getName());
                    batchRows = column.getValueCount();
                    if (readColumnId == null) {
                        continue;
                    }
                    columnValue.reset(column);
                    for (int j = 0; j < batchRows; j++) {
                        columnValue.setColumnIdx(j);
                        appendData(readColumnId, columnValue);
                    }
                }
                curReadRows += batchRows;
            } catch (Exception e) {
                String errorMsg = String.format("MaxComputeJniScanner Fail to read arrow data. "
                        + "curReadRows = {}, expectedRows = {}", curReadRows, expectedRows);
                LOG.warn(errorMsg, e);
                throw new RuntimeException(errorMsg, e);
            }
        }
        return curReadRows;
    }

    private static Object deserialize(String serializedString) throws IOException, ClassNotFoundException {
        byte[] serializedBytes = Base64.getDecoder().decode(serializedString);
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serializedBytes);
        ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
        return objectInputStream.readObject();
    }
}