///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include "nl-cpt.h"
#include "nl-dtree.h"
using namespace std;
//
////////////////////////////////////////////////////////////////////////////////

//// P: part of speech category...
//// it seems that I don't need this since I do have Pos in my HModel
//DiscreteDomain<short> domainP;
//typedef DiscreteDomainRV<short,domainP> P;


//the following arguments do not work well, but can be improved
//-v genmodel/indmodel.pronountrain genmodel/genmodel.pronountrain genmodel/nummodel.pronountrain genmodel/synmodel.pronountrain genmodel/opmodel.train genmodel/posmodel.pronountrain /home/dingcheng/Documents/NLPWorkspace/wsjparse/genmodel/posWordDepmodel.train genmodel/wordDepModel.pro genmodel/wordModel.pro genmodel/349699.pronountrain



////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//// Preterminal (POS) given constituent category models...
/// I know why G can be used here directly since G is a class from HHMMLangModel-?.h
/// when compiling HHMMLangModel-?.cpp which include both HHMMLangModel-?.h and TextObsModel.h,
/// Therefore, the compiler can find both classes in both files.
/// This is a point very different from Java in which, each file involves only one main class.
/// At this point, I confused java files and cpp files and thus failed to understand why G can be used
/// directly here. In fact, it is like call a java class within the same package.
/// Dingcheng's study notes Sept. 16, 09.


//// Preterminal (POS) given constituent category models...
//typedef HidVarCPT2DModel<Pos,C,LogProb> PgivCModel;

const int MAXELEMENT = 12;
//string pronArray[MAXELEMENT] = {"she", "her", "hers","he","him","his","they","their","theirs","them","it","its","this","that","these","those","whose","whom"};
string pronArray[MAXELEMENT] = {"she", "her", "hers","he","him","his","they","their","theirs","them","it","its"};
List<X> pronList;

void fillPronList(){
	pronList = List<X>();
	for(int ii=0;ii<MAXELEMENT;ii++){
		X w(pronArray[ii].c_str());
		pronList.add()= w;
	}
}


//// Generative model of word given tag...
class WModel {
private:
	TrainableDTree2DModel<Pos,X,LogProb> modPgivWdt;
	RandAccCPT2DModel<Pos,X,LogProb> modPgivWs;
	RandAccCPT1DModel<Pos,LogProb> modP;
	RandAccCPT1DModel<X,LogProb> modW;

public:
	//LogProb getProb ( const X& w, const HidVarCPT1DModel<P,LogProb>::IterVal& p ) const {
	LogProb getProb ( const X& w, const Pos::ArrayIterator<LogProb>& p ) const {
		assert(modP.getProb(p)!=LogProb());
		LogProb pr = ( (  modW.contains(w) ? modPgivWs.getProb(p,w) : modPgivWdt.getProb(p,w) )
				* LogProb(-1000) / modP.getProb(p) );
		return pr;
	}
	void writeFields ( FILE* pf, string sPref ) { modPgivWdt.writeFields(pf,sPref); }
	friend pair<StringInput,WModel*> operator>> ( StringInput si, WModel& m ) { return pair<StringInput,WModel*>(si,&m); }
	friend StringInput operator>> ( pair<StringInput,WModel*> delimbuff, const char* psD ) {
		StringInput si;
		return ( (si=delimbuff.first>>"X "   >>delimbuff.second->modW      >>psD)!=NULL ||
				(si=delimbuff.first>>"Pw "  >>delimbuff.second->modPgivWs >>psD)!=NULL ||
				(si=delimbuff.first>>"PwDT ">>delimbuff.second->modPgivWdt>>psD)!=NULL ||
				(si=delimbuff.first>>"P "   >>delimbuff.second->modP      >>psD)!=NULL ) ? si : StringInput(NULL);
	}
};

DiscreteDomain<short> domainIndex;
typedef DiscreteDomainRV<short,domainIndex> EntIndex;


//////////////////////////////////////// "Wrapper" models for individual RVs...
typedef RandAccCPT2DModel<Gen,X,LogProb> GenGivWModel;
typedef RandAccCPT2DModel<Num,X,LogProb> NumGivWModel;
typedef RandAccCPT2DModel<Syn,X,LogProb> SynGivWModel;
typedef RandAccCPT2DModel<Pos,X,LogProb> PosGivWModel;
typedef RandAccCPT2DModel<X, EntIndex, LogProb> IndexGivWModel;


