/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.neural;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Filter;
import java.io.File;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class NeuralUtils {
    private NeuralUtils() {
    }

    public static SimpleMatrix loadTextMatrix(String path) {
        return NeuralUtils.convertTextMatrix(IOUtils.slurpFileNoExceptions(path));
    }

    public static SimpleMatrix loadTextMatrix(File file) {
        return NeuralUtils.convertTextMatrix(IOUtils.slurpFileNoExceptions(file));
    }

    public static SimpleMatrix convertTextMatrix(String text) {
        List<String> lines = CollectionUtils.filterAsList(Arrays.asList(text.split("\n")), new Filter<String>(){
            private static final long serialVersionUID = 1L;

            @Override
            public boolean accept(String s) {
                return s.trim().length() > 0;
            }
        });
        int numRows = lines.size();
        int numCols = lines.get(0).trim().split("\\s+").length;
        double[][] data = new double[numRows][numCols];
        for (int row = 0; row < numRows; ++row) {
            String line = lines.get(row);
            String[] pieces = line.trim().split("\\s+");
            if (pieces.length != numCols) {
                throw new RuntimeException("Unexpected row length in line " + row);
            }
            for (int col = 0; col < numCols; ++col) {
                data[row][col] = Double.valueOf(pieces[col]);
            }
        }
        return new SimpleMatrix(data);
    }

    public static double cosine(SimpleMatrix vector1, SimpleMatrix vector2) {
        return NeuralUtils.dot(vector1, vector2) / (vector1.normF() * vector2.normF());
    }

    public static double dot(SimpleMatrix vector1, SimpleMatrix vector2) {
        double score = Double.NaN;
        if (vector1.numRows() == 1) {
            score = ((SimpleMatrix)vector1.mult(vector2.transpose())).get(0);
        } else if (vector1.numCols() == 1) {
            score = ((SimpleMatrix)((SimpleMatrix)vector1.transpose()).mult((SimpleBase)vector2)).get(0);
        } else {
            System.err.println("! Error in neural.Utils.dot: vector1 is a matrix " + vector1.numRows() + " x " + vector1.numCols());
            System.exit(1);
        }
        return score;
    }

    public static void vectorToParams(double[] theta, Iterator<SimpleMatrix> ... matrices) {
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    matrix.set(i, theta[index]);
                    ++index;
                }
            }
        }
        if (index != theta.length) {
            throw new AssertionError((Object)"Did not entirely use the theta vector");
        }
    }

    public static double[] paramsToVector(int totalSize, Iterator<SimpleMatrix> ... matrices) {
        double[] theta = new double[totalSize];
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    theta[index] = matrix.get(i);
                    ++index;
                }
            }
        }
        if (index != totalSize) {
            throw new AssertionError((Object)("Did not entirely fill the theta vector: expected " + totalSize + " used " + index));
        }
        return theta;
    }

    public static double[] paramsToVector(double scale, int totalSize, Iterator<SimpleMatrix> ... matrices) {
        double[] theta = new double[totalSize];
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    theta[index] = matrix.get(i) * scale;
                    ++index;
                }
            }
        }
        if (index != totalSize) {
            throw new AssertionError((Object)("Did not entirely fill the theta vector: expected " + totalSize + " used " + index));
        }
        return theta;
    }

    public static double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public static SimpleMatrix softmax(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.exp(output.get(i, j)));
            }
        }
        double sum = output.elementSum();
        return (SimpleMatrix)output.scale(1.0 / sum);
    }

    public static SimpleMatrix elementwiseApplyLog(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.log(output.get(i, j)));
            }
        }
        return output;
    }

    public static SimpleMatrix elementwiseApplyTanh(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.tanh(output.get(i, j)));
            }
        }
        return output;
    }

    public static SimpleMatrix elementwiseApplyTanhDerivative(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
        output.set(1.0);
        output = (SimpleMatrix)output.minus(input.elementMult((SimpleBase)input));
        return output;
    }

    public static SimpleMatrix concatenateWithBias(SimpleMatrix ... vectors) {
        int size = 0;
        for (SimpleMatrix vector : vectors) {
            size += vector.numRows();
        }
        SimpleMatrix result = new SimpleMatrix(++size, 1);
        int index = 0;
        for (SimpleMatrix vector : vectors) {
            result.insertIntoThis(index, 0, (SimpleBase)vector);
            index += vector.numRows();
        }
        result.set(index, 0, 1.0);
        return result;
    }

    public static SimpleMatrix concatenate(SimpleMatrix ... vectors) {
        int size = 0;
        for (SimpleMatrix vector : vectors) {
            size += vector.numRows();
        }
        SimpleMatrix result = new SimpleMatrix(size, 1);
        int index = 0;
        for (SimpleMatrix vector : vectors) {
            result.insertIntoThis(index, 0, (SimpleBase)vector);
            index += vector.numRows();
        }
        return result;
    }

    public static SimpleMatrix randomGaussian(int numRows, int numCols, Random rand) {
        SimpleMatrix result = new SimpleMatrix(numRows, numCols);
        for (int i = 0; i < numRows; ++i) {
            for (int j = 0; j < numCols; ++j) {
                result.set(i, j, rand.nextGaussian());
            }
        }
        return result;
    }

    public static boolean isZero(SimpleMatrix matrix) {
        int size = matrix.getNumElements();
        for (int i = 0; i < size; ++i) {
            if (matrix.get(i) == 0.0) continue;
            return false;
        }
        return true;
    }
}

