# coding: utf-8
import os
import json
import pickle
import torch

import numpy as np
import constant as C

from torch import nn
from allennlp.modules.elmo import Elmo, batch_to_ids
from collections import defaultdict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class HFet(nn.Module):

    def __init__(self,
                 label_size,
                 elmo_option,
                 elmo_weight,
                 elmo_dropout=.5,
                 repr_dropout=.2,
                 dist_dropout=.5,
                 latent_size=0,
                 ):
        super(HFet, self).__init__()
        self.label_size = label_size
        self.elmo = Elmo(elmo_option, elmo_weight, 1,
                         dropout=elmo_dropout)
        self.elmo_dim = self.elmo.get_output_dim()

        self.attn_dim = 1
        self.attn_inner_dim = self.elmo_dim
        # Mention attention
        self.men_attn_linear_m = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.men_attn_linear_o = nn.Linear(self.attn_inner_dim, self.attn_dim, bias=False)
        # Context attention
        self.ctx_attn_linear_c = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_m = nn.Linear(self.elmo_dim, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_d = nn.Linear(1, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_o = nn.Linear(self.attn_inner_dim,
                                        self.attn_dim, bias=False)
        # Output linear layers
        self.repr_dropout = nn.Dropout(p=repr_dropout)

        # Relative position (distance)
        self.dist_dropout = nn.Dropout(p=dist_dropout)

    def forward_nn(self, inputs, men_mask, ctx_mask, dist, gathers):
        # Elmo contextualized embeddings

        # print(inputs, inputs.size())        # 160,85,50
        # print(men_mask, men_mask.size())    # 160, 85
        # print(ctx_mask, ctx_mask.size())    # 160, 85
        # print(dist, dist.size())            # 160, 85
        # print(gathers, gathers.size())      # 160

        elmo_outputs = self.elmo(inputs)['elmo_representations'][0]
        _, seq_len, feat_dim = elmo_outputs.size()
        gathers = gathers.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_len, feat_dim)
        elmo_outputs = torch.gather(elmo_outputs, 0, gathers)

        men_attn = self.men_attn_linear_m(elmo_outputs).tanh()
        men_attn = self.men_attn_linear_o(men_attn)
        men_attn = men_attn + (1.0 - men_mask.unsqueeze(-1)) * -10000.0
        men_attn = men_attn.softmax(1)
        men_repr = (elmo_outputs * men_attn).sum(1)

        dist = self.dist_dropout(dist)
        ctx_attn = (self.ctx_attn_linear_c(elmo_outputs) +
                    self.ctx_attn_linear_m(men_repr.unsqueeze(1)) +
                    self.ctx_attn_linear_d(dist.unsqueeze(2))).tanh()
        ctx_attn = self.ctx_attn_linear_o(ctx_attn)

        ctx_attn = ctx_attn + (1.0 - ctx_mask.unsqueeze(-1)) * -10000.0
        ctx_attn = ctx_attn.softmax(1)
        ctx_repr = (elmo_outputs * ctx_attn).sum(1)

        # Classification
        final_repr = torch.cat([men_repr, ctx_repr], dim=1)
        final_repr = self.repr_dropout(final_repr)

        return final_repr

    def forward(self, inputs, labels, men_mask, ctx_mask, dist, gathers, inst_weights=None):
        final_repr = self.forward_nn(inputs, men_mask, ctx_mask, dist, gathers)
        return final_repr


class HFetSentenceDataset(torch.utils.data.Dataset):
    def __init__(self, fn, lbl2id_fn='data/ontology/onto_ontology.txt', tmp_fn='hfetsents.pkl'):
        ' Initialize the dataset by preprocessing into the desired format '
        super(HFetSentenceDataset, self).__init__()

        with open(lbl2id_fn) as f:
            self.lbl2id = {j:i for i,j in enumerate(f.read().strip().split('\n'))}
    
        # otherwise, load raw file and process
        ''' Train case'''
        with open(fn) as f:
            lines = f.read().strip().split('\n')

        
        self.data = []
        for i, line in enumerate(lines):
            print(i, end='\r')
            newdata = HFetSentenceDataset.processLine(line)
            if newdata:
                self.data.append(newdata)
            else:
                raise AssertionError
        print('Finished sentence preprocessing. ')

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        def _mask_to_distance(mask, mask_len, decay=.1):
            start = mask.index(1)
            end = mask_len - list(reversed(mask)).index(1)
            dist = [0] * mask_len
            for i in range(start):
                dist[i] = max(0, 1 - (start - i - 1) * decay)
            for i in range(end, mask_len):
                dist[i] = max(0, 1 - (i - end) * decay)
            return dist

        retr = []
        if isinstance(idx, int):
            retr.append(HFetSentenceDataset.numberize(self.data[idx], self.lbl2id))
        elif isinstance(idx, slice):
            retr.extend([HFetSentenceDataset.numberize(i, self.lbl2id) for i in self.data[idx]])
        elif isinstance(idx, (list, tuple)):
            pass
        elif torch.is_tensor(idx):
            idx = list(idx.cpu())

        if len(retr) == 0:
            for id_ in idx:
                retr.append(HFetSentenceDataset.numberize(self.data[id_], self.lbl2id))

        max_seq_len = max([x[-1] for x in retr])

        batch_char_ids = []
        batch_labels = []
        batch_men_mask = []
        batch_dist = []
        batch_ctx_mask = []
        batch_gathers = []
        batch_men_ids = []
        for inst_idx, inst in enumerate(retr):
            char_ids, labels, men_mask, ctx_mask, men_ids, anno_num, seq_len = inst
            batch_char_ids.append(char_ids + [[C.PAD_INDEX] * C.ELMO_MAX_CHAR_LEN
                                              for _ in range(max_seq_len - seq_len)])
            for ls in labels:
                batch_labels.append([1 if l in ls else 0
                                     for l in range(len(self.lbl2id))])
            for mask in men_mask:
                batch_men_mask.append(mask + [C.PAD_INDEX] * (max_seq_len - seq_len))
                batch_dist.append(_mask_to_distance(mask, seq_len)
                                  + [C.PAD_INDEX] * (max_seq_len - seq_len))
            for mask in ctx_mask:
                batch_ctx_mask.append(mask + [C.PAD_INDEX] * (max_seq_len - seq_len))
            batch_gathers.extend([inst_idx] * anno_num)
            batch_men_ids.extend(men_ids)

        batch_char_ids = torch.tensor(batch_char_ids, dtype=torch.long, device=device)
        batch_labels = torch.tensor(batch_labels, dtype=torch.float, device=device)
        batch_men_mask = torch.tensor(batch_men_mask, dtype=torch.float, device=device)
        batch_ctx_mask = torch.tensor(batch_ctx_mask, dtype=torch.float, device=device)
        batch_gathers = torch.tensor(batch_gathers, dtype=torch.long, device=device)
        batch_dist = torch.tensor(batch_dist, dtype=torch.float, device=device)

        # return (batch_char_ids, batch_labels, batch_men_mask, batch_ctx_mask,
        #         batch_dist, batch_gathers, batch_men_ids)
        return (batch_char_ids, batch_labels, batch_men_mask, batch_ctx_mask,
                batch_dist, batch_gathers)

    @staticmethod
    def processLine(line, mention_id=0):
        data_dict = json.loads(line)
        if isinstance(data_dict, dict):
            if 'left_context_token' in data_dict:
                left_context_token = data_dict['left_context_token']
            else:
                return
                
            if 'right_context_token' in data_dict:
                right_context_token = data_dict['right_context_token']
            else:
                return

            if 'mention_span' in data_dict:
                mention_span = data_dict['mention_span']
            else:
                return

            if 'y_str' in data_dict:
                y_str = data_dict['y_str']
            else:
                return

            mention_tokens = mention_span.split(' ')
            data = {}
            data['tokens'] = left_context_token + mention_tokens + right_context_token
            ann = [{'mention_id': str(mention_id), 'mention':mention_span,  \
                    'start': len(left_context_token), 'end': len(left_context_token)+len(mention_tokens),\
                    'labels': y_str}]
            data['annotations'] = ann
        return data

    @staticmethod
    def numberize(inst, label_stoi):
        tokens = inst['tokens']
        tokens = [C.TOK_REPLACEMENT.get(t, t) for t in tokens]
        seq_len = len(tokens)
        char_ids = batch_to_ids([tokens])[0].tolist()
        labels_nbz, men_mask, ctx_mask, men_ids = [], [], [], []
        annotations = inst['annotations']
        anno_num = len(annotations)
        for annotation in annotations:
            mention_id = annotation['mention_id']
            labels = annotation['labels']
            labels = [l.replace('geograpy', 'geography') for l in labels]
            start = annotation['start']
            end = annotation['end']

            men_ids.append(mention_id)
            labels = [label_stoi[l] for l in labels if l in label_stoi]
            labels_nbz.append(labels)
            men_mask.append([1 if i >= start and i < end else 0
                             for i in range(seq_len)])
            ctx_mask.append([1 if i < start or i >= end else 0
                             for i in range(seq_len)])
        return (char_ids, labels_nbz, men_mask, ctx_mask, men_ids, anno_num,
                seq_len)


def sentloader(sentset, batchsize, shuffle):
    length = len(sentset)
    import random
    inds = list(range(length))
    if shuffle: random.shuffle(inds)
    inds = torch.tensor(inds)

    n = length // batchsize + 1
    for i in range(n):
        yield sentset[inds[i*batchsize: (i+1)*batchsize]]

if __name__ == '__main__':

    # 初始化数据集
    batchsize = 16
    train_sent_dataset = HFetSentenceDataset('data/ontonotes/g_train.json', lbl2id_fn='data/ontology/onto_ontology.txt')
    train_loader = sentloader(train_sent_dataset, batchsize, True)

    # 初始化encoder
    elmo_option = 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json'
    elmo_weight = 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5'
    n_lbls = 89
    sample_encoder = HFet(n_lbls, elmo_option, elmo_weight, elmo_dropout=.5,\
                            repr_dropout=.2, dist_dropout=.2, latent_size=0)

        
    sample_encoder = sample_encoder.to(device)

    ' train '
    batch_data = next(train_loader)
    outputs = sample_encoder(*batch_data)

    print(outputs)