/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.write;

import java.io.IOException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.common.write.PushStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InFlightRequestTracker {
    private static final Logger logger = LoggerFactory.getLogger(InFlightRequestTracker.class);
    private final long waitInflightTimeoutMs;
    private final long delta;
    private final PushState pushState;
    private final PushStrategy pushStrategy;
    private final AtomicInteger batchId = new AtomicInteger();
    private final ConcurrentHashMap<String, Set<Integer>> inflightBatchesPerAddress = JavaUtils.newConcurrentHashMap();
    private ConcurrentHashMap<String, LongAdder> inflightBytesSizePerAddress = null;
    private ConcurrentHashMap<Integer, Integer> inflightBatchBytesSizes = null;
    private final int maxInFlightReqsTotal;
    private final boolean maxInFlightBytesSizeEnabled;
    private Long maxInFlightBytesSizeTotal = null;
    private Long maxInFlightBytesSizePerWorker = null;
    private final LongAdder totalInflightReqs = new LongAdder();
    private LongAdder totalInflightBytes = null;
    private volatile boolean cleaned = false;

    public InFlightRequestTracker(CelebornConf conf, PushState pushState) {
        this.waitInflightTimeoutMs = conf.clientPushLimitInFlightTimeoutMs();
        this.delta = conf.clientPushLimitInFlightSleepDeltaMs();
        this.pushState = pushState;
        this.pushStrategy = PushStrategy.getStrategy(conf);
        this.maxInFlightReqsTotal = conf.clientPushMaxReqsInFlightTotal();
        this.maxInFlightBytesSizeEnabled = conf.clientPushMaxBytesSizeInFlightEnabled();
        if (this.maxInFlightBytesSizeEnabled) {
            this.inflightBytesSizePerAddress = JavaUtils.newConcurrentHashMap();
            this.inflightBatchBytesSizes = JavaUtils.newConcurrentHashMap();
            this.maxInFlightBytesSizeTotal = conf.clientPushMaxBytesSizeInFlightTotal();
            this.maxInFlightBytesSizePerWorker = conf.clientPushMaxBytesSizeInFlightPerWorker();
            this.totalInflightBytes = new LongAdder();
        }
    }

    public void addBatch(int batchId, int batchBytesSize, String hostAndPushPort) {
        Set batchIdSetPerPair = this.inflightBatchesPerAddress.computeIfAbsent(hostAndPushPort, id -> ConcurrentHashMap.newKeySet());
        batchIdSetPerPair.add(batchId);
        this.totalInflightReqs.increment();
        if (this.maxInFlightBytesSizeEnabled) {
            LongAdder bytesSizePerPair = this.inflightBytesSizePerAddress.computeIfAbsent(hostAndPushPort, id -> new LongAdder());
            bytesSizePerPair.add(batchBytesSize);
            this.inflightBatchBytesSizes.put(batchId, batchBytesSize);
            this.totalInflightBytes.add(batchBytesSize);
        }
    }

    public void removeBatch(int batchId, String hostAndPushPort) {
        Set<Integer> batchIdSet = this.inflightBatchesPerAddress.get(hostAndPushPort);
        if (batchIdSet != null) {
            batchIdSet.remove(batchId);
        } else {
            logger.info("Batches of {} in flight is null.", (Object)hostAndPushPort);
        }
        this.totalInflightReqs.decrement();
        if (this.maxInFlightBytesSizeEnabled) {
            int inflightBatchBytesSize = -Optional.ofNullable(this.inflightBatchBytesSizes.remove(batchId)).orElse(0).intValue();
            LongAdder inflightBytesSize = this.inflightBytesSizePerAddress.get(hostAndPushPort);
            if (inflightBytesSize != null) {
                inflightBytesSize.add(inflightBatchBytesSize);
            }
            this.totalInflightBytes.add(inflightBatchBytesSize);
        }
    }

    public void onSuccess(String hostAndPushPort) {
        this.pushStrategy.onSuccess(hostAndPushPort);
    }

    public void onCongestControl(String hostAndPushPort) {
        this.pushStrategy.onCongestControl(hostAndPushPort);
    }

    public Set<Integer> getBatchIdSetByAddressPair(String hostAndPort) {
        return this.inflightBatchesPerAddress.computeIfAbsent(hostAndPort, pair -> ConcurrentHashMap.newKeySet());
    }

    public LongAdder getBatchBytesSizeByAddressPair(String hostAndPort) {
        return this.maxInFlightBytesSizeEnabled ? this.inflightBytesSizePerAddress.computeIfAbsent(hostAndPort, id -> new LongAdder()) : new LongAdder();
    }

    public boolean limitMaxInFlight(String hostAndPushPort) throws IOException {
        long times;
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        this.pushStrategy.limitPushSpeed(this.pushState, hostAndPushPort);
        int currentMaxReqsInFlight = this.pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort);
        Set<Integer> batchIdSet = this.getBatchIdSetByAddressPair(hostAndPushPort);
        LongAdder batchBytesSize = this.getBatchBytesSizeByAddressPair(hostAndPushPort);
        try {
            for (times = this.waitInflightTimeoutMs / this.delta; times > 0L; --times) {
                if (this.cleaned) {
                    return false;
                }
                if (!(this.totalInflightReqs.sum() <= (long)this.maxInFlightReqsTotal && batchIdSet.size() <= currentMaxReqsInFlight || this.maxInFlightBytesSizeEnabled && this.totalInflightBytes.sum() <= this.maxInFlightBytesSizeTotal && batchBytesSize.sum() <= this.maxInFlightBytesSizePerWorker)) {
                    if (this.pushState.exception.get() != null) {
                        throw this.pushState.exception.get();
                    }
                    Thread.sleep(this.delta);
                    continue;
                }
                break;
            }
        }
        catch (InterruptedException e) {
            this.pushState.exception.set(new CelebornIOException(e));
        }
        if (times <= 0L) {
            if (this.totalInflightReqs.sum() > (long)this.maxInFlightReqsTotal || batchIdSet.size() > currentMaxReqsInFlight) {
                logger.warn("After waiting for {} ms, there are still {} requests in flight (limit: {}): {} batches for hostAndPushPort {}, which exceeds the current limit {}.", new Object[]{this.waitInflightTimeoutMs, this.totalInflightReqs.sum(), this.maxInFlightReqsTotal, batchIdSet.size(), hostAndPushPort, currentMaxReqsInFlight});
            }
            if (this.maxInFlightBytesSizeEnabled && (this.totalInflightBytes.sum() > this.maxInFlightBytesSizeTotal || batchBytesSize.sum() > this.maxInFlightBytesSizePerWorker)) {
                logger.warn("After waiting for {} ms, there are still {} bytes in flight (limit: {}): {} bytes for hostAndPushPort {}, which exceeds the current limit {}.", new Object[]{this.waitInflightTimeoutMs, this.totalInflightBytes.sum(), this.maxInFlightBytesSizeTotal, batchBytesSize.sum(), hostAndPushPort, this.maxInFlightBytesSizePerWorker});
            }
        }
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        return times <= 0L;
    }

    public boolean limitZeroInFlight() throws IOException {
        long times;
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        try {
            for (times = this.waitInflightTimeoutMs / this.delta; times > 0L; --times) {
                if (this.cleaned) {
                    return false;
                }
                if (this.totalInflightReqs.sum() != 0L) {
                    if (this.pushState.exception.get() != null) {
                        throw this.pushState.exception.get();
                    }
                    Thread.sleep(this.delta);
                    continue;
                }
                break;
            }
        }
        catch (InterruptedException e) {
            this.pushState.exception.set(new CelebornIOException(e));
        }
        if (times <= 0L) {
            logger.error("After waiting for {} ms, there are still {} requests in flight: {}, which exceeds the current limit 0.", new Object[]{this.waitInflightTimeoutMs, this.totalInflightReqs.sum(), this.inflightBatchesPerAddress.entrySet().stream().filter(c -> !((Set)c.getValue()).isEmpty()).map(c -> ((Set)c.getValue()).size() + " batches for hostAndPushPort " + (String)c.getKey()).collect(Collectors.joining(", ", "[", "]"))});
        }
        if (this.pushState.exception.get() != null) {
            throw this.pushState.exception.get();
        }
        return times <= 0L;
    }

    public int remainingAllowPushes(String hostAndPushPort) {
        return this.pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort) - this.getBatchIdSetByAddressPair(hostAndPushPort).size();
    }

    protected int nextBatchId() {
        return this.batchId.incrementAndGet();
    }

    public void cleanup() {
        logger.info("Cleanup {} requests and {} batches in flight.", (Object)this.totalInflightReqs.sum(), (Object)this.inflightBatchesPerAddress.values().stream().mapToInt(Set::size).sum());
        this.cleaned = true;
        this.inflightBatchesPerAddress.clear();
        this.pushStrategy.clear();
        if (this.maxInFlightBytesSizeEnabled) {
            logger.info("Cleanup {} bytes in flight.", (Object)this.totalInflightBytes.sum());
            this.inflightBytesSizePerAddress.clear();
            this.inflightBatchBytesSizes.clear();
        }
    }
}

