/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.rnn;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.Generics;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class RNNUtils {
    private RNNUtils() {
    }

    public static void vectorToParams(double[] theta, Iterator<SimpleMatrix> ... matrices) {
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    matrix.set(i, theta[index]);
                    ++index;
                }
            }
        }
        if (index != theta.length) {
            throw new AssertionError((Object)"Did not entirely use the theta vector");
        }
    }

    public static double[] paramsToVector(int totalSize, Iterator<SimpleMatrix> ... matrices) {
        double[] theta = new double[totalSize];
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    theta[index] = matrix.get(i);
                    ++index;
                }
            }
        }
        if (index != totalSize) {
            throw new AssertionError((Object)("Did not entirely fill the theta vector: expected " + totalSize + " used " + index));
        }
        return theta;
    }

    public static double[] paramsToVector(double scale, int totalSize, Iterator<SimpleMatrix> ... matrices) {
        double[] theta = new double[totalSize];
        int index = 0;
        for (Iterator<SimpleMatrix> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                SimpleMatrix matrix = matrixIterator.next();
                int numElements = matrix.getNumElements();
                for (int i = 0; i < numElements; ++i) {
                    theta[index] = matrix.get(i) * scale;
                    ++index;
                }
            }
        }
        if (index != totalSize) {
            throw new AssertionError((Object)("Did not entirely fill the theta vector: expected " + totalSize + " used " + index));
        }
        return theta;
    }

    public static Map<String, SimpleMatrix> readRawWordVectors(String filename, int expectedSize) {
        Map<String, SimpleMatrix> wordVectors = Generics.newHashMap();
        System.err.println("Reading in the word vector file: " + filename);
        int dimOfWords = 0;
        boolean warned = false;
        for (String line : IOUtils.readLines(filename, "utf-8")) {
            String[] lineSplit = line.split("\\s+");
            String word = lineSplit[0];
            dimOfWords = lineSplit.length - 1;
            if (expectedSize <= 0) {
                expectedSize = dimOfWords;
                System.err.println("Dimensionality of numHid not set.  The length of the word vectors in the given file appears to be " + dimOfWords);
            }
            if (dimOfWords > expectedSize) {
                if (!warned) {
                    warned = true;
                    System.err.println("WARNING: Dimensionality of numHid parameter and word vectors do not match, deleting word vector dimensions to fit!");
                }
                dimOfWords = expectedSize;
            } else if (dimOfWords < expectedSize) {
                throw new RuntimeException("Word vectors file has dimension too small for requested numHid of " + expectedSize);
            }
            double[][] vec = new double[dimOfWords][1];
            for (int i = 1; i <= dimOfWords; ++i) {
                vec[i - 1][0] = Double.parseDouble(lineSplit[i]);
            }
            SimpleMatrix vector = new SimpleMatrix(vec);
            wordVectors.put(word, vector);
        }
        return wordVectors;
    }

    public static double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public static SimpleMatrix softmax(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.exp(output.get(i, j)));
            }
        }
        double sum = output.elementSum();
        return (SimpleMatrix)output.scale(1.0 / sum);
    }

    public static SimpleMatrix elementwiseApplyLog(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.log(output.get(i, j)));
            }
        }
        return output;
    }

    public static SimpleMatrix elementwiseApplyTanh(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input);
        for (int i = 0; i < output.numRows(); ++i) {
            for (int j = 0; j < output.numCols(); ++j) {
                output.set(i, j, Math.tanh(output.get(i, j)));
            }
        }
        return output;
    }

    public static SimpleMatrix elementwiseApplyTanhDerivative(SimpleMatrix input) {
        SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
        output.set(1.0);
        output = (SimpleMatrix)output.minus(input.elementMult((SimpleBase)input));
        return output;
    }

    public static SimpleMatrix concatenateWithBias(SimpleMatrix ... vectors) {
        int size = 0;
        for (SimpleMatrix vector : vectors) {
            size += vector.numRows();
        }
        SimpleMatrix result = new SimpleMatrix(++size, 1);
        int index = 0;
        for (SimpleMatrix vector : vectors) {
            result.insertIntoThis(index, 0, (SimpleBase)vector);
            index += vector.numRows();
        }
        result.set(index, 0, 1.0);
        return result;
    }

    public static SimpleMatrix concatenate(SimpleMatrix ... vectors) {
        int size = 0;
        for (SimpleMatrix vector : vectors) {
            size += vector.numRows();
        }
        SimpleMatrix result = new SimpleMatrix(size, 1);
        int index = 0;
        for (SimpleMatrix vector : vectors) {
            result.insertIntoThis(index, 0, (SimpleBase)vector);
            index += vector.numRows();
        }
        return result;
    }

    public static SimpleMatrix randomGaussian(int numRows, int numCols, Random rand) {
        SimpleMatrix result = new SimpleMatrix(numRows, numCols);
        for (int i = 0; i < numRows; ++i) {
            for (int j = 0; j < numCols; ++j) {
                result.set(i, j, rand.nextGaussian());
            }
        }
        return result;
    }
}

