import pickle
import re

import nltk
import numpy
from nltk.corpus import ptb

from helpers import Dictionary
from tree_utiles import ctree2distance, ctree2last, tree2list


class MRG2PKL_Creator(object):
    '''Data path is assumed to be a directory with
       pkl files and a corpora subdirectory.

    rules:
    tokens to delete:
    '''

    def __init__(self,
                 mrg,
                 file_path,
                 dict_path,
                 build_dict=False):

        pickle_filepath = file_path + '.pkl'
        text_filepath = file_path + '.txt'
        self.build_dict = build_dict

        if build_dict:
            self.dictionary = Dictionary()
        else:
            self.dictionary = pickle.load(open(dict_path, 'rb'))

        print("loading trees ...")
        trees = self.load_trees(mrg)

        print("preprocessing ...")
        self.data = self.preprocess(trees, text_filepath)

        with open(pickle_filepath, "wb") as file_data:
            pickle.dump(self.data, file_data)
        if build_dict:
            with open(dict_path, "wb") as file_data:
                pickle.dump(self.dictionary, file_data)

    def preprocess(self, parse_trees, text_filepath):
        sens_idx = []
        sens = []
        trees = []
        distances = []
        distances_list = []
        text_file = open(text_filepath, "w")

        print('\nConverting trees ...')
        for i, tree in enumerate(parse_trees):
            # modify the tree before subsequent steps:
            for ind, leaf in reversed(list(enumerate(tree.leaves()))):
                label = tree[tree.leaf_treeposition(ind)[:-1]].label()

                # delete unwanted tokens
                if label == '-NONE-':
                    postn = tree.leaf_treeposition(ind)
                    parentpos = postn[:-1]
                    while parentpos and len(tree[parentpos]) == 1:
                        postn = parentpos
                        parentpos = postn[:-1]
                    del tree[postn]

            if len(tree.leaves()) == 0:
                continue  # skip the null trees (after deletion)

            word_lexs, word_tags = zip(*tree.pos())
            word_lexs = ['<s>'] + list(word_lexs)
            if self.build_dict:
                for word in word_lexs:
                    self.dictionary.add_word(word)

            ctree, _, _ = tree2list(tree)

            distances_sent = ctree2distance(tree, idx=0)
            distances_sent = [0] + distances_sent
            assert len(distances_sent) == len(word_lexs)

            lasts = ctree2last(tree, 0)
            lasts = [0] + lasts
            assert len(lasts) == len(word_lexs)

            # text_file.write(str(distances_sent) + '\n')
            # text_file.write(str(lasts) + '\n')
            text_file.write(str(tree) + '\n')
            text_file.write(str(ctree) + '\n\n')

            sens.append(word_lexs)
            trees.append(ctree)
            distances.append(lasts)
            distances_list.append(max(distances_sent))

            if i % 10 == 0:
                print("Done %d/%d\r" % (i, len(parse_trees)), end='')

        if self.build_dict:
            self.dictionary.rebuild_by_freq(thd=2)

        for sent in sens:
            idx = []
            processed_words = []
            for loc, word in enumerate(sent):
                if self.build_dict:
                    self.dictionary.add_unk(word, loc)
                index = self.dictionary.get_idx(word, loc)
                idx.append(index)
                processed_words.append(self.dictionary.idx2word[index])

            # text_file.write(tree.to_conll(10))
            text_file.write(' '.join(processed_words) + '\n')
            sens_idx.append(idx)

        text_file.close()

        max_idx = numpy.argmax(distances_list)
        max_depth = max(distances_list)
        print('Max sentence: ', ' '.join(sens[max_idx]))
        print('Max depth: ', max(distances_list))
        print('Mean depth: ', numpy.mean(distances_list))
        print('Median depth: ', numpy.median(distances_list))
        for i in range(1, max_depth + 1):
            print(distances_list.count(i), end='\t')
        print()

        return sens_idx, distances, sens, trees

    def load_trees(self, file_ids):
        trees = []
        for id in file_ids:
            sentences = ptb.parsed_sents(id)
            for sen_tree in sentences:
                trees.append(sen_tree)
        return trees


if __name__ == '__main__':
    full_file_ids = ptb.fileids()
    train_file_ids = []
    valid_file_ids = []
    test_file_ids = []
    for id in full_file_ids:
        if 'WSJ/02/WSJ_0200.MRG' <= id <= 'WSJ/21/WSJ_2199.MRG':
            train_file_ids.append(id)
        elif 'WSJ/24/WSJ_2400.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG':
            valid_file_ids.append(id)
        elif 'WSJ/23/WSJ_2300.MRG' <= id <= 'WSJ/23/WSJ_2399.MRG':
            test_file_ids.append(id)

    datapath = 'data/ptb/ptb'
    dict_path = datapath + '-dict.pkl'

    MRG2PKL_Creator(train_file_ids, file_path=datapath + '-train',
                    dict_path=dict_path, build_dict=True)
    MRG2PKL_Creator(valid_file_ids, file_path=datapath + '-dev',
                    dict_path=dict_path, )
    MRG2PKL_Creator(test_file_ids, file_path=datapath + '-test',
                    dict_path=dict_path, )
