/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.client.impl;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.uniffle.client.response.DecompressedShuffleBlock;
import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.util.JavaUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DecompressionWorker {
    private static final Logger LOG = LoggerFactory.getLogger(DecompressionWorker.class);
    private final ExecutorService executorService;
    private final ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, DecompressedShuffleBlock>> tasks;
    private Codec codec;
    private final ThreadLocal<ByteBuffer> bufferLocal = ThreadLocal.withInitial(() -> ByteBuffer.allocate(0));

    public DecompressionWorker(Codec codec, int threads) {
        if (codec == null) {
            throw new IllegalArgumentException("Codec cannot be null");
        }
        if (threads <= 0) {
            throw new IllegalArgumentException("Threads must be greater than 0");
        }
        this.tasks = JavaUtils.newConcurrentMap();
        this.executorService = Executors.newFixedThreadPool(threads);
        this.codec = codec;
    }

    public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
        List<BufferSegment> bufferSegments = shuffleDataResult.getBufferSegments();
        ByteBuffer sharedByteBuffer = shuffleDataResult.getDataBuffer();
        int index = 0;
        LOG.debug("Adding {} segments with batch index:{} to decompression worker", (Object)bufferSegments.size(), (Object)batchIndex);
        for (BufferSegment bufferSegment : bufferSegments) {
            CompletableFuture<ByteBuffer> f = CompletableFuture.supplyAsync(() -> {
                int offset = bufferSegment.getOffset();
                int length = bufferSegment.getLength();
                ByteBuffer buffer = sharedByteBuffer.duplicate();
                buffer.position(offset);
                buffer.limit(offset + length);
                int uncompressedLen = bufferSegment.getUncompressLength();
                ByteBuffer dst = buffer.isDirect() ? ByteBuffer.allocateDirect(uncompressedLen) : ByteBuffer.allocate(uncompressedLen);
                this.codec.decompress(buffer, uncompressedLen, dst, 0);
                return dst;
            }, this.executorService);
            ConcurrentHashMap blocks = this.tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap());
            blocks.put(index++, new DecompressedShuffleBlock(f));
        }
    }

    public DecompressedShuffleBlock get(int batchIndex, int segmentIndex) {
        ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks = this.tasks.get(batchIndex);
        if (blocks == null) {
            return null;
        }
        DecompressedShuffleBlock block = blocks.remove(segmentIndex);
        return block;
    }

    public void close() {
        this.executorService.shutdown();
    }
}

