/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark.utils;

import java.io.IOException;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.io.Text;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.mllib.util.NumericParser;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer;
import org.apache.sysml.runtime.util.FastStringTokenizer;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class RDDConverterUtilsExt {
    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(JavaSparkContext sc, CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) {
        JavaPairRDD out = input.entries().toJavaRDD().mapPartitionsToPair((PairFlatMapFunction)new MatrixEntryToBinaryBlockFunction(mcIn));
        if (outputEmptyBlocks && mcIn.mightHaveEmptyBlocks()) {
            out = out.union(SparkUtils.getEmptyBlockRDD(sc, mcIn));
        }
        out = RDDAggregateUtils.mergeByKey((JavaPairRDD<MatrixIndexes, MatrixBlock>)out, false);
        return out;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(SparkContext sc, CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) {
        return RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(new JavaSparkContext(sc), input, mcIn, true);
    }

    public static Dataset<Row> projectColumns(Dataset<Row> df, ArrayList<String> columns) {
        ArrayList<String> columnToSelect = new ArrayList<String>();
        for (int i = 1; i < columns.size(); ++i) {
            columnToSelect.add(columns.get(i));
        }
        return df.select(columns.get(0), (Seq)JavaConversions.asScalaBuffer(columnToSelect).toList());
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] data, long rlen, long clen) {
        return RDDConverterUtilsExt.convertPy4JArrayToMB(data, (int)rlen, (int)clen, false);
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] data, int rlen, int clen) {
        return RDDConverterUtilsExt.convertPy4JArrayToMB(data, rlen, clen, false);
    }

    public static MatrixBlock convertSciPyCOOToMB(byte[] data, byte[] row, byte[] col, long rlen, long clen, long nnz) {
        return RDDConverterUtilsExt.convertSciPyCOOToMB(data, row, col, (int)rlen, (int)clen, (int)nnz);
    }

    public static MatrixBlock convertSciPyCOOToMB(byte[] data, byte[] row, byte[] col, int rlen, int clen, int nnz) {
        MatrixBlock mb = new MatrixBlock(rlen, clen, true);
        mb.allocateSparseRowsBlock(false);
        ByteBuffer buf1 = ByteBuffer.wrap(data);
        buf1.order(ByteOrder.nativeOrder());
        ByteBuffer buf2 = ByteBuffer.wrap(row);
        buf2.order(ByteOrder.nativeOrder());
        ByteBuffer buf3 = ByteBuffer.wrap(col);
        buf3.order(ByteOrder.nativeOrder());
        for (int i = 0; i < nnz; ++i) {
            double val = buf1.getDouble();
            int rowIndex = buf2.getInt();
            int colIndex = buf3.getInt();
            mb.setValue(rowIndex, colIndex, val);
        }
        mb.recomputeNonZeros();
        mb.examSparsity();
        return mb;
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] data, long rlen, long clen, boolean isSparse) {
        return RDDConverterUtilsExt.convertPy4JArrayToMB(data, (int)rlen, (int)clen, isSparse);
    }

    public static MatrixBlock allocateDenseOrSparse(int rlen, int clen, boolean isSparse) {
        MatrixBlock ret = new MatrixBlock(rlen, clen, isSparse);
        ret.allocateBlock();
        return ret;
    }

    public static MatrixBlock allocateDenseOrSparse(long rlen, long clen, boolean isSparse) {
        if (rlen > Integer.MAX_VALUE || clen > Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Dimensions of matrix are too large to be passed via NumPy/SciPy:" + rlen + " X " + clen);
        }
        return RDDConverterUtilsExt.allocateDenseOrSparse(rlen, clen, isSparse);
    }

    public static void copyRowBlocks(MatrixBlock mb, int rowIndex, MatrixBlock ret, int numRowsPerBlock, int rlen, int clen) {
        RDDConverterUtilsExt.copyRowBlocks(mb, (long)rowIndex, ret, (long)numRowsPerBlock, (long)rlen, (long)clen);
    }

    public static void copyRowBlocks(MatrixBlock mb, long rowIndex, MatrixBlock ret, int numRowsPerBlock, int rlen, int clen) {
        RDDConverterUtilsExt.copyRowBlocks(mb, rowIndex, ret, (long)numRowsPerBlock, (long)rlen, (long)clen);
    }

    public static void copyRowBlocks(MatrixBlock mb, int rowIndex, MatrixBlock ret, long numRowsPerBlock, long rlen, long clen) {
        RDDConverterUtilsExt.copyRowBlocks(mb, (long)rowIndex, ret, numRowsPerBlock, rlen, clen);
    }

    public static void copyRowBlocks(MatrixBlock mb, long rowIndex, MatrixBlock ret, long numRowsPerBlock, long rlen, long clen) {
        ret.copy((int)(rowIndex * numRowsPerBlock), (int)Math.min((rowIndex + 1L) * numRowsPerBlock - 1L, rlen - 1L), 0, (int)(clen - 1L), mb, false);
    }

    public static void postProcessAfterCopying(MatrixBlock ret) {
        ret.recomputeNonZeros();
        ret.examSparsity();
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] data, int rlen, int clen, boolean isSparse) {
        MatrixBlock mb = new MatrixBlock(rlen, clen, isSparse, -1L);
        if (isSparse) {
            throw new DMLRuntimeException("Convertion to sparse format not supported");
        }
        long limit = rlen * clen;
        if (limit > Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Dense NumPy array of size " + limit + " cannot be converted to MatrixBlock");
        }
        double[] denseBlock = new double[(int)limit];
        ByteBuffer buf = ByteBuffer.wrap(data);
        buf.order(ByteOrder.nativeOrder());
        for (int i = 0; i < rlen * clen; ++i) {
            denseBlock[i] = buf.getDouble();
        }
        mb.init(denseBlock, rlen, clen);
        mb.recomputeNonZeros();
        mb.examSparsity();
        return mb;
    }

    public static byte[] convertMBtoPy4JDenseArr(MatrixBlock mb) {
        int times;
        long limit;
        byte[] ret = null;
        if (mb.isInSparseFormat()) {
            mb.sparseToDense();
        }
        if ((limit = (long)(mb.getNumRows() * mb.getNumColumns())) > (long)(Integer.MAX_VALUE / (times = 8))) {
            throw new DMLRuntimeException("MatrixBlock of size " + limit + " cannot be converted to dense numpy array");
        }
        ret = new byte[(int)(limit * (long)times)];
        double[] denseBlock = mb.getDenseBlockValues();
        if (mb.isEmptyBlock()) {
            int i = 0;
            while ((long)i < limit) {
                ByteBuffer.wrap(ret, i * times, times).order(ByteOrder.nativeOrder()).putDouble(0.0);
                ++i;
            }
        } else {
            if (denseBlock == null) {
                throw new DMLRuntimeException("Error while dealing with empty blocks.");
            }
            for (int i = 0; i < denseBlock.length; ++i) {
                ByteBuffer.wrap(ret, i * times, times).order(ByteOrder.nativeOrder()).putDouble(denseBlock[i]);
            }
        }
        return ret;
    }

    public static Dataset<Row> addIDToDataFrame(Dataset<Row> df, SparkSession sparkSession, String nameOfCol) {
        StructField[] oldSchema = df.schema().fields();
        StructField[] newSchema = new StructField[oldSchema.length + 1];
        for (int i = 0; i < oldSchema.length; ++i) {
            newSchema[i] = oldSchema[i];
        }
        newSchema[oldSchema.length] = DataTypes.createStructField((String)nameOfCol, (DataType)DataTypes.DoubleType, (boolean)false);
        JavaRDD newRows = df.rdd().toJavaRDD().zipWithIndex().map((Function)new AddRowID());
        return sparkSession.createDataFrame(newRows, new StructType(newSchema));
    }

    public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF) {
        StructField[] oldSchema = inputDF.schema().fields();
        StructField[] newSchema = new StructField[oldSchema.length];
        for (int i = 0; i < oldSchema.length; ++i) {
            String colName = oldSchema[i].name();
            newSchema[i] = DataTypes.createStructField((String)colName, (DataType)new VectorUDT(), (boolean)true);
        }
        class StringToVector
        implements Function<Tuple2<Row, Long>, Row> {
            private static final long serialVersionUID = -4733816995375745659L;

            StringToVector() {
            }

            public Row call(Tuple2<Row, Long> arg0) throws Exception {
                Row oldRow = (Row)arg0._1;
                int oldNumCols = oldRow.length();
                if (oldNumCols > 1) {
                    throw new DMLRuntimeException("The row must have at most one column");
                }
                Object[] fields = new Object[oldNumCols];
                ArrayList<Vector> fieldsArr = new ArrayList<Vector>();
                for (int i = 0; i < oldRow.length(); ++i) {
                    Object ci = oldRow.get(i);
                    if (ci == null) {
                        fieldsArr.add(null);
                        continue;
                    }
                    if (ci instanceof String) {
                        String cis = (String)ci;
                        StringBuffer sb = new StringBuffer(cis.trim());
                        boolean nid = false;
                        while (i < 2) {
                            if (sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')' || sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']') {
                                sb.deleteCharAt(0);
                                sb.setLength(sb.length() - 1);
                            }
                            ++i;
                        }
                        String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
                        try {
                            double[] doubles = (double[])NumericParser.parse((String)ncis);
                            Vector dense = Vectors.dense((double[])doubles);
                            fieldsArr.add(dense);
                            continue;
                        }
                        catch (Exception e) {
                            throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
                        }
                    }
                    throw new DMLRuntimeException("Only String is supported");
                }
                Row row = RowFactory.create((Object[])fieldsArr.toArray());
                return row;
            }
        }
        JavaRDD newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map((Function)new StringToVector());
        Dataset outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType((StructField[])newSchema));
        return outDF;
    }

    private static class IJVToBinaryBlockFunctionHelper
    implements Serializable {
        private static final long serialVersionUID = -7952801318564745821L;
        private static final int BUFFER_SIZE = 4000000;
        private int _bufflen = -1;
        private long _rlen = -1L;
        private long _clen = -1L;
        private int _brlen = -1;
        private int _bclen = -1;

        public IJVToBinaryBlockFunctionHelper(MatrixCharacteristics mc) {
            if (!mc.dimsKnown()) {
                throw new DMLRuntimeException("The dimensions need to be known in given MatrixCharacteristics for given input RDD");
            }
            this._rlen = mc.getRows();
            this._clen = mc.getCols();
            this._brlen = mc.getRowsPerBlock();
            this._bclen = mc.getColsPerBlock();
            this._bufflen = (int)Math.min(this._rlen * this._clen, 4000000L);
        }

        public Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell(Text txt) {
            FastStringTokenizer st = new FastStringTokenizer(' ');
            String strVal = txt.toString();
            if (strVal.startsWith("%")) {
                return null;
            }
            st.reset(strVal);
            long row = st.nextLong();
            long col = st.nextLong();
            double val = st.nextDouble();
            MatrixIndexes indx = new MatrixIndexes(row, col);
            MatrixCell cell = new MatrixCell(val);
            return new Tuple2((Object)indx, (Object)cell);
        }

        public Tuple2<MatrixIndexes, MatrixCell> matrixEntryToMatrixCell(MatrixEntry entry) {
            MatrixIndexes indx = new MatrixIndexes(entry.i(), entry.j());
            MatrixCell cell = new MatrixCell(entry.value());
            return new Tuple2((Object)indx, (Object)cell);
        }

        Iterable<Tuple2<MatrixIndexes, MatrixBlock>> convertToBinaryBlock(Object arg0, RDDConverterTypes converter) throws Exception {
            ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes, MatrixBlock>>();
            ReblockBuffer rbuff = new ReblockBuffer(this._bufflen, this._rlen, this._clen, this._brlen, this._bclen);
            Iterator iter = (Iterator)arg0;
            while (iter.hasNext()) {
                Tuple2<MatrixIndexes, MatrixCell> cell = null;
                switch (converter) {
                    case MATRIXENTRY_TO_MATRIXCELL: {
                        cell = this.matrixEntryToMatrixCell((MatrixEntry)iter.next());
                        break;
                    }
                    case TEXT_TO_MATRIX_CELL: {
                        cell = this.textToMatrixCell((Text)iter.next());
                        break;
                    }
                    default: {
                        throw new Exception("Invalid converter for IJV data:" + converter.toString());
                    }
                }
                if (cell == null) continue;
                if (rbuff.getSize() >= rbuff.getCapacity()) {
                    IJVToBinaryBlockFunctionHelper.flushBufferToList(rbuff, ret);
                }
                rbuff.appendCell(((MatrixIndexes)cell._1).getRowIndex(), ((MatrixIndexes)cell._1).getColumnIndex(), ((MatrixCell)cell._2).getValue());
            }
            IJVToBinaryBlockFunctionHelper.flushBufferToList(rbuff, ret);
            return ret;
        }

        private static void flushBufferToList(ReblockBuffer rbuff, ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret) throws IOException, DMLRuntimeException {
            rbuff.flushBufferToBinaryBlocks().stream().map(b -> SparkUtils.fromIndexedMatrixBlock(b)).forEach(b -> ret.add((Tuple2<MatrixIndexes, MatrixBlock>)b));
        }
    }

    private static class MatrixEntryToBinaryBlockFunction
    implements PairFlatMapFunction<Iterator<MatrixEntry>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 4907483236186747224L;
        private IJVToBinaryBlockFunctionHelper helper = null;

        public MatrixEntryToBinaryBlockFunction(MatrixCharacteristics mc) {
            this.helper = new IJVToBinaryBlockFunctionHelper(mc);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<MatrixEntry> arg0) throws Exception {
            return this.helper.convertToBinaryBlock(arg0, RDDConverterTypes.MATRIXENTRY_TO_MATRIXCELL).iterator();
        }
    }

    public static class AddRowID
    implements Function<Tuple2<Row, Long>, Row> {
        private static final long serialVersionUID = -3733816995375745659L;

        public Row call(Tuple2<Row, Long> arg0) throws Exception {
            int oldNumCols = ((Row)arg0._1).length();
            Object[] fields = new Object[oldNumCols + 1];
            for (int i = 0; i < oldNumCols; ++i) {
                fields[i] = ((Row)arg0._1).get(i);
            }
            fields[oldNumCols] = new Double((Long)arg0._2 + 1L);
            return RowFactory.create((Object[])fields);
        }
    }

    public static enum RDDConverterTypes {
        TEXT_TO_MATRIX_CELL,
        MATRIXENTRY_TO_MATRIXCELL;

    }
}

