import pickle
import subprocess
import os

import numpy
from nltk import DependencyGraph
from helpers import Dictionary
from tree_utiles import dtree2ctree, ctree2last, ctree2distance


class CONLL2PKL_Creator(object):
    '''Data path is assumed to be a directory with
       pkl files and a corpora subdirectory.

    rules:
    tokens to delete:
    '''

    def __init__(self,
                 mrg,
                 file_path,
                 dict_path,
                 build_dict=False):

        dict_filepath = dict_path  # this .pkl file should contain the
        # Dictionary instance that has the ptc 10K
        # vocab
        pickle_filepath = file_path + '.pkl'
        text_filepath = file_path + '.txt'
        self.build_dict = build_dict

        if build_dict:
            self.dictionary = Dictionary()
        else:
            self.dictionary = pickle.load(open(dict_filepath, 'rb'))

        print("loading trees ...")
        trees = self.load_trees(mrg, file_path)

        print("preprocessing ...")
        self.data = self.preprocess(trees, text_filepath)

        with open(pickle_filepath, "wb") as file_data:
            pickle.dump(self.data, file_data)
        if build_dict:
            with open(dict_filepath, "wb") as file_data:
                pickle.dump(self.dictionary, file_data)

    def preprocess(self, parse_trees, text_filepath):
        sens_idx = []
        sens = []
        trees = []
        distances = []
        distances_list = []
        text_file = open(text_filepath, "w")

        print('\nConverting trees ...')
        for i, tree in enumerate(parse_trees):
            # modify the tree before subsequent steps:
            for ind in range(1, len(tree.nodes)):
                node = tree.nodes[ind]
                head = node['head']
                label = node['tag']

                # delete unwanted tokens
                if label == '-NONE-':
                    tree.nodes[head]['deps'][node['rel']].remove(ind)
                    for rel, deps in node['deps'].items():
                        for dep in deps:
                            tree.nodes[head]['deps'][rel].append(dep)
                            tree.nodes[dep]['head'] = head
                    tree.remove_by_address(ind)

            if len(tree.nodes) == 1:
                continue  # skip the null trees (after deletion)

            word_lexs = [node['word'] for node in sorted(tree.nodes.values(), key=lambda v: v['address'])]
            word_lexs = word_lexs[1:] + ['</s>']
            if self.build_dict:
                for word in word_lexs:
                    self.dictionary.add_word(word)

            # distances_sent = dtree2distance(tree)
            ctree = dtree2ctree(tree)
            distances_sent = ctree2distance(ctree, idx=0)
            distances_sent = distances_sent + [1]
            assert len(distances_sent) == len(word_lexs)

            lasts = ctree2last(ctree, 0)
            lasts = lasts + [-len(lasts)]
            assert len(lasts) == len(word_lexs)

            # text_file.write(str(distances_sent) + '\n')
            # text_file.write(str(lasts) + '\n')
            # # text_file.write(str(last2distance(lasts)) + '\n')
            # text_file.write(str(ctree) + '\n\n')

            sens.append(word_lexs)
            trees.append(tree)
            distances.append(lasts)
            distances_list.append(max(distances_sent))

            if i % 10 == 0:
                print("Done %d/%d\r" % (i, len(parse_trees)), end='')

        if self.build_dict:
            self.dictionary.rebuild_by_freq(thd=2)

        for sent in sens:
            idx = []
            processed_words = []
            for loc, word in enumerate(sent):
                if self.build_dict:
                    self.dictionary.add_unk(word, loc)
                index = self.dictionary.get_idx(word, loc)
                idx.append(index)
                processed_words.append(self.dictionary.idx2word[index])

            # text_file.write(tree.to_conll(10))
            text_file.write(' '.join(processed_words[:-1]) + '\n')
            sens_idx.append(idx)

        text_file.close()

        max_idx = numpy.argmax(distances_list)
        max_depth = max(distances_list)
        print('Max sentence: ', ' '.join(sens[max_idx]))
        print('Max depth: ', max(distances_list))
        print('Mean depth: ', numpy.mean(distances_list))
        print('Median depth: ', numpy.median(distances_list))
        for i in range(1, max_depth + 1):
            print(distances_list.count(i), end='\t')
        print()
        print(trees[max_idx].to_conll(10))

        return sens_idx, distances

    def load_trees(self, file_ids, output_path):
        trees = []
        in_file = output_path + '.mrg'
        out_file = output_path + '.conll'
        ctree_file = open(in_file, 'w')
        for i, file_id in enumerate(file_ids):
            with open(file_id, 'r') as f_in:
                str = f_in.read()
                print(str.strip(), file=ctree_file)

            if i % 1000 == 0:
                with open('err.out', 'w') as err_file:
                    subprocess.call(
                        'java -mx1g -cp "/home/hmwv1114/SciSoft/stanford-corenlp/*" '
                        'edu.stanford.nlp.trees.ud.UniversalDependenciesConverter '
                        '-treeFile %s > %s' % (in_file, out_file), stderr=err_file, shell=True)

                with open(out_file, 'r') as dtree_file:
                    output = dtree_file.read().strip()
                outputs = output.split('\n\n')
                trees.extend([DependencyGraph(o, top_relation_label='root') for o in outputs])

                ctree_file = open(in_file, 'w')

        if i % 1000 > 0:
            with open('err.out', 'w') as err_file:
                subprocess.call(
                    'java -mx1g -cp "/home/hmwv1114/SciSoft/stanford-corenlp/*" '
                    'edu.stanford.nlp.trees.ud.UniversalDependenciesConverter '
                    '-treeFile %s > %s' % (in_file, out_file), stderr=err_file, shell=True)

            with open(out_file, 'r') as out_file:
                output = out_file.read().strip()
            outputs = output.split('\n\n')
            trees.extend([DependencyGraph(o, top_relation_label='root') for o in outputs])

        return trees

