import random
import torch
import numpy as np
import torch.functional
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from best.attention import MultiplicativeAttention, MultilinearAttention, InteriorMultilinear, MLPAttention
from collections import OrderedDict

class Predict(nn.Module):
		def __init__(self, opinion, tagset_size, hidden_dim, pair_dim, resolve_author, dropout=0.05, attention="",
		 parameterization="rank", attention_hyperparam=5, use_entities=True):
				super(Predict, self).__init__()
				self.opinion = opinion
				if not attention or attention == "None" or attention == "none":
						self.attention = None
				else:
						self.attention = attention
				self.parameterization = parameterization
				self.hidden_dim = hidden_dim
				self.pair_dim_ = pair_dim
				self.resolve_author = resolve_author
				self.attention_hyperparam = attention_hyperparam
				self.multiplicative_attention = MultiplicativeAttention(self.hidden_dim)
				self.multilinear_attention = MultilinearAttention(self.hidden_dim, self.attention_hyperparam)
				self.interior_multilinear = InteriorMultilinear(self.hidden_dim, self.attention_hyperparam)
				self.mlp_attention = MLPAttention(3 * self.hidden_dim)
				self.ranking_null = nn.Parameter(torch.Tensor(self.hidden_dim))
				init.uniform_(self.ranking_null, -0.001, 0.001)
				self.hid = nn.Linear(self.pair_dim_, self.hidden_dim)
				self.out_layers = nn.Sequential(OrderedDict([
						("all", nn.Linear(self.hidden_dim, tagset_size)), 
						('relation', nn.Linear(self.hidden_dim, tagset_size)), 
						('event', nn.Linear(self.hidden_dim, tagset_size)),
						('entity', nn.Linear(self.hidden_dim, tagset_size))
				]))
				self.rank_pairs = nn.Sequential(OrderedDict([
						("all", nn.Linear(self.pair_dim_, tagset_size)), 
						('relation', nn.Linear(self.pair_dim_, tagset_size)), 
						('event', nn.Linear(self.pair_dim_, tagset_size)),
						('entity', nn.Linear(self.pair_dim_, tagset_size))
				]))
				self.dropout = dropout
				self.dropout_layer = nn.Dropout(p=self.dropout, inplace=False)
				self.use_entities = use_entities

		def make_H(self, src, trg, encoded_doc):
				H = torch.cat((src,trg),0)
				H = H.unsqueeze(0)
				H = H.expand(len(encoded_doc), 2 * self.hidden_dim)
				H = torch.cat((H, encoded_doc), 1)
				H = H.unsqueeze(0)
				return H

		def get_bounds(self, doc, entity_id, offset, length, author_flag, encoded_doc):
				start, _ = doc.offset_to_flat_tokens(offset, length + 1)
				diff = 18
				if not start:
					return 0, len(encoded_doc) - 1
				if not entity_id or author_flag:
						return max(0, start - diff), max(0, start)
				if entity_id in doc.evaluator_ere.entities:
								entity = doc.evaluator_ere.entities[entity_id]
								mentions = entity.mentions
								for mention in mentions:
										t_start, t_end = doc.offset_to_flat_tokens(mention.offset, mention.length + 1)
										if 0 < start - t_end:
											diff = min(diff, start - t_end)
				return max(0, start - diff), max(0, start)

		def get_attention(self, doc, src, trg, encoded_doc, src_id, trg_id,
								offset, length, ux_src_dict, ux_trg_dict, ux_doc, author_flag):
				if self.attention == "multilinear":
						return self.multilinear_attention(src, trg, encoded_doc,
						 src_id, trg_id, ux_src_dict, ux_trg_dict, ux_doc)
				elif self.attention == "interiormultilinear":
					start, end = self.get_bounds(doc, src_id, offset, length, author_flag, encoded_doc)
					return self.interior_multilinear(src, trg, encoded_doc,
																						 src_id, trg_id, start,
																						 end, ux_src_dict, ux_trg_dict, ux_doc)
				elif self.attention == "multiplicative":
					start, end = self.get_bounds(doc, src_id, offset, length, author_flag, encoded_doc)
					return self.multiplicative_attention(src, trg, encoded_doc,
																						 src_id, trg_id, start,
																						 end, ux_src_dict, ux_trg_dict, ux_doc)
				elif self.attention == "mlp":
						H = self.make_H(src, trg, encoded_doc)
						scores, attention = self.mlp_attention(H, lengths=None)
						scores = scores.squeeze(0)
						attention = attention.squeeze(0)
						return scores, attention, ux_src_dict, ux_trg_dict, ux_doc
				else:
						raise NotImplementedError

		def combine_pair(self, src, trg, attention=None):
				if self.attention:
						return torch.cat((src, trg, attention))
				else:
						return torch.cat((src, trg))

		def _predict_pairs(self, enc_pairs, mention_type, combination=None):
				if not enc_pairs:
						return None
				enc_pairs = torch.stack(enc_pairs)
				hid = torch.tanh(self.hid(enc_pairs))
				hid = self.dropout_layer(hid)
				scores = self.out_layers._modules[mention_type](hid)
				if combination:
						scores = scores + self.out_layers._modules['all'](hid)
						return scores
				else:
						return scores

		def classify(self, doc, encoded_doc, enc_entities, enc_mentions):
				ux_src_dict = {}
				ux_trg_dict = {}
				ux_doc = None
				if(self.opinion == "sentiment") and self.use_entities:
						enc_entity_pairs = []
						for (src_id, trg_m_id) in doc.pairs_entity:
							trg_mention = doc.evaluator_ere.entity_mentions[trg_m_id]
							offset = trg_mention.offset
							length = trg_mention.length
							post_author = doc.get_author(offset)
							enc_src = enc_entities[src_id]
							enc_trg = enc_mentions[trg_m_id]
							author_flag = False
							if hasattr(enc_src, 'lower') or hasattr(enc_trg, 'lower'):
									author_flag = True
							enc_src = self.resolve_author(encoded_value=enc_src, post_author=post_author)
							enc_trg = self.resolve_author(encoded_value=enc_trg, post_author=post_author)
							if self.attention:
									scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																													 enc_src,
																																													 enc_trg,
																																													 encoded_doc,
																																													 src_id,
																																													 trg_m_id,
																																													 offset,
																																													 length,
																																													 ux_src_dict,
																																													 ux_trg_dict,
																																													 ux_doc,
																																													 author_flag)

									enc_entity_pairs.append(self.combine_pair(enc_src, enc_trg, attention))
							else:
									enc_entity_pairs.append(self.combine_pair(enc_src, enc_trg))
				enc_relation_pairs = []
				for (src_id, trg_m_id) in doc.pairs_relation:
						trg_mention = doc.evaluator_ere.relation_mentions[trg_m_id]
						offset, length = doc.relation_mention_to_span(trg_mention)
						post_author = doc.get_author(offset)
						enc_src = enc_entities[src_id]
						author_flag = False
						if hasattr(enc_src, 'lower'):
							author_flag = True
						enc_src = self.resolve_author(encoded_value=enc_src, post_author=post_author)
						enc_trg = enc_mentions[trg_m_id]
						if self.attention:
								scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																												 enc_src,
																																												 enc_trg,
																																												 encoded_doc,
																																												 src_id,
																																												 trg_m_id,
																																												 offset,
																																												 length,
																																												 ux_src_dict,
																																												 ux_trg_dict,
																																												 ux_doc,
																																												 author_flag)

								enc_relation_pairs.append(self.combine_pair(enc_src, enc_trg, attention))
						else:
								enc_relation_pairs.append(self.combine_pair(enc_src, enc_trg))

				enc_event_pairs = []
				for (src_id, trg_m_id) in doc.pairs_event:
						trg_mention = doc.evaluator_ere.event_mentions[trg_m_id]
						offset, length = doc.event_mention_to_span(trg_mention)
						post_author = doc.get_author(offset)
						enc_src = enc_entities[src_id]
						author_flag = False
						if hasattr(enc_src, 'lower'):
							author_flag = True
						enc_src = self.resolve_author(encoded_value=enc_src, post_author=post_author)
						enc_trg = enc_mentions[trg_m_id]
						if self.attention:
								scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																												 enc_src,
																																												 enc_trg,
																																												 encoded_doc,
																																												 src_id,
																																												 trg_m_id,
																																												 offset,
																																												 length,
																																												 ux_src_dict,
																																												 ux_trg_dict,
																																												 ux_doc,
																																												 author_flag)
								enc_event_pairs.append(self.combine_pair(enc_src, enc_trg, attention))
						else:
								enc_event_pairs.append(self.combine_pair(enc_src, enc_trg))
				scores_relation = self._predict_pairs(enc_relation_pairs, mention_type="relation")
				scores_event = self._predict_pairs(enc_event_pairs, mention_type="event")
				if self.opinion == "sentiment" and self.use_entities:
					scores_entity = self._predict_pairs(enc_entity_pairs, mention_type="entity")
					return scores_entity, scores_relation, scores_event
				return scores_relation, scores_event

		def ranker(self, targets, mention_type="all"):
				predictions = []
				for target_m_id in targets:
					src_ids, pairs = targets[target_m_id]
					p = self.rank_pairs._modules[mention_type](targets[target_m_id][1])
					prediction = (nn.functional.softmax(p.view(1,-1), 1)).view(p.size()[0], p.size()[1])
					predictions.append(prediction)
				output = torch.cat(predictions, 0)
				return output

		def rank(self, doc, encoded_doc, enc_entities, enc_mentions):
				ux_src_dict = {}
				ux_trg_dict = {}
				ux_doc = None
				if(self.opinion == "sentiment") and self.use_entities:
						entity_targets = {}
						enc_entity_pairs = []
						for (src_id, trg_m_id) in doc.pairs_entity:
								trg_mention = doc.evaluator_ere.entity_mentions[trg_m_id]
								offset = trg_mention.offset
								length = trg_mention.length
								post_author = doc.get_author(offset)
								enc_src = enc_entities[src_id]
								enc_trg = enc_mentions[trg_m_id]
								author_flag = False
								if hasattr(enc_src, 'lower') or hasattr(enc_trg, 'lower'):
										author_flag = True
								enc_src = self.resolve_author(encoded_value=enc_src,
																							post_author=post_author)
								enc_trg = self.resolve_author(encoded_value=enc_trg,
																							post_author=post_author)
								if self.attention:
										scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																														 enc_src,
																																														 enc_trg,
																																														 encoded_doc,
																																														 src_id,
																																														 trg_m_id,
																																														 offset,
																																														 length,
																																														 ux_src_dict,
																																														 ux_trg_dict,
																																														 ux_doc,
																																														 author_flag)

								if trg_m_id not in entity_targets:
									if not self.attention:
											entity_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg))
											entity_targets[trg_m_id][0].append(src_id)
											entity_targets[trg_m_id] = (entity_targets[trg_m_id][0], torch.stack([entity_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg)],0))
									else:
											_, null_attention, _, _, _ = self.get_attention(doc,
																																			self.ranking_null,
																																			enc_trg,
																																			encoded_doc,
																																			None,
																																			trg_m_id,
																																			offset,
																																			length,
																																			ux_src_dict,
																																			ux_trg_dict,
																																			ux_doc,
																																			False)
											entity_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg, null_attention))
											entity_targets[trg_m_id][0].append(src_id)
											entity_targets[trg_m_id] = (entity_targets[trg_m_id][0], torch.stack([entity_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg, attention)],0))
								else:
									if self.attention:
											entity_targets[trg_m_id][0].append(src_id)
											entity_targets[trg_m_id] = (entity_targets[trg_m_id][0], torch.cat([entity_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg, attention)).unsqueeze(0)],0))

									else:
											entity_targets[trg_m_id][0].append(src_id)
											entity_targets[trg_m_id] = (entity_targets[trg_m_id][0], torch.cat([entity_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg)).unsqueeze(0)],0))
				relation_targets = {}
				enc_relation_pairs = []
				for (src_id, trg_m_id) in doc.pairs_relation:
						trg_mention = doc.evaluator_ere.relation_mentions[trg_m_id]
						offset, length = doc.relation_mention_to_span(trg_mention)
						post_author = doc.get_author(offset)
						enc_src = enc_entities[src_id]
						author_flag = False
						if hasattr(enc_src, 'lower'):
								author_flag = True
						enc_src = self.resolve_author(encoded_value=enc_src,
																					post_author=post_author)
						enc_trg = enc_mentions[trg_m_id]
						if self.attention:
								scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																												 enc_src,
																																												 enc_trg,
																																												 encoded_doc,
																																												 src_id,
																																												 trg_m_id,
																																												 offset,
																																												 length,
																																												 ux_src_dict,
																																												 ux_trg_dict,
																																												 ux_doc,
																																												 author_flag)
						if trg_m_id not in relation_targets:
								if not self.attention:
										relation_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg))
										relation_targets[trg_m_id][0].append(src_id)
										relation_targets[trg_m_id] = (relation_targets[trg_m_id][0], torch.stack([relation_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg)],0))
								else:
										_, null_attention, _, _, _ = self.get_attention(doc,
																																		self.ranking_null,
																																		enc_trg,
																																		encoded_doc,
																																		None,
																																		trg_m_id,
																																		offset,
																																		length,
																																		ux_src_dict,
																																		ux_trg_dict,
																																		ux_doc,
																																		False)
										relation_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg, null_attention))
										relation_targets[trg_m_id][0].append(src_id)
										relation_targets[trg_m_id] = (relation_targets[trg_m_id][0], torch.stack([relation_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg, attention)],0))
						else:
								if self.attention:
										relation_targets[trg_m_id][0].append(src_id)
										relation_targets[trg_m_id] = (relation_targets[trg_m_id][0], torch.cat([relation_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg, attention)).unsqueeze(0)],0))

								else:
										relation_targets[trg_m_id][0].append(src_id)
										relation_targets[trg_m_id] = (relation_targets[trg_m_id][0], torch.cat([relation_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg)).unsqueeze(0)],0))
				

				event_targets = {}
				enc_event_pairs = []
				for (src_id, trg_m_id) in doc.pairs_event:
						trg_mention = doc.evaluator_ere.event_mentions[trg_m_id]
						offset, length = doc.event_mention_to_span(trg_mention)
						post_author = doc.get_author(offset)
						enc_src = enc_entities[src_id]
						author_flag = False
						if hasattr(enc_src, 'lower'):
								author_flag = True
						enc_src = self.resolve_author(encoded_value=enc_src,
																					post_author=post_author)
						enc_trg = enc_mentions[trg_m_id]
						if self.attention:
								scores, attention, ux_src_dict, ux_trg_dict, ux_doc = self.get_attention(doc,
																																												 enc_src,
																																												 enc_trg,
																																												 encoded_doc,
																																												 src_id,
																																												 trg_m_id,
																																												 offset,
																																												 length,
																																												 ux_src_dict,
																																												 ux_trg_dict,
																																												 ux_doc,
																																												 author_flag)
						if trg_m_id not in event_targets:
								if not self.attention:
										event_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg))
										event_targets[trg_m_id][0].append(src_id)
										event_targets[trg_m_id] = (event_targets[trg_m_id][0], torch.stack([event_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg)],0))
								else:
										_, null_attention, _, _, _ = self.get_attention(doc,
																																		self.ranking_null,
																																		enc_trg,
																																		encoded_doc,
																																		None,
																																		trg_m_id,
																																		offset,
																																		length,
																																		ux_src_dict,
																																		ux_trg_dict,
																																		ux_doc,
																																		False)
										event_targets[trg_m_id] = (["NO_PRED"],self.combine_pair(self.ranking_null, enc_trg, null_attention))
										event_targets[trg_m_id][0].append(src_id)
										event_targets[trg_m_id] = (event_targets[trg_m_id][0], torch.stack([event_targets[trg_m_id][1],self.combine_pair(enc_src, enc_trg, attention)],0))
						else:
								if self.attention:
										event_targets[trg_m_id][0].append(src_id)
										event_targets[trg_m_id] = (event_targets[trg_m_id][0], torch.cat([event_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg, attention)).unsqueeze(0)],0))

								else:
										event_targets[trg_m_id][0].append(src_id)
										event_targets[trg_m_id] = (event_targets[trg_m_id][0], torch.cat([event_targets[trg_m_id][1],(self.combine_pair(enc_src, enc_trg)).unsqueeze(0)],0))
				scores_relation = self.ranker(relation_targets, mention_type="relation")
				scores_event = self.ranker(event_targets, mention_type="event")
				if self.opinion == "sentiment" and self.use_entities:
					scores_entity = self.ranker(entity_targets, mention_type = "entity")
					return scores_entity, scores_relation, scores_event
				return scores_relation, scores_event

		def forward(self, doc, encoded_doc, enc_entities, enc_mentions):
				if self.parameterization == "classify":
						return self.classify(doc, encoded_doc, enc_entities, enc_mentions)
				else:
						return self.rank(doc, encoded_doc, enc_entities, enc_mentions)
