import pickle
import re
from shutil import copyfile

import numpy
from nltk import DependencyGraph
from nltk.corpus import ptb

from helpers import Dictionary
from tree_utiles import dtree2ctree, ctree2last, ctree2distance


class CONLL2PKL_Creator(object):
    '''Data path is assumed to be a directory with
       pkl files and a corpora subdirectory.

    rules:
    tokens to delete:
    '''

    def __init__(self,
                 file_path,
                 dict_path,
                 build_dict=False):

        pickle_filepath = file_path + '.pkl'
        text_filepath = file_path + '.txt'
        self.build_dict = build_dict

        if build_dict:
            self.dictionary = Dictionary()
        else:
            self.dictionary = pickle.load(open(dict_path, 'rb'))

        print("loading trees ...")
        trees = self.load_trees(file_path)

        print("preprocessing ...")
        self.data = self.preprocess(trees, text_filepath)

        with open(pickle_filepath, "wb") as file_data:
            pickle.dump(self.data, file_data)
        if build_dict:
            with open(dict_path, "wb") as file_data:
                pickle.dump(self.dictionary, file_data)

    def preprocess(self, parse_trees, text_filepath):
        sens_idx = []
        sens = []
        trees = []
        distances = []
        distances_list = []
        text_file = open(text_filepath, "w")

        print('\nConverting trees ...')
        for i, tree in enumerate(parse_trees):
            # modify the tree before subsequent steps:
            for ind in range(1, len(tree.nodes)):
                node = tree.nodes[ind]
                head = node['head']
                label = node['tag']

                # delete unwanted tokens
                if label == '-NONE-':
                    tree.nodes[head]['deps'][node['rel']].remove(ind)
                    for rel, deps in node['deps'].items():
                        for dep in deps:
                            tree.nodes[head]['deps'][rel].append(dep)
                            tree.nodes[dep]['head'] = head
                    tree.remove_by_address(ind)

            if len(tree.nodes) == 1:
                continue  # skip the null trees (after deletion)

            word_lexs = [node['word'] for node in sorted(tree.nodes.values(), key=lambda v: v['address'])]
            word_lexs = ['<s>'] + word_lexs[1:]
            if self.build_dict:
                for word in word_lexs:
                    self.dictionary.add_word(word)

            # distances_sent = dtree2distance(tree)
            ctree = dtree2ctree(tree)
            distances_sent = ctree2distance(ctree, idx=0)
            distances_sent = [0] + distances_sent
            assert len(distances_sent) == len(word_lexs)

            lasts = ctree2last(ctree, 0)
            lasts = [0] + lasts
            assert len(lasts) == len(word_lexs)

            # text_file.write(str(distances_sent) + '\n')
            # text_file.write(str(lasts) + '\n')
            # # text_file.write(str(last2distance(lasts)) + '\n')
            # text_file.write(str(ctree) + '\n\n')

            sens.append(word_lexs)
            trees.append(tree)
            distances.append(lasts)
            distances_list.append(max(distances_sent))

            if i % 10 == 0:
                print("Done %d/%d\r" % (i, len(parse_trees)), end='')

        if self.build_dict:
            self.dictionary.rebuild_by_freq(thd=2)

        for sent in sens:
            idx = []
            processed_words = []
            for loc, word in enumerate(sent):
                if self.build_dict:
                    self.dictionary.add_unk(word, loc)
                index = self.dictionary.get_idx(word, loc)
                idx.append(index)
                processed_words.append(self.dictionary.idx2word[index])

            # text_file.write(tree.to_conll(10))
            text_file.write(' '.join(processed_words) + '\n')
            sens_idx.append(idx)

        text_file.close()

        max_idx = numpy.argmax(distances_list)
        max_depth = max(distances_list)
        print('Max sentence: ', ' '.join(sens[max_idx]))
        print('Max depth: ', max(distances_list))
        print('Mean depth: ', numpy.mean(distances_list))
        print('Median depth: ', numpy.median(distances_list))
        for i in range(1, max_depth + 1):
            print(distances_list.count(i), end='\t')
        print()
        print(trees[max_idx].to_conll(10))

        return sens_idx, distances

    def load_trees(self, file_path):
        file_path = file_path + '.conllu'
        with open(file_path, 'r') as out_file:
            output = out_file.read().strip()
        outputs = output.split('\n\n')

        def line_filter(input):
            line_list = input.split('\n')
            output_list = []
            for line in line_list:
                if line[0] != '#':
                    output_list.append(line)
            return '\n'.join(output_list)

        trees = [DependencyGraph(line_filter(o), top_relation_label='root') for o in outputs]
        return trees