if __name__ == '__main__':
    path_1987 = 'data/LDC2000T43/1987/W7_%03d'
    path_1988 = 'data/LDC2000T43/1988/W8_%03d'
    path_1989 = 'data/LDC2000T43/1989/W9_%03d'

    train_path_XS = [path_1987 % id for id in [71, 122]] + \
                    [path_1988 % id for id in [54, 107]] + \
                    [path_1989 % id for id in [28, 37]]

    train_path_SM = [path_1987 % id for id in [35, 43, 48, 54, 61, 71, 77, 81, 96, 122]] + \
                    [path_1988 % id for id in [24, 54, 55, 59, 69, 73, 76, 79, 90, 107]] + \
                    [path_1989 % id for id in [12, 13, 15, 18, 21, 22, 28, 37, 38, 39]]

    train_path_MD = [path_1987 % id for id in [5, 10, 18, 21, 22, 26, 32, 35, 43, 47, 48, 49, 51, 54, 55, 56, 57, 61, 62, 65, 71, 77, 79, 81, 90, 96, 100, 105, 122, 125]] + \
                    [path_1988 % id for id in [12, 13, 14, 17, 23, 24, 33, 39, 40, 47, 48, 54, 55, 59, 69, 72, 73, 76, 78, 79, 83, 84, 88, 89, 90, 93, 94, 96, 102, 107]] + \
                    [path_1989 % id for id in range(12, 42)]

    train_path_LG = [path_1987 % id for id in range(3, 128)] + \
                    [path_1988 % id for id in range(3, 109)] + \
                    [path_1989 % id for id in range(12, 42)]

    train_path = train_path_SM
    valid_path = ['data/LDC2000T43/1987/W7_001', 'data/LDC2000T43/1988/W8_001', 'data/LDC2000T43/1989/W9_010']
    test_path = ['data/LDC2000T43/1987/W7_002', 'data/LDC2000T43/1988/W8_002', 'data/LDC2000T43/1989/W9_011']

    train_file_ids = []
    for path in train_path:
        train_file_ids.extend([os.path.join(path, file_name) for file_name in os.listdir(path)])
    valid_file_ids = []
    for path in valid_path:
        valid_file_ids.extend([os.path.join(path, file_name) for file_name in os.listdir(path)[:500]])
    test_file_ids = []
    for path in test_path:
        test_file_ids.extend([os.path.join(path, file_name) for file_name in os.listdir(path)[:1000]])

    datapath = 'data/bllip-sm/bllip-sm'
    dict_path = datapath + '-dict.pkl'

    CONLL2PKL_Creator(train_file_ids, file_path=datapath + '-train',
                      dict_path=dict_path, build_dict=True)
    CONLL2PKL_Creator(valid_file_ids, file_path=datapath + '-dev',
                      dict_path=dict_path, )
    CONLL2PKL_Creator(test_file_ids, file_path=datapath + '-test',
                      dict_path=dict_path, )
