import pickle
import subprocess
from os import path

import numpy
from nltk import DependencyGraph
from nltk.corpus import ptb

from helpers import Dictionary
from tree_utiles import dtree2ctree, ctree2last, ctree2distance


class CONLL2PKL_Creator(object):
    '''Data path is assumed to be a directory with
       pkl files and a corpora subdirectory.

    rules:
    tokens to delete:
    '''

    def __init__(self,
                 mrg,
                 file_path,
                 dict_path,
                 build_dict=False):

        pickle_filepath = file_path + '.pkl'
        text_filepath = file_path + '.txt'
        self.build_dict = build_dict

        if build_dict:
            self.dictionary = Dictionary()
        else:
            self.dictionary = pickle.load(open(dict_path, 'rb'))

        print("loading trees ...")
        trees = self.load_trees(mrg, file_path)

        print("preprocessing ...")
        self.data = self.preprocess(trees, text_filepath)

        with open(pickle_filepath, "wb") as file_data:
            pickle.dump(self.data, file_data)
        if build_dict:
            with open(dict_path, "wb") as file_data:
                pickle.dump(self.dictionary, file_data)

    def preprocess(self, parse_trees, text_filepath):
        sens_idx = []
        sens = []
        trees = []
        distances = []
        distances_list = []
        text_file = open(text_filepath, "w")

        print('\nConverting trees ...')
        for i, tree in enumerate(parse_trees):
            # modify the tree before subsequent steps:
            for ind in range(1, len(tree.nodes)):
                node = tree.nodes[ind]
                head = node['head']
                label = node['tag']

                # delete unwanted tokens
                if label == '-NONE-':
                    tree.nodes[head]['deps'][node['rel']].remove(ind)
                    for rel, deps in node['deps'].items():
                        for dep in deps:
                            tree.nodes[head]['deps'][rel].append(dep)
                            tree.nodes[dep]['head'] = head
                    tree.remove_by_address(ind)

            if len(tree.nodes) == 1:
                continue  # skip the null trees (after deletion)

            word_lexs = [node['word'] for node in sorted(tree.nodes.values(), key=lambda v: v['address'])]
            word_lexs = ['<s>'] + word_lexs[1:]
            if self.build_dict:
                for word in word_lexs:
                    self.dictionary.add_word(word)

            # distances_sent = dtree2distance(tree)
            ctree = dtree2ctree(tree)
            distances_sent = ctree2distance(ctree, idx=0)
            distances_sent = [0] + distances_sent
            assert len(distances_sent) == len(word_lexs)

            lasts = ctree2last(ctree, 0)
            lasts = [0] + lasts
            assert len(lasts) == len(word_lexs)

            # text_file.write(str(distances_sent) + '\n')
            # text_file.write(str(lasts) + '\n')
            # # text_file.write(str(last2distance(lasts)) + '\n')
            # text_file.write(str(ctree) + '\n\n')

            sens.append(word_lexs)
            trees.append(tree)
            distances.append(lasts)
            distances_list.append(max(distances_sent))

            if i % 10 == 0:
                print("Done %d/%d\r" % (i, len(parse_trees)), end='')

        if self.build_dict:
            self.dictionary.rebuild_by_freq(thd=2)

        for sent in sens:
            idx = []
            processed_words = []
            for loc, word in enumerate(sent):
                if self.build_dict:
                    self.dictionary.add_unk(word, loc)
                index = self.dictionary.get_idx(word, loc)
                idx.append(index)
                processed_words.append(self.dictionary.idx2word[index])

            # text_file.write(tree.to_conll(10))
            text_file.write(' '.join(processed_words) + '\n')
            sens_idx.append(idx)

        text_file.close()

        max_idx = numpy.argmax(distances_list)
        max_depth = max(distances_list)
        print('Max sentence: ', ' '.join(sens[max_idx]))
        print('Max depth: ', max(distances_list))
        print('Mean depth: ', numpy.mean(distances_list))
        print('Median depth: ', numpy.median(distances_list))
        for i in range(1, max_depth + 1):
            print(distances_list.count(i), end='\t')
        print()
        print(trees[max_idx].to_conll(10))

        return sens_idx, distances

    def load_trees(self, file_ids, output_path):
        in_file = output_path + '.mrg'
        out_file = output_path + '.conll'

        if not path.exists(out_file):
            with open(in_file, 'w') as ctree_file:
                for id in file_ids:
                    with open('/home/hmwv1114/nltk_data/corpora/ptb/' + id, 'r') as f_in:
                        str = f_in.read()
                    print(str.strip(), file=ctree_file)

            with open('err.out', 'w') as err_file:
                subprocess.call(
                    'java -mx1g -cp "/home/hmwv1114/SciSoft/stanford-corenlp/*" '
                    'edu.stanford.nlp.trees.ud.UniversalDependenciesConverter '
                    '-treeFile %s > %s' % (in_file, out_file), stderr=err_file, shell=True)

        with open(out_file, 'r') as out_file:
            output = out_file.read().strip()
        outputs = output.split('\n\n')
        trees = [DependencyGraph(o, top_relation_label='root') for o in outputs]
        return trees


if __name__ == '__main__':
    full_file_ids = ptb.fileids()
    train_file_ids = []
    valid_file_ids = []
    test_file_ids = []
    for id in full_file_ids:
        if 'WSJ/02/WSJ_0200.MRG' <= id <= 'WSJ/21/WSJ_2199.MRG':
            train_file_ids.append(id)
        elif 'WSJ/24/WSJ_2400.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG':
            valid_file_ids.append(id)
        elif 'WSJ/23/WSJ_2300.MRG' <= id <= 'WSJ/23/WSJ_2399.MRG':
            test_file_ids.append(id)

    datapath = 'data/rnng/en_ptb-ud'
    dict_path = datapath + '-dict.pkl'

    CONLL2PKL_Creator(train_file_ids, file_path=datapath + '-train',
                      dict_path=dict_path, build_dict=True)
    CONLL2PKL_Creator(valid_file_ids, file_path=datapath + '-dev',
                      dict_path=dict_path, )
    CONLL2PKL_Creator(test_file_ids, file_path=datapath + '-test',
                      dict_path=dict_path, )