if __name__ == '__main__':

    data_path = 'data/dependency/UD_English-PUD/en_pud-ud'
    dict_path = data_path + '-dict.pkl'

    copyfile('data/rnng/en_ptb-ud-dict.pkl', dict_path)

    # CONLL2PKL_Creator(file_path=data_path + '-dev',
    #                   dict_path=dict_path, )
    CONLL2PKL_Creator(file_path=data_path + '-test',
                      dict_path=dict_path, )
    # CONLL2PKL_Creator(file_path=data_path + '-train',
    #                   dict_path=dict_path, )

# class Corpus(object):
#     def __init__(self, path):
#         dict_file_name = os.path.join(path, 'dict.pkl')
#         if os.path.exists(dict_file_name):
#             self.dictionary = cPickle.load(open(dict_file_name, 'rb'))
#         else:
#             self.dictionary = Dictionary()
#             self.add_words(train_file_ids)
#             # self.add_words(valid_file_ids)
#             # self.add_words(test_file_ids)
#             self.dictionary.rebuild_by_freq()
#             cPickle.dump(self.dictionary, open(dict_file_name, 'wb'))
#
#         self.train, self.train_sens, self.train_trees = self.tokenize(train_file_ids)
#         self.valid, self.valid_sens, self.valid_trees = self.tokenize(valid_file_ids)
#         self.test, self.test_sens, self.test_trees = self.tokenize(test_file_ids)
#         self.rest, self.rest_sens, self.rest_trees = self.tokenize(rest_file_ids)
#
#     def filter_words(self, tree):
#         words = []
#         for w, tag in tree.pos():
#             if tag in word_tags:
#                 w = w.lower()
#                 w = re.sub('[0-9]+', 'N', w)
#                 # if tag == 'CD':
#                 #     w = 'N'
#                 words.append(w)
#         return words
#
#     def add_words(self, file_ids):
#         # Add words to the dictionary
#         for id in file_ids:
#             sentences = ptb.parsed_sents(id)
#             for sen_tree in sentences:
#                 words = self.filter_words(sen_tree)
#                 words = ['<s>'] + words + ['</s>']
#                 for word in words:
#                     self.dictionary.add_word(word)
#
#     def tokenize(self, file_ids):
#
#         def tree2list(tree):
#             if isinstance(tree, nltk.Tree):
#                 if tree.label() in word_tags:
#                     return tree.leaves()[0]
#                 else:
#                     root = []
#                     for child in tree:
#                         c = tree2list(child)
#                         if c != []:
#                             root.append(c)
#                     if len(root) > 1:
#                         return root
#                     elif len(root) == 1:
#                         return root[0]
#             return []
#
#         sens_idx = []
#         sens = []
#         trees = []
#         for id in file_ids:
#             sentences = ptb.parsed_sents(id)
#             for sen_tree in sentences:
#                 words = self.filter_words(sen_tree)
#                 words = ['<s>'] + words + ['</s>']
#                 # if len(words) > 50:
#                 #     continue
#                 sens.append(words)
#                 idx = []
#                 for word in words:
#                     idx.append(self.dictionary[word])
#                 sens_idx.append(torch.LongTensor(idx))
#                 trees.append(tree2list(sen_tree))
#
#         return sens_idx, sens, trees
