AgentBoundedBatchTask.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.task;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.ClientPool;
import org.apache.doris.common.Config;
import org.apache.doris.common.FeConstants;
import org.apache.doris.common.Pair;
import org.apache.doris.common.ThriftUtils;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.metric.MetricRepo;
import org.apache.doris.system.Backend;
import org.apache.doris.thrift.BackendService;
import org.apache.doris.thrift.TAgentTaskRequest;
import org.apache.doris.thrift.TNetworkAddress;
import org.apache.doris.thrift.TTaskType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/*
* Like AgentBatchTask, but this class is used to submit tasks to BE in a bounded way, to avoid BE OOM.
*/
public class AgentBoundedBatchTask extends AgentBatchTask {
private static final Logger LOG = LogManager.getLogger(AgentBoundedBatchTask.class);
private static final int RPC_MAX_RETRY_TIMES = 3;
private int taskConcurrency;
private Map<Long, Integer> backendIdToConsumedTaskIndex;
private int beUnavailableMaxLostTimeSecond;
/**
* NOTE:
* this class is used to submit tasks to BE in a bounded way,
* and it will automatically add to AgentTaskQueue.
*
* @param batchSize the max number of tasks to submit to BE in one time
* @param taskConcurrency the max number of tasks to submit to BE in one time
*/
public AgentBoundedBatchTask(int batchSize, int taskConcurrency) {
super(batchSize);
this.taskConcurrency = taskConcurrency;
this.backendIdToConsumedTaskIndex = new HashMap<>();
this.beUnavailableMaxLostTimeSecond = Config.agent_task_be_unavailable_heartbeat_timeout_second;
}
@Override
public void addTask(AgentTask agentTask) {
if (agentTask == null) {
return;
}
long backendId = agentTask.getBackendId();
if (backendIdToTasks.containsKey(backendId)) {
List<AgentTask> tasks = backendIdToTasks.get(backendId);
tasks.add(agentTask);
} else {
List<AgentTask> tasks = new ArrayList<>();
tasks.add(agentTask);
backendIdToTasks.put(backendId, tasks);
}
}
@Override
public void run() {
int taskNum = getTaskNum();
LOG.info("begin to submit tasks to BE. total {} tasks, be task concurrency: {}", taskNum, taskConcurrency);
boolean submitFinished = false;
while (getSubmitTaskNum() < taskNum && !submitFinished) {
for (Long backendId : backendIdToTasks.keySet()) {
int consumedTaskIndex = backendIdToConsumedTaskIndex.getOrDefault(backendId, 0);
if (consumedTaskIndex >= backendIdToTasks.get(backendId).size()) {
LOG.info("backend {} has submitted all tasks, taskNum: {}",
backendId, backendIdToTasks.get(backendId).size());
continue;
}
boolean ok = false;
String errMsg = "";
Backend backend = null;
List<AgentTask> tasks = new ArrayList<>();
List<TAgentTaskRequest> agentTaskRequests = new ArrayList<>();
try {
backend = Env.getCurrentSystemInfo().getBackend(backendId);
tasks = this.backendIdToTasks.getOrDefault(backendId, new ArrayList<>());
if (backend == null) {
errMsg = String.format("backend %d is not found", backendId);
throw new RuntimeException(errMsg);
}
if (!backend.isAlive()) {
errMsg = String.format("backend %d is not alive", backendId);
if (System.currentTimeMillis() - backend.getLastUpdateMs()
> beUnavailableMaxLostTimeSecond * 1000) {
errMsg = String.format("backend %d is not alive too long, last update time: %s",
backendId, TimeUtils.longToTimeString(backend.getLastUpdateMs()));
throw new RuntimeException(errMsg);
}
continue;
}
int runningTaskNum = getRunningTaskNum(backendId);
LOG.info("backend {} has {} running tasks, task concurrency: {}",
backendId, runningTaskNum, taskConcurrency);
int index = consumedTaskIndex;
for (; index < tasks.size()
&& index < consumedTaskIndex + taskConcurrency - runningTaskNum; index++) {
agentTaskRequests.add(toAgentTaskRequest(tasks.get(index)));
// add to AgentTaskQueue
AgentTaskQueue.addTask(tasks.get(index));
if (agentTaskRequests.size() >= batchSize) {
submitTasks(backend, agentTaskRequests);
agentTaskRequests.clear();
}
}
submitTasks(backend, agentTaskRequests);
backendIdToConsumedTaskIndex.put(backendId, index);
LOG.info("submit task to backend {} finished, already submitted task num: {}/{}",
backendId, index, tasks.size());
ok = true;
} catch (Exception e) {
LOG.warn("task exec error. backend[{}]", backendId, e);
errMsg = String.format("task exec error: %s. backend[%d]", e.getMessage(), backendId);
if (!agentTaskRequests.isEmpty() && errMsg.contains("Broken pipe")) {
// Log the task binary message size and the max task type, to help debug the
// large thrift message size issue.
List<Pair<TTaskType, Long>> taskTypeAndSize = agentTaskRequests.stream()
.map(req -> Pair.of(req.getTaskType(), ThriftUtils.getBinaryMessageSize(req)))
.collect(Collectors.toList());
Pair<TTaskType, Long> maxTaskTypeAndSize = taskTypeAndSize.stream()
.max((p1, p2) -> Long.compare(p1.value(), p2.value()))
.orElse(null); // taskTypeAndSize is not empty
TTaskType maxType = maxTaskTypeAndSize.first;
long maxSize = maxTaskTypeAndSize.second;
long totalSize = taskTypeAndSize.stream().map(Pair::value).reduce(0L, Long::sum);
LOG.warn("submit {} tasks to backend[{}], total size: {}, max task type: {}, size: {}. msg: {}",
agentTaskRequests.size(), backendId, totalSize, maxType, maxSize, e.getMessage());
}
} finally {
if (!ok) {
submitFinished = true;
LOG.warn("submit task to backend {} failed, errMsg: {}, cancel all tasks", backendId, errMsg);
cancelAllTasks(errMsg);
}
}
}
try {
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
String errMsg = "submit task thread is interrupted";
LOG.warn(errMsg, e);
submitFinished = true;
cancelAllTasks(errMsg);
Thread.currentThread().interrupt();
break;
}
}
}
private static void submitTasks(Backend backend, List<TAgentTaskRequest> agentTaskRequests) throws Exception {
long start = System.currentTimeMillis();
if (agentTaskRequests.isEmpty()) {
return;
}
if (LOG.isDebugEnabled()) {
long size = agentTaskRequests.stream()
.map(ThriftUtils::getBinaryMessageSize)
.reduce(0L, Long::sum);
TTaskType firstTaskType = agentTaskRequests.get(0).getTaskType();
LOG.debug("submit {} tasks to backend[{}], total size: {}, first task type: {}",
agentTaskRequests.size(), backend.getId(), size, firstTaskType);
for (TAgentTaskRequest req : agentTaskRequests) {
LOG.debug("send task: type[{}], backend[{}], signature[{}]",
req.getTaskType(), backend.getId(), req.getSignature());
}
}
MetricRepo.COUNTER_AGENT_TASK_REQUEST_TOTAL.increase(1L);
BackendService.Client client = null;
TNetworkAddress address = null;
// create AgentClient
String host = FeConstants.runningUnitTest ? "127.0.0.1" : backend.getHost();
address = new TNetworkAddress(host, backend.getBePort());
long backendId = backend.getId();
boolean ok = false;
for (int attempt = 1; attempt <= RPC_MAX_RETRY_TIMES; attempt++) {
try {
if (client == null) {
// borrow new client when previous client request failed
client = ClientPool.backendPool.borrowObject(address);
}
client.submitTasks(agentTaskRequests);
ok = true;
break;
} catch (Exception e) {
if (attempt == RPC_MAX_RETRY_TIMES) {
LOG.warn("submit task to agent failed. backend[{}], request size: {}, elapsed:{} ms error: {}",
backendId, agentTaskRequests.size(), System.currentTimeMillis() - start,
e.getMessage());
throw e;
} else {
LOG.warn("submit task attempt {} failed, retrying... backend[{}], error: {}",
attempt, backendId, e.getMessage());
try {
Thread.sleep(200);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
} finally {
if (ok) {
ClientPool.backendPool.returnObject(address, client);
} else {
ClientPool.backendPool.invalidateObject(address, client);
client = null;
}
}
}
}
private int getSubmitTaskNum() {
return backendIdToConsumedTaskIndex.values().stream()
.mapToInt(Integer::intValue)
.sum();
}
private int getFinishedTaskNum(long backendId) {
int count = 0;
for (AgentTask agentTask : this.backendIdToTasks.get(backendId)) {
if (agentTask.isFinished()) {
count++;
}
}
return count;
}
private int getRunningTaskNum(long backendId) {
int count = 0;
List<AgentTask> tasks = backendIdToTasks.get(backendId);
int consumedTaskIndex = backendIdToConsumedTaskIndex.getOrDefault(backendId, 0);
for (int i = 0; i < consumedTaskIndex; i++) {
if (!tasks.get(i).isFinished) {
count++;
}
}
return count;
}
private void cancelAllTasks(String errMsg) {
for (List<AgentTask> beTasks : backendIdToTasks.values()) {
for (AgentTask task : beTasks) {
task.failedWithMsg(errMsg);
}
}
}
}