// syntaxsum.cpp
//
// by William M. Darling (c) 2010
//
// This is a C++ class implementing a Gibbs Sampler for our topic-syntax model for multi-document summarization.
//
// Large portions of this code are inspired by the HMMLDA Gibbs Sampler in the Matlab Topic Modeling Toolbox by Mark Steyvers 
// and Tom Griffiths available at http://psiexp.ss.uci.edu/research/programs_data/toolbox.htm. (Used with permission).
//
// For more information please see:
// Griffiths, T., & Steyvers, M. (2004).  Finding Scientific Topics. Proceedings of the National Academy 
// of Sciences, 101 (suppl. 1), 5228-5235.
// and
// Griffiths, T.L., & Steyvers, M.,  Blei, D.M., & Tenenbaum, J.B. (2004). Integrating Topics and Syntax. In: 
// Advances in Neural Information Processing Systems, 17.

#include "syntaxsum.h"
#include <cstdlib>

using namespace std;

syntaxsum::syntaxsum()
{
	// initialize syntaxsum object...
	
}

syntaxsum::~syntaxsum()
{
	// free up memory...
	
}

void syntaxsum::setup(double b, double g, const char *WS, const char *DS, int nt, int ns, int dps)
{
	// set up parameters
	beta = b;
	gamma = g;
	T = nt;
	S = ns+2;	// numstates+2 because we need one state for the topics and one state for end-of-sentence marker
	docspds = dps;
	
	// load data...
	
	// WORDS
	FILE *fp = fopen(WS, "r");
	
	double temp;
	int NWS=0;
	while(fscanf(fp, "%lf", &temp)!=EOF)
		NWS++;
	rewind(fp);
	
	// adjust vector size
	w.resize(NWS);
	
	NWS=0;
	W=0;
	while(fscanf(fp, "%lf", &temp)!=EOF) {
		w[NWS] = temp-1;	// because we use the same data format as the matlab topic modeling toolbox (indeces start at 1)
		NWS++;
		if(temp > W)
			W=temp;
	}
	fclose(fp);
	
	// DOCUMENTS
	fp = fopen(DS, "r");
	
	int NDS=0;
	while(fscanf(fp, "%lf", &temp)!=EOF)
		NDS++;
	rewind(fp);
	
	// adjust vector size
	d.resize(NDS);
	
	NDS=0;
	while(fscanf(fp, "%lf", &temp)!=EOF) {
		d[NDS] = temp-1;	// because we use the same data format as the matlab topic modeling toolbox (indeces start at 1)
		NDS++;
	}
	fclose(fp);
	
	// # of tokens, # of documents
	N = NDS;
	D = d[NDS-1]+1;
}

