///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include <vector>
#include <string>
#include <iostream>
#include <sstream>
#include <climits>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <map>
using namespace std;

void getRandomDistribution(float* dist, int size){
    float sum = 0.0;
    float temp[size];
    for(int i = 0; i < size; i++){
        temp[i] = (float)rand()/(float)RAND_MAX;
        sum += temp[i];
    }
    // normalize and log
    for(int i = 0; i < size; i++){
        dist[i] = log(temp[i] / sum);
    }
}

// Define domains and types for all variables
// F - Reduction node - F for historical reasons
// Act - Active category - the S in S/NP
// Awa - Awaited catgory - the NP in S/NP
// W - Word - observed variable
const int maxf = 30;
const int maxact = 5;
const int maxawa = 20;
const int trellis_size = maxf * maxact * maxawa;

int getTrellisIndex(int f, int act, int awa){
    return f * maxact * maxawa + act * maxawa + awa;
}

void breakdownIndex(int ind, int *f, int *act, int *awa){
    *awa = ind % maxawa;
    *act = ((ind - *awa) / maxawa) % maxact;
    *f = ind / (maxawa*maxact);
}

// define count matrices
double f_cpt[maxawa][maxf];
double f_counts[maxawa][maxf];
double f_cond_counts[maxawa];
double f_smooth = 0.1;
float act_cpt[maxact][maxawa][maxf][maxact];
float act_counts[maxact][maxawa][maxf][maxact];
float act_cond_counts[maxact][maxawa][maxf];
float act_smooth = 0.1;
float awa_cpt[maxawa][maxf][maxact][maxawa];
float awa_counts[maxawa][maxf][maxact][maxawa];
float awa_cond_counts[maxawa][maxf][maxact];
float awa_smooth = 0.1;

class TrellisNode{
    public:
    int back;
    float prob;
};

