/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;

public class Softmax
extends Script {
    public Softmax() {
        String string = "scripts/nn/layers/softmax.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public Matrix forward(Object object) {
        String string = "source('scripts/nn/layers/softmax.dml') as mlcontextns;probs = mlcontextns::forward(scores);";
        Script script = new Script(string);
        script.in("scores", object).out("probs");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("probs");
        return matrix;
    }

    public String forward__docs() {
        String string = "forward = function(matrix[double] scores)\n    return (matrix[double] probs) {\n  /*\n   * Computes the forward pass for a softmax classifier.  The input\n   * has N examples, each with D values that are interpreted as\n   * unnormalized, log-probabilities for each of D classes.  The softmax\n   * function transforms these values to normalized probabilities across\n   * the D classes, for every example.\n   *\n   * This can be interpreted as a generalization of the sigmoid\n   * function to multiple classes.\n   *\n   *   `probs_ij = e^scores_ij / sum(e^scores_i)`\n   *\n   * Inputs:\n   *  - scores: Inputs, of shape (N, D).\n   *\n   * Outputs:\n   *  - probs: Outputs, of shape (N, D).\n   */\n";
        return string;
    }

    public String forward__source() {
        String string = "forward = function(matrix[double] scores)\n    return (matrix[double] probs) {\n  /*\n   * Computes the forward pass for a softmax classifier.  The input\n   * has N examples, each with D values that are interpreted as\n   * unnormalized, log-probabilities for each of D classes.  The softmax\n   * function transforms these values to normalized probabilities across\n   * the D classes, for every example.\n   *\n   * This can be interpreted as a generalization of the sigmoid\n   * function to multiple classes.\n   *\n   *   `probs_ij = e^scores_ij / sum(e^scores_i)`\n   *\n   * Inputs:\n   *  - scores: Inputs, of shape (N, D).\n   *\n   * Outputs:\n   *  - probs: Outputs, of shape (N, D).\n   */\n  # For numerical stability, we subtract the max score of an example from all scores for that\n  # example.  This is equivalent to the original formulation:\n  # e^scores_i / sum(e^scores_i) == C*e^scores_i / C*sum(e^scores_i)\n  #                              == e^(scores_i+log(C)) / sum(e^(scores_i+log(C))\n  # set log(C) = -max(scores_i):\n  #                              == e^(scores_i-max(scores_i)) / sum(e^(scores_i-max(scores_i))\n  scores = scores - rowMaxs(scores)  # numerical stability\n  unnorm_probs = exp(scores)  # unnormalized probabilities\n  probs = unnorm_probs / rowSums(unnorm_probs)  # normalized probabilities\n}\n";
        return string;
    }

    public Matrix backward(Object object, Object object2) {
        String string = "source('scripts/nn/layers/softmax.dml') as mlcontextns;dscores = mlcontextns::backward(dprobs, scores);";
        Script script = new Script(string);
        script.in("dprobs", object).in("scores", object2).out("dscores");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("dscores");
        return matrix;
    }

    public String backward__docs() {
        String string = "backward = function(matrix[double] dprobs, matrix[double] scores)\n    return (matrix[double] dscores) {\n  /*\n   * Computes the backward pass for a softmax classifier.\n   *\n   * Note that dscores_ij has multiple source branches:\n   *\n   *   ```\n   *   dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)\n   *   dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j\n   *\n   *   dloss/dscores_ij =\n   *      (dloss/dprobs_ij * dprobs_ij/dscores_ij)\n   *      + sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij)\n   *   ```\n   *\n   * Inputs:\n   *  - dprobs: Gradient wrt `probs` from upstream, of shape (N, D).\n   *  - scores: Inputs, of shape (N, D).\n   *\n   * Outputs:\n   *  - dscores: Gradient wrt `scores`, of shape (N, D).\n   */\n";
        return string;
    }

    public String backward__source() {
        String string = "backward = function(matrix[double] dprobs, matrix[double] scores)\n    return (matrix[double] dscores) {\n  /*\n   * Computes the backward pass for a softmax classifier.\n   *\n   * Note that dscores_ij has multiple source branches:\n   *\n   *   ```\n   *   dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)\n   *   dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j\n   *\n   *   dloss/dscores_ij =\n   *      (dloss/dprobs_ij * dprobs_ij/dscores_ij)\n   *      + sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij)\n   *   ```\n   *\n   * Inputs:\n   *  - dprobs: Gradient wrt `probs` from upstream, of shape (N, D).\n   *  - scores: Inputs, of shape (N, D).\n   *\n   * Outputs:\n   *  - dscores: Gradient wrt `scores`, of shape (N, D).\n   */\n  scores = scores - rowMaxs(scores)  # numerical stability\n  unnorm_probs = exp(scores)  # unnormalized probabilities\n  probs = unnorm_probs / rowSums(unnorm_probs)  # normalized probabilities\n  # After some cancellation:\n  # dscores = dprobs*probs - probs*rowSums(dprobs*probs)\n  dtemp = dprobs * probs\n  dscores = dtemp - probs*rowSums(dtemp)\n}\n";
        return string;
    }
}

