/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateCell;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.Pair;

public class TemplateMultiAgg
extends TemplateCell {
    public TemplateMultiAgg() {
        super(TemplateBase.TemplateType.MAGG, TemplateBase.CloseType.OPEN_VALID);
    }

    public TemplateMultiAgg(TemplateBase.CloseType ctype) {
        super(TemplateBase.TemplateType.MAGG, ctype);
    }

    @Override
    public boolean open(Hop hop) {
        return false;
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return false;
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return false;
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        return TemplateBase.CloseType.CLOSED_INVALID;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        CPlanMemoTable.MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.MAGG);
        ArrayList<Hop> roots = new ArrayList<Hop>();
        for (int i = 0; i < 3; ++i) {
            if (!multiAgg.isPlanRef(i)) continue;
            roots.add(memo._hopRefs.get(multiAgg.input(i)));
        }
        Hop.resetVisitStatus(roots);
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        for (Hop root : roots) {
            super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
        }
        Hop.resetVisitStatus(roots);
        Hop shared = this.getSparseSafeSharedInput(roots, inHops);
        Hop[] sinHops = (Hop[])inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new TemplateCell.HopInputComparator(shared)).toArray(Hop[]::new);
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        ArrayList<CNode> outputs = new ArrayList<CNode>();
        ArrayList<Hop.AggOp> aggOps = new ArrayList<Hop.AggOp>();
        for (Hop root : roots) {
            CNode node = tmp.get(root.getHopID());
            if (node instanceof CNodeData && ((CNodeData)inputs.get(0)).getHopID() != ((CNodeData)node).getHopID()) {
                node = new CNodeUnary(node, roots.get(0).getDim2() == 1L ? CNodeUnary.UnaryType.LOOKUP_R : CNodeUnary.UnaryType.LOOKUP_RC);
            }
            outputs.add(node);
            aggOps.add(TemplateUtils.getAggOp(root));
        }
        CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
        tpl.setAggOps(aggOps);
        tpl.setSparseSafe(this.isSparseSafe(roots, sinHops[0], tpl.getOutputs(), tpl.getAggOps(), true));
        tpl.setRootNodes(roots);
        tpl.setBeginLine(hop.getBeginLine());
        return new Pair<Hop[], CNodeTpl>(sinHops, tpl);
    }

    private Hop getSparseSafeSharedInput(ArrayList<Hop> roots, HashSet<Hop> inHops) {
        Set<Hop> tmp = inHops.stream().filter(h -> h.getDataType().isMatrix()).collect(Collectors.toSet());
        for (Hop root : roots) {
            root.resetVisitStatus();
            HashSet<Hop> inputs = new HashSet<Hop>();
            this.rCollectSparseSafeInputs(root, inHops, inputs);
            tmp.removeIf(h -> !inputs.contains(h));
        }
        Hop.resetVisitStatus(roots);
        return tmp.isEmpty() ? null : tmp.toArray(new Hop[0])[0];
    }

    private void rCollectSparseSafeInputs(Hop current, HashSet<Hop> inHops, HashSet<Hop> sparseInputs) {
        if (current.isVisited() || !HopRewriteUtils.isBinary(current, Hop.OpOp2.MULT) && !HopRewriteUtils.isAggUnaryOp(current, Hop.AggOp.SUM, Hop.AggOp.SUM_SQ)) {
            return;
        }
        for (Hop c : current.getInput()) {
            if (!inHops.contains(c)) {
                this.rCollectSparseSafeInputs(c, inHops, sparseInputs);
                continue;
            }
            if (!c.dimsKnown(true) || !MatrixBlock.evalSparseFormatInMemory(c.getDim1(), c.getDim2(), c.getNnz())) continue;
            sparseInputs.add(c);
        }
        current.setVisited();
    }
}