int main(int argc, char** argv){
    int num_iters = 2;
    if(argc==2) num_iters = atoi(argv[1]);
    vector<vector<int> > sentences;
    map<string,int> word_to_int;
    map<int,string> int_to_word;
    word_to_int["__fake__"] = 0;
    int_to_word[0] = "__fake__";
    int num_words = 1;
    srand(time(0));  // seed random number generator (comment out for repeatability)
    
    cerr << "Reading in sentences..." << endl;
    // read in all the sentences and store (for repeated iterations)
    string line;
    while(getline(cin,line)){
        vector<int> sent;
        sent.push_back(0);
        string word;
        stringstream ss(line);
        while(ss >> word){
            if(word_to_int.count(word) == 0){
                int_to_word[num_words] = word;
                word_to_int[word] = num_words++;
            }
            sent.push_back(word_to_int[word]);
        }
        sentences.push_back(sent);
    }
    cerr << "Finished reading all input data -- found : " << sentences.size() << " sentences." << endl;
    // Now we can declare arrays for w
    int maxw = num_words;
    float w_counts[maxawa][maxw];
    float w_cond_counts[maxawa];
    float w_cpt[maxawa][maxw];
    float w_smooth = 0.0;
    // initialize variable names
     
    cerr << "Initializing CPTS to random values..." << endl;
    // initialize all cpts and smooth count values
    for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
        float dist[maxf];
        getRandomDistribution(dist, maxf);
        for(int f = 0; f < maxf; f++){
            f_cpt[prev_awa][f] = dist[f];
            f_counts[prev_awa][f] = f_smooth;
        }
        f_cond_counts[prev_awa] = f_smooth * maxf;
    }

    for(int prev_act = 0; prev_act < maxact; prev_act++){
        for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
            for(int cur_f = 0; cur_f < maxf; cur_f++){
                float dist[maxact];
                getRandomDistribution(dist,maxact);
                for(int act = 0; act < maxact; act++){
                    act_cpt[prev_act][prev_awa][cur_f][act] = dist[act];
                    act_counts[prev_act][prev_awa][cur_f][act] = act_smooth;
                }
                act_cond_counts[prev_act][prev_awa][cur_f] = act_smooth * maxact;
            }
        }
    }

    for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
        for(int cur_f = 0; cur_f < maxf; cur_f++){
            for(int cur_act = 0; cur_act < maxact; cur_act++){
                float dist[maxawa];
                getRandomDistribution(dist,maxawa);
                for(int cur_awa = 0; cur_awa < maxawa; cur_awa++){
                    awa_cpt[prev_awa][cur_f][cur_act][cur_awa] = dist[cur_awa];
                    awa_counts[prev_awa][cur_f][cur_act][cur_awa] = awa_smooth;
                }
                awa_cond_counts[prev_awa][cur_f][cur_act] = awa_smooth * maxawa;
            }
        }
    }
    
    for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
        float dist[maxw];
        getRandomDistribution(dist,maxw);
        for(int w = 0; w < maxw; w++){
            w_cpt[prev_awa][w] = dist[w];
            w_counts[prev_awa][w] = w_smooth;
        }
        w_cond_counts[prev_awa] = w_smooth * maxw;
    }

    cerr << "Beginning EM Iterations" << endl;
    // indices we'll use repeatedly
    int old_f, old_act, old_awa;
    int f,act,awa;
    // do EM iterations
    for(int iter = 0; iter < num_iters; iter++){
        float total_prob=0;
        cerr << "Iteration " << iter << "... ";
        // expectation
        for(unsigned int i = 0; i < sentences.size(); i++){
        	if(i && i%10000 == 0){
        		cerr << i <<"...";
        	}
            vector<int> sentence = sentences[i];
            int max_state = 0;
            // initialize the trellis
            TrellisNode trellis[sentence.size()][trellis_size];
            trellis[0][0].prob = 0.0;
            trellis[0][0].back = -1;
            for(int n = 1; n < trellis_size; n++){
                trellis[0][n].prob = INT_MIN;
            }
            // iterate over words
            for(unsigned int j = 1; j < sentence.size(); j++){
                int word = sentence[j];
                // iterate over previous states
                for(int ind = 0; ind < trellis_size; ind++){
                    TrellisNode last = trellis[j-1][ind];
                    breakdownIndex(ind, &old_f, &old_act, &old_awa);
                    // iterate over future states
                    for(f = 0; f < maxf; f++){
                        float p_f_last = last.prob + f_cpt[old_awa][f];
                        for(act = 0; act < maxact; act++){
                            float p_act_f_last = p_f_last + act_cpt[old_act][old_awa][f][act];
                            for(awa = 0; awa < maxawa; awa++){
                                int new_ind = getTrellisIndex(f,act,awa);
                                float p_all = p_act_f_last +
                                              awa_cpt[old_awa][f][act][awa] +
                                              w_cpt[awa][word];
                                // if this is the first step, or a max, track it
                                if((ind==0) || p_all > trellis[j][new_ind].prob){
                                    trellis[j][new_ind].prob = p_all;
                                    trellis[j][new_ind].back = ind;
                                }
                                if(j == sentence.size()-1 && p_all > trellis[j][max_state].prob){
                                    max_state = new_ind;
                                }
                            }
                        }
                    }
                }
            }
            // get most likely sequence 
            // tally transitions
            int state = max_state;
            int prev_state, word;
            total_prob += trellis[sentence.size()-1][max_state].prob;
            for(int j = sentence.size()-1; j >= 1; j--){
                prev_state = trellis[j][state].back;
                breakdownIndex(prev_state, &old_f, &old_act, &old_awa);
                breakdownIndex(state, &f, &act, &awa);
                word = sentence[j];
                f_cond_counts[old_awa]+= 1;
                f_counts[old_awa][f]+= 1;
                act_cond_counts[old_act][old_awa][f]+= 1;
                act_counts[old_act][old_awa][f][act]+= 1;
                awa_cond_counts[old_awa][f][act]+=1;
                awa_counts[old_awa][f][act][awa]+=1;
                w_cond_counts[awa]+=1;
                w_counts[awa][word]+=1;
                state = prev_state;
            }
        }
        
        // maximization -- calculate probs for cpt (and output)
        cerr << endl;
        cout << "#Total probability of this iteration:" << total_prob << endl;
//        cerr << "Re-weighting cpts..." << endl;
        if(iter==(num_iters-1))cout << "################## F_CPT ##################" << endl;
        for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
        	//if(iter==(num_iters-1))cout << prev_awa <<" : ";
            for(int f = 0; f < maxf; f++){
                f_cpt[prev_awa][f] = log(f_counts[prev_awa][f] / f_cond_counts[prev_awa]);
                if(iter==(num_iters-1)){
                	cout << "F AWA" << prev_awa << " : F" << f << " = " << exp(f_cpt[prev_awa][f]) << endl;
                }
                f_counts[prev_awa][f] = f_smooth;
            }
            //if(iter==(num_iters-1))cout << endl;
            f_cond_counts[prev_awa] = f_smooth * maxf;
        }

        if(iter==(num_iters-1))cout << "################ ACT_CPT ####################" << endl;
        for(int prev_act = 0; prev_act < maxact; prev_act++){
            for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
                for(int cur_f = 0; cur_f < maxf; cur_f++){
                	//if(iter==(num_iters-1))cout << prev_act <<","<<prev_awa<<","<<cur_f<< " : ";
                    for(int act = 0; act < maxact; act++){
                        act_cpt[prev_act][prev_awa][cur_f][act] = log(act_counts[prev_act][prev_awa][cur_f][act] / act_cond_counts[prev_act][prev_awa][cur_f]);
                        if(iter==(num_iters-1)){
                        	cout << "ACT ACT" << prev_act << " AWA" << prev_awa << " F" << cur_f << " : ACT" << act << " = " << exp(act_cpt[prev_act][prev_awa][cur_f][act]) << endl;
                        }
                        act_counts[prev_act][prev_awa][cur_f][act] = act_smooth;
                    }
                    //if(iter==(num_iters-1))cout << endl;
                    act_cond_counts[prev_act][prev_awa][cur_f] = act_smooth * maxact;
                }
            }
        }

        if(iter==(num_iters-1))cout << "################# AWA_CPT #####################" << endl;
        for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
            for(int cur_f = 0; cur_f < maxf; cur_f++){
                for(int cur_act = 0; cur_act < maxact; cur_act++){
                	//if(iter==(num_iters-1))cout << prev_awa <<","<<cur_f<<","<<cur_act<< " : ";
                    for(int cur_awa = 0; cur_awa < maxawa; cur_awa++){
                        awa_cpt[prev_awa][cur_f][cur_act][cur_awa] = log(awa_counts[prev_awa][cur_f][cur_act][cur_awa] / awa_cond_counts[prev_awa][cur_f][cur_act]);
                        if(iter==(num_iters-1)){
                        	cout << "AWA AWA" << prev_awa << " F" << cur_f << " ACT" << cur_act << " : AWA" << cur_awa << " = " << exp(awa_cpt[prev_awa][cur_f][cur_act][cur_awa]) << endl;
                        }
                        awa_counts[prev_awa][cur_f][cur_act][cur_awa] = awa_smooth;
                    }
                    //if(iter==(num_iters-1))cout << endl;
                    awa_cond_counts[prev_awa][cur_f][cur_act] = awa_smooth * maxawa;
                }
            }
        }
        
        if(iter==(num_iters-1)) cout << "############### W_CPT #######################" << endl;
        for(int prev_awa = 0; prev_awa < maxawa; prev_awa++){
        	//if(iter==(num_iters-1))cout << prev_awa << " : ";
            for(int w = 0; w < maxw; w++){
                w_cpt[prev_awa][w] = log(w_counts[prev_awa][w] / w_cond_counts[prev_awa]);
                if(iter==(num_iters-1) && w_counts[prev_awa][w] > 0.0){
                    cout << "W AWA" << prev_awa <<  " : " << int_to_word[w] << " = " << exp(w_cpt[prev_awa][w]) << endl;
            	}
            	w_counts[prev_awa][w] = w_smooth;
            }
            //if(iter==(num_iters-1))cout << endl;
            w_cond_counts[prev_awa] = w_smooth * maxw;
        }
        
    }
}

