import torch
import os
from torch import nn
from torch.nn import init
from best.Dict import Dict
from tqdm import tqdm
import flair
from flair.data import Sentence 
from flair.embeddings import ELMoEmbeddings
import pickle


class PretrainedEmbeddings(nn.Module):
    def __init__(self, path_to_embeddings, glove_file, embedding_dim, update_method):
        super(PretrainedEmbeddings, self).__init__()
        self.elmo = ELMoEmbeddings()
        self.update_method = update_method
        self.path = path_to_embeddings
        # glove = path_to_embeddings / glove_file
        # f = open(glove, 'r', encoding='utf-8')
        # self.vocab = Dict()
        # self.embeddings = []
        self.embedding_dim = embedding_dim
        # if self.embedding_dim > 3072:
        #     self.glove_dim = self.embedding_dim - 3072
        # else:
        #     self.glove_dim = self.embedding_dim
        # for line in f:
        #     split_line = line.split()
        #     token = split_line[0]
        #     self.vocab.add(token)
        #     if self.glove_dim:
        #         token = token + ''.join(split_line[:-self.glove_dim])
        #         embedding = [float(val) for val in split_line[-self.glove_dim:]]
        #     else:
        #         embedding = [float(val) for val in split_line[1:]]
        #     if self.glove_dim:
        #         assert self.glove_dim == len(embedding)
        #     else:
        #         self.glove_dim = len(embedding)
        #     embedding = torch.FloatTensor(embedding)
        #     self.embeddings.append(embedding)
        # self.unk_parameter = nn.Parameter(torch.Tensor(self.glove_dim))
        # self.lookup = nn.Embedding(len(self.embeddings), self.glove_dim)
        # embeds = torch.stack(self.embeddings)
        # self.lookup.weight = nn.Parameter(embeds)
        # self.projection_layer = nn.Linear(self.embedding_dim, self.embedding_dim)
        # if self.update_method == 'direct':
        #     print("Warning this may take a while during backprop")
        # elif self.update_method == 'projection':
        #     print("Warning this may take a while during backprop")
        #     self.lookup.weight = nn.Parameter(self.projection_layer(self.lookup.weight).data)
        # elif self.update_method in {'doc_projection', 'frozen'}:
        #     print(self.update_method)
        #     self.lookup.weight.requires_grad = False 
        # else:
        #     raise NotImplementedError
        # init.uniform_(self.unk_parameter, -0.01, 0.01)

    def forward(self, docs):
        embeddings = []
        elmo_embeddings = []
        for doc in tqdm(docs):
            if self.embedding_dim == 3372:
                doc_path = self.path / (doc.doc_id + '.pretrained')
            else:
                doc_path = self.path / (doc.doc_id + '.' + str(self.embedding_dim) + '.pretrained')
            if os.path.exists(doc_path):
                doc_embeddings = pickle.load(open(doc_path, 'rb'))
            else:
                # words = []
                contextual = []
                for sent in doc.tokenized['sentences']:
                #     for tok in sent['tokens']:
                #         word = tok['word']
                #         word = word.lower()
                #         if word in self.vocab.labelToIdx:
                #             embedding = self.lookup.weight[self.vocab.labelToIdx[word]]
                #         else:
                #             embedding = self.unk_parameter
                #         words.append(embedding)
                    
                    sentence = Sentence(' '.join(tok['word'] for tok in sent['tokens']))
                    self.elmo.embed(sentence)
                    for tok in sentence:
                        contextual.append(tok.embedding)
                #doc_embeddings = torch.cat([torch.stack(words),torch.stack(contextual)], 1)
                doc_embeddings = torch.stack(contextual)
                pickle.dump(doc_embeddings, open(doc_path, 'wb'))
            embeddings.append(doc_embeddings)
        for e in embeddings:
            e.requires_grad = False
        return embeddings