class OModel {
private:
	WModel     modWgivP;
	GenGivWModel mGen_X;
	NumGivWModel mNum_X;
	SynGivWModel mSyn_X;
	PosGivWModel mPos_X;
	IndexGivWModel mEntity_X;
	RandAccCPT1DModel<Gen,LogProb> mGender;
	RandAccCPT1DModel<Num,LogProb> mNumber;
	RandAccCPT1DModel<Syn,LogProb> mSyntax;
	RandAccCPT1DModel<Pos,LogProb> mPos;
	RandAccCPT1DModel<X,LogProb> mW;
public:
	class DistribModeledWgivG {
	private:
		X wW;
		EntIndex indexW;
		LogProb pr;
	public:

		DistribModeledWgivG& set ( const X& w, const EntIndex & index, const OModel& m ) {  wW=w; indexW=index; return *this; }
		//Prob&   setProb ( const Pos & pos )       { return hcpCache.set(pos); }
		//LogProb getProb ( const Pos & pos ) const { return LogProb(hcpCache.get(pos)); }
		LogProb getProb(const OModel& m, const S & s) const {return m.calcProb(*this, s);}
		X       getW    ( )            const { return wW; }
		EntIndex       getIndex    ( )            const { return indexW; }
	};

	typedef DistribModeledWgivG RandVarType;
	LogProb calcProb ( const OModel::RandVarType& o, const S& s) const{
		fillPronList();
		X w = o.getW();
		EntIndex indexW = o.getIndex();
		//cerr<<"\nindexW: "<<indexW<<endl;
		//Pos::ArrayIterator<LogProb> pos;
//		if(w==X("their")){
//			cout<<" start the break point in ObsModel."<<endl;
//		}
		LogProb pr;
		//cerr<<"mEntity_w.getProb(indexW,w): "<<mEntity_X.getProb(w, indexW)<<" LogProb: "<<LogProb()<<" s.first: "<<s.first<<" w: "<<w<<" indexW: "<<indexW<<endl;
		if(w==X("eos")){
			pr=mGen_X.getProb(Gen("-"),w);
			pr*=mNum_X.getProb(Num("-"),w);
			pr*=mSyn_X.getProb(Syn("-"),w);
			pr *= mPos_X.getProb(Pos("-"),w);
			pr*= mW.getProb(w);
			pr/=mGender.getProb(Gen("-"));
			pr/=mNumber.getProb(Num("-"));
			pr/=mSyntax.getProb(Syn("-"));
			pr/=mPos.getProb(Pos("-"));
		}
		else if((mEntity_X.getProb(w,indexW)==0 && s.first==OP(COPY))
				||(mEntity_X.getProb(w,indexW)==LogProb() && s.first!=OP(COPY))){
			pr = LogProb();
			//cerr<<" pr when it is not an entity but op is suppsed to be entity:  "<<pr<<endl;
		}
		else if((mEntity_X.getProb(w,indexW)==0 && s.first!=OP(COPY))||(mEntity_X.getProb(w,indexW)==LogProb() && s.first==OP(COPY))){
			if(mW.contains(w)){
				if(pronList.contains(w) && s.first != OP(OLD)){
					pr = LogProb();
				}else if(s.first==OP(OLD) || s.first == OP(NEW)){
					pr=mGen_X.getProb(Gen(s.second.first.second.first.getString().substr(s.second.first.second.first.getString().find_last_of("_")+1).c_str()),w);
				//	cerr<<" gen: "<<s.second.first.second.first.getString().substr(s.second.first.second.first.getString().find_last_of("_")+1)<<" word:  "<<w<<" Pr: "<<pr<<endl;
					pr*=mNum_X.getProb(Num(s.second.first.third.first.getString().substr(s.second.first.third.first.getString().find_last_of("_")+1).c_str()),w);
				//	cerr<<" num: "<<s.second.first.third.first.getString().substr(s.second.first.third.first.getString().find_last_of("_")+1)<<" word:  "<<w<<" pr: "<<pr<<endl;
					pr*=mSyn_X.getProb(Syn(s.second.first.fourth.first.getString().substr(s.second.first.fourth.first.getString().find_last_of("_")+1).c_str()),w);
				//	cerr<<"syn: "<<s.second.first.fourth.first.getString().substr(s.second.first.fourth.first.getString().find_last_of("_")+1)<<" word:  "<<w <<" pr: "<<pr<<endl;
					//pr*=mPos_X.getProb(Pos(s.second.second.getString().substr(s.second.second.getString().find_last_of("_")+1).c_str()),w);
					pr *= mPos_X.getProb(s.second.second,w);
				//	cerr<<"pos: "<<s.second.second.getString().substr(s.second.second.getString().find_last_of("_")+1)<<" word:  "<<w<<" pr: "<<pr<<endl;
					pr*= mW.getProb(w);
				//	cerr<<"mW: "<<" word:  "<<w<<" pr: "<<pr<<endl;
					pr/=mGender.getProb(Gen(s.second.first.second.first.getString().substr(s.second.first.second.first.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mGender: "<<pr<<endl;
					pr/=mNumber.getProb(Num(s.second.first.third.first.getString().substr(s.second.first.third.first.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mNumber: "<<pr<<endl;
					pr/=mSyntax.getProb(Syn(s.second.first.fourth.first.getString().substr(s.second.first.fourth.first.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mSyntax: "<<pr<<endl;
					pr/=mPos.getProb(Pos(s.second.second.getString().substr(s.second.second.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mPos: "<<pr<<endl;
					//cout<<" pr when there may be an entity an entity but op is suppsed to be entity: "<<OP(s.first)<<" pr: "<<pr<<endl;
				}else{
					pr=mPos_X.getProb(s.second.second,w);
				//	cerr<<"mPos: "<<s.second.second<<" w: "<<w<<" pr: "<<pr<<endl;
					pr*=mW.getProb(w);
				//	cerr<<"mW: "<<w<<" pr "<<pr<<endl;
					pr/=mPos.getProb(s.second.second);
				//	cerr<<"mPos: "<<w<<" pr "<<pr<<endl;
				}
			}
			else{
				if((s.first==OP(OLD) || s.first == OP(NEW))){
					Pos::ArrayIterator<LogProb> pos;
	//				pr=p=mPgivWdt.getProb(w,pos);
	//				pr*=p=LogProb(-1000);
	//				pr/=mPos.getProb(pos);
					X wUnknown = X("unkword");
					pr=mPos_X.getProb(s.second.second,wUnknown);
				//	cerr<<"pos: "<<s.second.first<<" word:  "<<w<<" pr: "<<pr<<endl;
					pr*=mGender.getProb(Gen(s.second.first.second.first.getString().substr(s.second.first.second.first.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mGender: "<<pr<<endl;
					pr*=mNumber.getProb(Num(s.second.first.third.first.getString().substr(s.second.first.third.first.getString().find_last_of("_")+1).c_str()));
				//	cerr<<"mNumber: "<<pr<<endl;
					pr*=mSyntax.getProb(Syn(s.second.first.fourth.first.getString().substr(s.second.first.fourth.first.getString().find_last_of("_")+1).c_str()));
					pr*=mW.getProb(wUnknown);
					pr/=mPos.getProb(s.second.second);
				}else{
					Pos::ArrayIterator<LogProb> pos;
	//				pr=p=mPgivWdt.getProb(w,pos);
	//				pr*=p=LogProb(-1000);
	//				pr/=mPos.getProb(pos);
					X wUnknown = X("unkword");
					pr=mPos_X.getProb(s.second.second,wUnknown);
				//	cerr<<"pos: "<<s.second.first<<" word:  "<<w<<" pr: "<<pr<<endl;
					pr*=mW.getProb(wUnknown);
					pr/=mPos.getProb(s.second.second);
				}	//cout<<" unknown words, but it is not: "<<endl;
			}
		}
		return pr;
	}

	LogProb getProb ( const OModel::RandVarType& o, const S & s) const {  return o.getProb(*this, s); }

	friend pair<StringInput,OModel*> operator>> ( StringInput si, OModel& m ) { return pair<StringInput,OModel*>(si,&m); }
	friend StringInput operator>> ( pair<StringInput,OModel*> delimbuff, const char* psD ) {
		StringInput si;
		return ( (si=delimbuff.first>>"Gen_X ">>delimbuff.second->mGen_X>>psD)!=NULL   ||
				(si=delimbuff.first>>"Num_X " >>delimbuff.second->mNum_X>>psD)!=NULL  ||
				(si=delimbuff.first>>"Syn_X " >>delimbuff.second->mSyn_X>>psD)!=NULL  ||
				(si=delimbuff.first>>"Pos_X " >>delimbuff.second->mPos_X>>psD)!=NULL  ||
				(si=delimbuff.first>>"Gender " >>delimbuff.second->mGender>>psD)!=NULL  ||
				(si=delimbuff.first>>"Number " >>delimbuff.second->mNumber>>psD)!=NULL  ||
				(si=delimbuff.first>>"Syntax " >>delimbuff.second->mSyntax>>psD)!=NULL ||
				(si=delimbuff.first>>"Word " >>delimbuff.second->mW>>psD)!=NULL ||
				(si=delimbuff.first>>"Entity " >>delimbuff.second->mEntity_X>>psD)!=NULL ||
				(si=delimbuff.first>>"Pos " >>delimbuff.second->mPos>>psD)!=NULL  ? si : StringInput(NULL));
	}
	void writeFields ( FILE* pf, string sPref ) { modWgivP.writeFields(pf,sPref); }
};