void syntaxsum::gibbs(int iterations)
{
	int current, prev, preprev, dioffset, wioffset; 
	int S2, S3, doc, word, clas, topic;			// clas has only one s because of the c++ keyword class!
	double betasum, gammasum, max, best, totprob, r; 
	int cdocset, j;
		
	// some constants
	S2 = S * S;
	S3 = S * S * S;
	betasum = (double)W * beta;
	gammasum = (double)S * gamma;

	// create 1-gram, bigram, trigram
	vector<int> first(N);
	vector<int> second(N);
	vector<int> third(N);
		
	// adjust vector sizes
	z.resize(N);			// word topic allocations
	c.resize(N);			// word syntax class allocations
	
	nz.resize(T);			// total allocations to each topic
	nc.resize(S);			// total allocations to each class
		
	nwc.resize(S*W);		// number of words allocated to each class
	nwz.resize(T*W);		// number of words allocated to each topic
	
	// create more housekeeping variables
	vector<double> probs(T+S);
	vector<int> stot(S*S*S);
	vector<int> sp(S*S*S*S);
	
	
	// print out stats
	printf( "Running SyntaxSum Gibbs Sampler for %d iterations...\n", iterations);
	printf("\tnumber of unique words = %d\n", W);
	printf("\tnumber of word tokens = %d\n", N);
	printf("\tnumber of docs = %d\n", D);
	printf("\tnumber of docs per set = %d\n", docspds);
	printf("\tnumber of topics = %d\n", T);
	printf("\tnumber of syntactic states = %d\n", S-2);
	printf("\tbeta = %lf\n", beta);
	printf("\tgamma = %lf\n", gamma);
	
	
	// do random initialization of the markov chain
	printf( "Starting Random initialization...\n" );
	
	current = 0;
	prev = 0;
	preprev = 0;
	first[0] = 0;
	second[0] = 0;
	third[0] = 0;
	
	// for each word -- initialize
	for(int i=0; i < N-1; i++) {	// last one is a sentence marker...
        
		word = w[i];
		
		if (word == -1) 
		{
			// sentence marker
	        	sp[preprev*S3 + prev*S2 + current*S]++;
			first[i]  = current;
			second[i] = prev;
			third[i]  = preprev;
			preprev   = 0;
			prev      = 0;
			current   = 0;
		} 
		else 
		{
			doc = d[i];
			dioffset = doc*T;
			wioffset = word*T;
			
			// depends on number of documents per document set
			cdocset = doc / docspds;
			
			max = 0; 
			best = 0; 
			totprob = 0;

			c[i] = rand() % 2;
			
			for(j = 0; j < T; j++) {
				if(j==cdocset)
					probs[j] = 1.0;
				else
					probs[j] = 0.0;
				
				if (c[i]==1)		// if semantic class
					probs[j] *= ((double)nwz[wioffset + j] + beta) / ((double)nz[j] + betasum);
				totprob += probs[j];
			}
			
			r = ((double)rand() / (double)RAND_MAX) * totprob;
			max = probs[0];
			j = 0;
			while (r>max) {
				j++;
				max += probs[j];
			}
			
			z[i] = j;
			max = 0; 
			best = 0;
			
			if (c[i]==0) {
				probs[1] = ((double)nwz[wioffset + j] + beta) / ((double)nz[j] + betasum) * ((double)sp[preprev*S3 + prev*S2 + current*S + 1] + gamma);
				totprob = probs[1];
				
				for (j = 2; j < S; j++) {
					probs[j] = ((double)nwc[word*S + j] + beta) / ((double)nc[j] + betasum) * ((double)sp[preprev*S3 + prev*S2 + current*S + j] + gamma);
					totprob += probs[j];
				}
				
				r = ((double)rand() / (double)RAND_MAX) * totprob;
				max = probs[1];
				j = 1;
                
				while (r>max) {
					j++;
					max += probs[j];
				}
				c[i] = j;
			}
			
			if (c[i] == 1) {
				topic = z[i];
				nwz[wioffset + topic]++;
				nz[topic]++;
			} 
			else
				nwc[word*S + c[i]]++;
			
			clas = c[i];
			nc[clas]++;
			stot[prev*S2 + current*S + clas]++;
			sp[preprev*S3 + prev*S2 + current*S + clas]++;
			first[i] = current;
			second[i] = prev;
			third[i] = preprev;
			preprev = prev;
			prev = current;
			current = clas;
		}
	}
    
	
	// perform Gibbs Sampling iterations...
	probs[0] = 0;
	
	printf("Beginning Gibbs Sampling iterations...\n");
	
	for(int it=0; it<iterations; it++)
	{
		if ((it % 10)==0)
			printf( "\tIteration %d of %d\n", it, iterations);
		
		current = 0;
		prev = 0;
		preprev = 0;
		
		// go through each word in the corpus -- only up to N-4 because we look ahead and because we don't need the last N (end of sentence marker)
		for(int i=0; i < N-4; i++)
		{
			word = w[i]; 
			
			if (word == -1) // sentence marker
			{
				sp[third[i]*S3 + second[i]*S2 + first[i]*S]--;
				sp[preprev*S3 + prev*S2 + current*S]++;
				first[i] = current;
				second[i] = prev;
				third[i] = preprev;
				current = 0;
				prev = 0;
				preprev = 0;
			} 
			else
			{
				doc = d[i]; 
				dioffset = doc*T;
				wioffset = word*T;
				
				// depends on number of documents per document set
				cdocset = doc / docspds;
				
				sp[third[i]*S3 + second[i]*S2 + first[i]*S + c[i]]--;
				
				if (c[i] == 1)	// if class==1 then we use the topic model
				{
					// get word's currently assigned topic
					topic = z[i];
					
					// remove current counts from markov chain
					nwz[wioffset + topic]--;
					nz[topic]--;
				} 
				else		// else we use the HMM (syntax model)
				{
					// remove current counts from markov chain
					nwc[word*S + c[i]]--;
				}
				
				// decrement counts
				nc[c[i]]--;
				stot[second[i]*S2 + first[i]*S + c[i]]--;
				
				max = 0; 
				best = 0; 
				totprob = 0;
				
				// build probability distribution for topics sampling from markov chain
				for (j = 0; j < T; j++)
				{
					if(cdocset==j)
						probs[j] = 1.0;
					else
						probs[j] = 0.0;
					
					if (c[i] == 1)
						probs[j] *= ((double)nwz[wioffset + j] + beta) / ((double)nz[j] + betasum);
					totprob += probs[j];
				}
				
				// sample from distribution to get new topic
				r = ((double)rand() / (double)(RAND_MAX)) * totprob;
				max = probs[0];
				j = 0;
				while (r>max)
				{
					j++;
					max += probs[j];
				}
				
				// assign new sampled topic
				z[i] = j;
				
				probs[1]=((double)nwz[wioffset+j] + beta) / ((double)nz[j] + betasum)
				*((double)sp[preprev*S3 + prev*S2 + current*S + 1] + gamma)
				*((double)sp[prev*S3 + current*S2 + S + c[i+1]] + gamma) / ((double)stot[prev*S2 + current*S + 1] + gammasum)
				*((double)sp[current*S3 + S2 + c[i+1]*S + c[i+2]] + gamma) / ((double)stot[current*S2 + S + c[i+1]] + gammasum)
				*((double)sp[S3 + c[i+1]*S2 + c[i+2]*S + c[i+3]] + gamma) / ((double)stot[S2 + c[i+1]*S + c[i+2]] + gammasum);
				
				totprob = probs[1];
				
				// build probability distribution for states sampling from markov chain
				for (j = 2; j < S; j++)
				{
					probs[j]=((double)nwc[word*S + j] + beta) / ((double)nc[j] + betasum)
					*((double)sp[preprev*S3 + prev*S2 + current*S + j] + gamma)
					*((double)sp[prev*S3 + current*S2 + j*S + c[i+1]] + gamma) / ((double)stot[prev*S2 + current*S + j] + gammasum)
					*((double)sp[current*S3 + j*S2 + c[i+1]*S + c[i+2]] + gamma) / ((double)stot[current*S2 + j*S + c[i+1]] + gammasum)
					*((double)sp[j*S3 + c[i+1]*S2 + c[i+2]*S + c[i+3]] + gamma) / ((double)stot[j*S2 + c[i+1]*S + c[i+2]] + gammasum);
					totprob += probs[j];
				}
				
				// sample from distribution to get new state
				r = ((double)rand() / (double)RAND_MAX) * totprob;
				max = probs[1];
				j = 1;
				while (r > max)
				{
					j++;
					max += probs[j];
				}
				
				// assign new sampled syntax class
				c[i] = j;
				
				sp[preprev*S3 + prev*S2 + current*S + c[i]]++;
				if(c[i] == 1)
				{
					// set topic and increment counts
					topic = z[i];
					nwz[wioffset + topic]++;
					nz[topic]++;
				} 
				else
				{
					nwc[word*S + c[i]]++;
				}
				
				// increment counts and move everything up one
				nc[c[i]]++;
				stot[prev*S2 + current*S + c[i]]++;
				first[i] = current;
				second[i] = prev;
				third[i] = preprev;
				preprev = prev;
				prev = current;
				current = c[i];
			}
		} 
	}
}

