import json
import random
import torch
import numpy as np
from tqdm import tqdm

from utils import constant, helper, vocab
from RAMS.classification_code.code.model.tree import Tree, head_to_tree, tree_to_adj

class DataLoader(object):
    def __init__(self, filename, batch_size, opt, vocab, evaluation=False):
        self.batch_size = batch_size
        self.opt = opt
        self.vocab = vocab
        self.eval = evaluation
        self.label2id = constant.LABEL_TO_ID

        with open(filename) as infile:
            data = json.load(infile)
        self.raw_data = data
        data = self.preprocess(data, vocab, opt)
        bert_vectors = None

        if not evaluation:
            indices = list(range(len(data)))
            random.shuffle(indices)
            data = [data[i] for i in indices]
        self.id2label = dict([(v,k) for k,v in self.label2id.items()])
        self.labels = [self.id2label[d[-1]] for d in data]
        self.num_examples = len(data)

        data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
        self.data = data
        print("{} batches created for {}".format(len(data), filename))

    def preprocess(self, data, vocab, opt):
        processed = []
        for d in tqdm(data):
            boundaries = d['boundaries']
            for i in range(len(boundaries)):
                boundaries[i] = sum(boundaries[:i+1])
            trig_sent = -1
            arg_sent = -1
            for i, bound in enumerate(boundaries):
                if d['trigger_start'] < bound:
                    trig_sent = i
                    break
            for i, bound in enumerate(boundaries):
                if d['arg_start'] < bound:
                    arg_sent = i
                    break
            if trig_sent == -1:
                print(boundaries)
                print(d['trigger_start'])
                exit(1)
            if arg_sent == -1:
                print(boundaries)
                print(d['arg_stat'])
                exit(1)
            if trig_sent != arg_sent:
                continue
            tokens = list(d['tokens'])
            if opt['lower']:
                tokens = [t.lower() for t in tokens]
            tokens = map_to_ids(tokens, vocab.word2id)
            head = [int(x) for x in d['headd']]
            assert any([x == 0 for x in head])
            l = len(tokens)
            trigger_positions = get_positions(d['trigger_start'], d['trigger_end'], l)
            arg_positions = get_positions(d['arg_start'], d['arg_end'], l)
            trig_scores = get_dist(d['trigger_start'],[h-1 if h > 0 else i for i, h in enumerate(head)])
            arg_scores = get_dist(d['arg_start'],[h-1 if h > 0 else i for i, h in enumerate(head)])
            scores = [max([trig_scores[i],arg_scores[i]]) for i in range(len(trig_scores))]
            dep_path = d['dep_path']
            adj = tree_to_adj(l, head_to_tree(head,  l), directed=False, self_loop=False).reshape(1, l, l)[0]
            # for i in range(l):
            #     for j in range(l):
            #         if dep_path[i] == 0 or dep_path[j] == 0:
            #             adj[i,j] = 0
            label = self.label2id[d['label']]
            ## BERT
            bads = 0
            try:
                bert = bert_vectors[d['doc_key']]
            except:
                bert = []
                for _ in range(len(d['tokens'])):
                    bert += [[0]*768]
                bads += 1
            processed += [(tokens, head, trigger_positions, arg_positions, dep_path, adj, bert, scores, label)]
        print('Bads: ',bads/len(data))
        return processed

    def gold(self):
        return self.labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, key):
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 9

        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        if not self.eval:
            words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
        else:
            words = batch[0]

        words = get_long_tensor(words, batch_size)
        masks = torch.eq(words, 0)
        head = get_long_tensor(batch[1], batch_size)
        trigger_positions = get_long_tensor(batch[2], batch_size)
        arg_positions = get_long_tensor(batch[3], batch_size)
        dep_path = get_long_tensor(batch[4], batch_size)
        adj = get_adj_tensor(batch[5], batch_size)
        bert = get_bert_tensor(batch[6], batch_size)
        scores = get_long_tensor(batch[7], batch_size)

        labels = torch.LongTensor(batch[8])

        return (words, masks, head, trigger_positions, arg_positions, dep_path, adj, bert, scores, labels, orig_idx)

    def __iter__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

def map_to_ids(tokens, vocab):
    ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
    return ids

def get_distance(pivot, source, heads, d=0, colors=[]):
    if pivot == source or colors[source] == 1:
        return d
    else:
        colors[source] = 1
        return get_distance(pivot, heads[source], heads, d=d+1, colors=colors)

def get_dist(pivot, heads):
    dist = []
    for i in range(len(heads)):
        dist += [len(heads) - get_distance(pivot, i, heads, d=0, colors=[0]*len(heads))]
    return dist

def get_positions(start_idx, end_idx, length):
    return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \
            list(range(1, length-end_idx))

def get_bert_tensor(tokens_list, batch_size):
    token_len = max(len(x) for x in tokens_list)
    tokens = torch.FloatTensor(batch_size, token_len, 768).fill_(constant.PAD_ID)
    for i, s in enumerate(tokens_list):
        tokens[i, :len(s), :] = torch.FloatTensor(s)
    return tokens

def get_long_tensor(tokens_list, batch_size):
    token_len = max(len(x) for x in tokens_list)
    tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID)
    for i, s in enumerate(tokens_list):
        tokens[i, :len(s)] = torch.LongTensor(s)
    return tokens

def get_adj_tensor(tokens_list, batch_size):
    token_len = max(len(x) for x in tokens_list)
    tokens = torch.FloatTensor(batch_size, token_len, token_len).fill_(constant.PAD_ID)
    for i, s in enumerate(tokens_list):
        tokens[i, :len(s), :len(s)] = torch.FloatTensor(s)
    return tokens

def sort_all(batch, lens):
    unsorted_all = [lens] + [range(len(lens))] + list(batch)
    sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
    return sorted_all[2:], sorted_all[1]

def word_dropout(tokens, dropout):
    return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \
            else x for x in tokens]

