/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.service.deploy.worker.storage;

import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import java.io.IOException;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import javax.annotation.concurrent.GuardedBy;
import org.apache.celeborn.common.meta.DiskFileInfo;
import org.apache.celeborn.common.meta.FileInfo;
import org.apache.celeborn.common.meta.MapFileMeta;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
import org.apache.celeborn.service.deploy.worker.storage.MapPartitionData;
import org.apache.celeborn.service.deploy.worker.storage.segment.SegmentMapPartitionData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CreditStreamManager {
    private static final Logger logger = LoggerFactory.getLogger(CreditStreamManager.class);
    private final AtomicLong nextStreamId;
    private final ConcurrentHashMap<Long, StreamState> streams;
    private final ConcurrentHashMap<FileInfo, MapPartitionData> activeMapPartitions;
    private final ConcurrentHashMap<String, ExecutorService> storageFetcherPool = JavaUtils.newConcurrentHashMap();
    private int minReadBuffers;
    private int maxReadBuffers;
    private int threadsPerMountPoint;
    private int minBuffersToTriggerRead;
    private final BlockingQueue<DelayedStreamId> recycleStreamIds = new DelayQueue<DelayedStreamId>();
    @GuardedBy(value="lock")
    private volatile ExecutorService recycleThread;
    private final Object lock = new Object();

    public CreditStreamManager(int minReadBuffers, int maxReadBuffers, int threadsPerMountPoint, int minBuffersToTriggerRead) {
        this.nextStreamId = new AtomicLong((long)new Random().nextInt(Integer.MAX_VALUE) * 1000L);
        this.streams = JavaUtils.newConcurrentHashMap();
        this.activeMapPartitions = JavaUtils.newConcurrentHashMap();
        this.minReadBuffers = minReadBuffers;
        this.maxReadBuffers = maxReadBuffers;
        this.threadsPerMountPoint = threadsPerMountPoint;
        this.minBuffersToTriggerRead = minBuffersToTriggerRead;
        MemoryManager.instance().setCreditStreamManager(this);
        logger.debug("Initialize buffer stream manager with {} {} {}", new Object[]{this.minReadBuffers, this.maxReadBuffers, threadsPerMountPoint});
    }

    public long registerStream(Consumer<Long> notifyStreamHandlerCallback, Channel channel, String shuffleKey, int initialCredit, int startSubIndex, int endSubIndex, DiskFileInfo fileInfo) throws IOException {
        long streamId = this.nextStreamId.getAndIncrement();
        logger.debug("Register stream start from {}, streamId: {}, fileInfo: {}", new Object[]{channel.remoteAddress(), streamId, fileInfo});
        AtomicReference exception = new AtomicReference();
        MapPartitionData mapPartitionData = this.activeMapPartitions.compute((FileInfo)fileInfo, (k, v) -> {
            if (v == null) {
                try {
                    MapFileMeta fileMeta = (MapFileMeta)fileInfo.getFileMeta();
                    v = fileMeta.isSegmentGranularityVisible() ? new SegmentMapPartitionData(this.minReadBuffers, this.maxReadBuffers, this.storageFetcherPool, this.threadsPerMountPoint, fileInfo, id -> this.recycleStream((long)id), this.minBuffersToTriggerRead) : new MapPartitionData(this.minReadBuffers, this.maxReadBuffers, this.storageFetcherPool, this.threadsPerMountPoint, fileInfo, id -> this.recycleStream((long)id), this.minBuffersToTriggerRead);
                }
                catch (IOException e) {
                    exception.set(e);
                    return null;
                }
            }
            this.initializeStreamStateAndPartitionReader(channel, shuffleKey, startSubIndex, endSubIndex, (FileInfo)fileInfo, streamId, (MapPartitionData)v);
            return v;
        });
        if (exception.get() != null) {
            throw (IOException)exception.get();
        }
        mapPartitionData.tryRequestBufferOrRead();
        notifyStreamHandlerCallback.accept(streamId);
        this.addCredit(initialCredit, streamId);
        logger.debug("Register stream streamId: {}, fileInfo: {}", (Object)streamId, (Object)fileInfo);
        return streamId;
    }

    private void initializeStreamStateAndPartitionReader(Channel channel, String shuffleKey, int startSubIndex, int endSubIndex, FileInfo fileInfo, long streamId, MapPartitionData mapPartitionData) {
        StreamState streamState = new StreamState(channel, shuffleKey, ((MapFileMeta)fileInfo.getFileMeta()).getBufferSize(), mapPartitionData);
        this.streams.put(streamId, streamState);
        mapPartitionData.setupDataPartitionReader(startSubIndex, endSubIndex, streamId, channel);
    }

    private void addCredit(MapPartitionData mapPartitionData, int numCredit, long streamId) {
        logger.debug("streamId: {}, add credit: {}", (Object)streamId, (Object)numCredit);
        try {
            if (mapPartitionData != null && numCredit > 0) {
                mapPartitionData.addReaderCredit(numCredit, streamId);
            }
        }
        catch (Throwable e) {
            logger.error("streamId: {}, add credit end: {}", (Object)streamId, (Object)numCredit);
        }
    }

    private void notifyRequiredSegment(MapPartitionData mapPartitionData, int requiredSegmentId, long streamId, int subPartitionId) {
        logger.debug("Receive RequiredSegment from client, streamId: {}, requiredSegmentId: {}, subPartitionId: {}", new Object[]{streamId, requiredSegmentId, subPartitionId});
        try {
            if (mapPartitionData instanceof SegmentMapPartitionData) {
                ((SegmentMapPartitionData)mapPartitionData).notifyRequiredSegmentId(requiredSegmentId, streamId, subPartitionId);
            } else {
                logger.warn("Only non-null SegmentMapPartitionData is expected for notifyRequiredSegment.");
            }
        }
        catch (Throwable e) {
            logger.error(String.format("Fail to notify segmentId %s for stream %s.", requiredSegmentId, streamId), e);
            throw e;
        }
    }

    public void addCredit(int numCredit, long streamId) {
        if (!this.streams.containsKey(streamId)) {
            logger.warn("Ignore AddCredit from stream {}, numCredit {}.", (Object)streamId, (Object)numCredit);
            return;
        }
        MapPartitionData mapPartitionData = this.streams.get(streamId).getMapPartitionData();
        this.addCredit(mapPartitionData, numCredit, streamId);
    }

    public void notifyRequiredSegment(int requiredSegmentId, long streamId, int subPartitionId) {
        StreamState streamState = this.streams.get(streamId);
        if (streamState != null) {
            this.notifyRequiredSegment(streamState.getMapPartitionData(), requiredSegmentId, streamId, subPartitionId);
        } else {
            logger.warn("Ignore RequiredSegment from stream {}, subPartition {}, segmentId {}.", new Object[]{streamId, subPartitionId, requiredSegmentId});
        }
    }

    public void connectionTerminated(Channel channel) {
        for (Map.Entry<Long, StreamState> entry : this.streams.entrySet()) {
            if (entry.getValue().getAssociatedChannel() != channel) continue;
            logger.info("connection closed, clean streamId: {}", (Object)entry.getKey());
            this.recycleStream(entry.getKey());
        }
    }

    public void notifyStreamEndByClient(long streamId) {
        this.recycleStream(streamId);
    }

    public void recycleStream(long streamId) {
        this.recycleStreamIds.add(new DelayedStreamId(streamId));
        this.startRecycleThread();
    }

    @VisibleForTesting
    public int numRecycleStreams() {
        return this.recycleStreamIds.size();
    }

    @VisibleForTesting
    public ConcurrentHashMap<Long, StreamState> getStreams() {
        return this.streams;
    }

    public String getStreamShuffleKey(Long streamId) {
        StreamState streamState = this.streams.get(streamId);
        return streamState == null ? null : streamState.getShuffleKey();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void startRecycleThread() {
        Object object = this.lock;
        synchronized (object) {
            if (this.recycleThread == null) {
                this.recycleThread = ThreadUtils.newDaemonSingleThreadExecutor((String)"worker-credit-stream-manager-recycler");
                this.recycleThread.submit(() -> {
                    while (true) {
                        try {
                            while (true) {
                                DelayedStreamId delayedStreamId = this.recycleStreamIds.take();
                                this.cleanResource(delayedStreamId.streamId);
                            }
                        }
                        catch (Throwable e) {
                            logger.warn(e.getMessage(), e);
                            continue;
                        }
                        break;
                    }
                });
                logger.info("start stream recycle thread");
            }
        }
    }

    public void cleanResource(Long streamId) {
        MapPartitionData mapPartitionData;
        logger.debug("received clean stream: {}", (Object)streamId);
        if (this.streams.containsKey(streamId) && (mapPartitionData = this.streams.get(streamId).getMapPartitionData()) != null) {
            if (mapPartitionData.releaseReader(streamId)) {
                this.streams.remove(streamId);
                if (mapPartitionData.getReaders().isEmpty()) {
                    DiskFileInfo fileInfo = mapPartitionData.getDiskFileInfo();
                    this.activeMapPartitions.computeIfPresent((FileInfo)fileInfo, (k, v) -> {
                        if (v.getReaders().isEmpty()) {
                            v.close();
                            return null;
                        }
                        return v;
                    });
                }
            } else {
                logger.debug("retry clean stream: {}", (Object)streamId);
                this.recycleStreamIds.add(new DelayedStreamId(streamId));
            }
        }
    }

    public long getStreamsCount() {
        return this.streams.size();
    }

    public int getActiveMapPartitionCount() {
        return this.activeMapPartitions.size();
    }

    public static class DelayedStreamId
    implements Delayed {
        private static final long delayTime = 100L;
        private long createMillis = System.currentTimeMillis();
        private long streamId;

        public DelayedStreamId(long streamId) {
            this.createMillis += 100L;
            this.streamId = streamId;
        }

        @Override
        public long getDelay(TimeUnit unit) {
            long diff = this.createMillis - System.currentTimeMillis();
            return unit.convert(diff, TimeUnit.MILLISECONDS);
        }

        public long getCreateMillis() {
            return this.createMillis;
        }

        @Override
        public int compareTo(Delayed o) {
            long otherCreateMillis = ((DelayedStreamId)o).getCreateMillis();
            if (this.createMillis < otherCreateMillis) {
                return -1;
            }
            if (this.createMillis > otherCreateMillis) {
                return 1;
            }
            return 0;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder("DelayedStreamId{");
            sb.append("createMillis=").append(this.createMillis);
            sb.append(", streamId=").append(this.streamId);
            sb.append('}');
            return sb.toString();
        }
    }

    protected static class StreamState {
        private Channel associatedChannel;
        private String shuffleKey;
        private int bufferSize;
        private MapPartitionData mapPartitionData;

        public StreamState(Channel associatedChannel, String shuffleKey, int bufferSize, MapPartitionData mapPartitionData) {
            this.associatedChannel = associatedChannel;
            this.shuffleKey = shuffleKey;
            this.bufferSize = bufferSize;
            this.mapPartitionData = mapPartitionData;
        }

        public Channel getAssociatedChannel() {
            return this.associatedChannel;
        }

        public String getShuffleKey() {
            return this.shuffleKey;
        }

        public int getBufferSize() {
            return this.bufferSize;
        }

        public MapPartitionData getMapPartitionData() {
            return this.mapPartitionData;
        }
    }
}