void syntaxsum::run(int iterations)
{
	// seed random number generator
	srand(time(NULL));
	
	// run gibbs sampler for 'iterations' times...
	gibbs(iterations);
}

void syntaxsum::write()
{
	// make topic file
	FILE *fp = fopen("WO", "r");

	if(fp == NULL)
	{
		fprintf(stderr, "Could not find vocabulary file 'WO'. It must be in this directory.\n");
		exit(1);
	}

	vector<string> words;
	char word[255];
	while(fscanf(fp,"%s",word)!=EOF)
		words.push_back(word);
	fclose(fp);
	
	// make lda style zeta matrix
	fp = fopen("zeta", "w");

	if(fp == NULL)
	{
		fprintf(stderr, "Could not create output files.\n");
		exit(1);
	}

	for(int j=0; j<T; j++)
	{
		for(int i=0; i<W; i++)
		{
			fprintf(fp, "%lf ", ((double)nwz[j+i*T] + beta) / ((double)nz[j] + (double)(beta*W)));
		}
		fprintf(fp, "\n");
	}
	fclose(fp);
	
	
	// make "matlab-style" files for use with python summarization scripts
	fp = fopen("WP.txt", "w");
	for(int j=0; j<T; j++)
	{
		fprintf(fp, "TOPIC_%d 0\n\n", j+1);
		for(int i=0; i<W; i++)
		{
			fprintf(fp, "%s %lf\n", words[i].c_str(), ((double)nwz[j+i*T] + beta) / ((double)nz[j] + (double)(beta*W)));
		}
		fprintf(fp, "\n\n");
	}
	fclose(fp);
	
	fp = fopen("MP.txt", "w");
	for(int j=0; j<S; j++)
	{
		fprintf(fp, "CLASS_%d 0\n\n", j+1);
		for(int i=0; i<W; i++)
		{
			fprintf(fp, "%s %lf\n", words[i].c_str(), ((double)nwc[j+i*S] + gamma) / ((double)nc[j] + (double)(gamma*W)));
		}
		fprintf(fp, "\n\n");
	}
	fclose(fp);	
}
