/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.lineage.LineageDedupUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.utils.Explain;

public class AggregateUnaryCPInstruction
extends UnaryCPInstruction {
    private final AUType _type;

    private AggregateUnaryCPInstruction(Operator op, CPOperand in, CPOperand out, AUType type, String opcode, String istr) {
        this(op, in, null, null, out, type, opcode, istr);
    }

    protected AggregateUnaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, AUType type, String opcode, String istr) {
        super(CPInstruction.CPType.AggregateUnary, op, in1, in2, in3, out, opcode, istr);
        this._type = type;
    }

    public static AggregateUnaryCPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        if (opcode.equalsIgnoreCase("nrow") || opcode.equalsIgnoreCase("ncol") || opcode.equalsIgnoreCase("length") || opcode.equalsIgnoreCase("exists") || opcode.equalsIgnoreCase("lineage")) {
            return new AggregateUnaryCPInstruction(new SimpleOperator(Builtin.getBuiltinFnObject(opcode)), in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str);
        }
        if (opcode.equalsIgnoreCase("uacd")) {
            return new AggregateUnaryCPInstruction(new SimpleOperator(null), in1, out, AUType.COUNT_DISTINCT, opcode, str);
        }
        if (opcode.equalsIgnoreCase("uacdap")) {
            return new AggregateUnaryCPInstruction(new SimpleOperator(null), in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
        }
        if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin")) {
            AggregateUnaryOperator aggun = InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode, Integer.parseInt(parts[4]), Integer.parseInt(parts[3]));
            return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.DEFAULT, opcode, str);
        }
        AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode, Integer.parseInt(parts[3]));
        return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.DEFAULT, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        String output_name = this.output.getName();
        String opcode = this.getOpcode();
        switch (this._type) {
            case NROW: 
            case NCOL: 
            case LENGTH: {
                if (!ec.getVariables().keySet().contains(this.input1.getName())) {
                    throw new DMLRuntimeException("Variable '" + this.input1.getName() + "' does not exist.");
                }
                long rval = -1L;
                if (this.input1.getDataType() == Types.DataType.LIST && this._type == AUType.LENGTH) {
                    rval = ((ListObject)ec.getVariable(this.input1.getName())).getLength();
                } else if (this.input1.getDataType().isMatrix() || this.input1.getDataType().isFrame()) {
                    DataCharacteristics mc = ec.getDataCharacteristics(this.input1.getName());
                    rval = AggregateUnaryCPInstruction.getSizeMetaData(this._type, mc);
                    if (!mc.dimsKnown()) {
                        if (DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE || this.input1.getDataType() == Types.DataType.FRAME) {
                            CacheableData<?> obj = ec.getCacheableData(this.input1.getName());
                            obj.acquireRead();
                            obj.refreshMetaData();
                            obj.release();
                            mc = ec.getDataCharacteristics(this.input1.getName());
                            rval = AggregateUnaryCPInstruction.getSizeMetaData(this._type, mc);
                        } else {
                            throw new DMLRuntimeException("Invalid meta data returned by '" + opcode + "': " + rval + ":" + this.instString);
                        }
                    }
                }
                ec.setScalarOutput(output_name, new IntObject(rval));
                break;
            }
            case EXISTS: {
                String varName = !this.input1.isScalar() ? this.input1.getName() : ec.getScalarInput(this.input1).getStringValue();
                boolean rval = ec.getVariables().keySet().contains(varName);
                ec.setScalarOutput(output_name, new BooleanObject(rval));
                break;
            }
            case LINEAGE: {
                if (ec.getLineageItem(this.input1) == null) {
                    throw new DMLRuntimeException("Lineage trace for variable " + this.input1.getName() + " unavailable.");
                }
                LineageItem li = ec.getLineageItem(this.input1);
                String out = !DMLScript.LINEAGE_DEDUP ? Explain.explain(li) : Explain.explain(li) + LineageDedupUtils.mergeExplainDedupBlocks(ec);
                ec.setScalarOutput(output_name, new StringObject(out));
                break;
            }
            case COUNT_DISTINCT: 
            case COUNT_DISTINCT_APPROX: {
                if (!ec.getVariables().keySet().contains(this.input1.getName())) {
                    throw new DMLRuntimeException("Variable '" + this.input1.getName() + "' does not exist.");
                }
                MatrixBlock input = ec.getMatrixInput(this.input1.getName());
                CountDistinctOperator op = new CountDistinctOperator(this._type);
                int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
                ec.releaseMatrixInput(this.input1.getName());
                ec.setScalarOutput(output_name, new IntObject(res));
                break;
            }
            default: {
                AggregateUnaryOperator au_op = (AggregateUnaryOperator)this._optr;
                if (this.input1.getDataType() == Types.DataType.MATRIX) {
                    MatrixBlock matBlock = ec.getMatrixInput(this.input1.getName());
                    MatrixBlock resultBlock = matBlock.aggregateUnaryOperations(au_op, new MatrixBlock(), matBlock.getNumRows(), new MatrixIndexes(1L, 1L), true);
                    ec.releaseMatrixInput(this.input1.getName());
                    if (this.output.getDataType() == Types.DataType.SCALAR) {
                        DoubleObject ret = new DoubleObject(resultBlock.getValue(0, 0));
                        ec.setScalarOutput(output_name, ret);
                        break;
                    }
                    ec.setMatrixOutput(output_name, resultBlock);
                    break;
                }
                if (this.input1.getDataType() == Types.DataType.TENSOR) {
                    BasicTensorBlock basicTensor = ec.getTensorInput(this.input1.getName()).getBasicTensor();
                    BasicTensorBlock resultBlock = basicTensor.aggregateUnaryOperations(au_op, new BasicTensorBlock());
                    ec.releaseTensorInput(this.input1.getName());
                    if (this.output.getDataType() == Types.DataType.SCALAR) {
                        ec.setScalarOutput(output_name, ScalarObjectFactory.createScalarObject(this.input1.getValueType(), resultBlock.get(new int[]{0, 0})));
                        break;
                    }
                    ec.setTensorOutput(output_name, new TensorBlock(resultBlock));
                    break;
                }
                throw new DMLRuntimeException(opcode + " only supported on matrix or tensor.");
            }
        }
    }

    public AUType getAUType() {
        return this._type;
    }

    private static long getSizeMetaData(AUType type, DataCharacteristics mc) {
        switch (type) {
            case NROW: {
                return mc.getRows();
            }
            case NCOL: {
                return mc.getCols();
            }
            case LENGTH: {
                return mc.getRows() * mc.getCols();
            }
        }
        throw new RuntimeException("Opcode not applicable: " + type.name());
    }

    public static enum AUType {
        NROW,
        NCOL,
        LENGTH,
        EXISTS,
        LINEAGE,
        COUNT_DISTINCT,
        COUNT_DISTINCT_APPROX,
        DEFAULT;


        public boolean isMeta() {
            return this != DEFAULT;
        }
    }
}

