/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.tasks;

import com.sun.management.ThreadMXBean;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.common.util.concurrent.ConcurrentMapLong;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.ResourceStats;
import org.opensearch.tasks.ResourceStatsType;
import org.opensearch.tasks.ResourceUsageMetric;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.tasks.ThreadResourceInfo;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.ThreadPool;

@SuppressForbidden(reason="ThreadMXBean#getThreadAllocatedBytes")
public class TaskResourceTrackingService
implements RunnableTaskExecutionListener {
    private static final Logger logger = LogManager.getLogger(TaskManager.class);
    public static final Setting<Boolean> TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting("task_resource_tracking.enabled", true, Setting.Property.Dynamic, Setting.Property.NodeScope);
    public static final String TASK_ID = "TASK_ID";
    private static final ThreadMXBean threadMXBean = (ThreadMXBean)ManagementFactory.getThreadMXBean();
    private final ConcurrentMapLong<Task> resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency();
    private final ThreadPool threadPool;
    private volatile boolean taskResourceTrackingEnabled;

    @Inject
    public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) {
        this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings);
        this.threadPool = threadPool;
        clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled);
    }

    public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) {
        this.taskResourceTrackingEnabled = taskResourceTrackingEnabled;
    }

    public boolean isTaskResourceTrackingEnabled() {
        return this.taskResourceTrackingEnabled;
    }

    public boolean isTaskResourceTrackingSupported() {
        return threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled();
    }

    public ThreadContext.StoredContext startTracking(Task task) {
        if (!(task.supportsResourceTracking() && this.isTaskResourceTrackingEnabled() && this.isTaskResourceTrackingSupported())) {
            return () -> {};
        }
        logger.debug("Starting resource tracking for task: {}", (Object)task.getId());
        this.resourceAwareTasks.put(task.getId(), task);
        return this.addTaskIdToThreadContext(task);
    }

    public void stopTracking(Task task) {
        logger.debug("Stopping resource tracking for task: {}", (Object)task.getId());
        try {
            if (this.isCurrentThreadWorkingOnTask(task)) {
                this.taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId());
            }
        }
        catch (Exception e) {
            logger.warn("Failed while trying to mark the task execution on current thread completed.", (Throwable)e);
            assert (false);
        }
        finally {
            this.resourceAwareTasks.remove(task.getId());
        }
    }

    public void refreshResourceStats(Task ... tasks) {
        if (!this.isTaskResourceTrackingEnabled() || !this.isTaskResourceTrackingSupported()) {
            return;
        }
        for (Task task : tasks) {
            if (!task.supportsResourceTracking() || !this.resourceAwareTasks.containsKey(task.getId())) continue;
            this.refreshResourceStats(task);
        }
    }

    private void refreshResourceStats(Task resourceAwareTask) {
        try {
            logger.debug("Refreshing resource stats for Task: {}", (Object)resourceAwareTask.getId());
            List<Long> threadsWorkingOnTask = this.getThreadsWorkingOnTask(resourceAwareTask);
            threadsWorkingOnTask.forEach(threadId -> resourceAwareTask.updateThreadResourceStats((long)threadId, ResourceStatsType.WORKER_STATS, this.getResourceUsageMetricsForThread((long)threadId)));
        }
        catch (IllegalStateException e) {
            logger.debug("Resource stats already updated.");
        }
    }

    @Override
    public void taskExecutionStartedOnThread(long taskId, long threadId) {
        block3: {
            try {
                Task task = this.resourceAwareTasks.get(taskId);
                if (task != null) {
                    logger.debug("Task execution started on thread. Task: {}, Thread: {}", (Object)taskId, (Object)threadId);
                    task.startThreadResourceTracking(threadId, ResourceStatsType.WORKER_STATS, this.getResourceUsageMetricsForThread(threadId));
                }
            }
            catch (Exception e) {
                logger.warn((Message)new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", (Object)taskId), (Throwable)e);
                if ($assertionsDisabled) break block3;
                throw new AssertionError();
            }
        }
    }

    @Override
    public void taskExecutionFinishedOnThread(long taskId, long threadId) {
        block3: {
            try {
                Task task = this.resourceAwareTasks.get(taskId);
                if (task != null) {
                    logger.debug("Task execution finished on thread. Task: {}, Thread: {}", (Object)taskId, (Object)threadId);
                    task.stopThreadResourceTracking(threadId, ResourceStatsType.WORKER_STATS, this.getResourceUsageMetricsForThread(threadId));
                }
            }
            catch (Exception e) {
                logger.warn((Message)new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", (Object)taskId), (Throwable)e);
                if ($assertionsDisabled) break block3;
                throw new AssertionError();
            }
        }
    }

    public Map<Long, Task> getResourceAwareTasks() {
        return Collections.unmodifiableMap(this.resourceAwareTasks);
    }

    private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) {
        ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric(ResourceStats.MEMORY, threadMXBean.getThreadAllocatedBytes(threadId));
        ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId));
        return new ResourceUsageMetric[]{currentMemoryUsage, currentCPUUsage};
    }

    private boolean isCurrentThreadWorkingOnTask(Task task) {
        long threadId = Thread.currentThread().getId();
        List threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList());
        for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) {
            if (!threadResourceInfo.isActive()) continue;
            return true;
        }
        return false;
    }

    private List<Long> getThreadsWorkingOnTask(Task task) {
        ArrayList<Long> activeThreads = new ArrayList<Long>();
        for (List<ThreadResourceInfo> threadResourceInfos : task.getResourceStats().values()) {
            for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) {
                if (!threadResourceInfo.isActive()) continue;
                activeThreads.add(threadResourceInfo.getThreadId());
            }
        }
        return activeThreads;
    }

    private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) {
        ThreadContext threadContext = this.threadPool.getThreadContext();
        ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID));
        threadContext.putTransient(TASK_ID, task.getId());
        return storedContext;
    }
}

