import numpy as np
import random
import os
import pandas as pd

from sister_terms_similarity.pedersen_similarities import ReaderSynsetOOVCouple
from utility.randomfixedseed import Random

from preprocessing.w2v_preprocessing_embedding import POSAwarePreprocessingWord2VecEmbedding


def save_test(paths, test_size=0.10, test_path='data/definitions_test'):
    output_paths = []

    reduced_path = 'data/example_reduced_training'
    if not os.path.exists(reduced_path):
        os.mkdir(reduced_path)
        os.mkdir(os.path.join(reduced_path, 'n'))
        os.mkdir(os.path.join(reduced_path, 'v'))
        os.mkdir(os.path.join(reduced_path, 'a'))

    if not os.path.exists(test_path):
        os.mkdir(test_path)
        os.mkdir(os.path.join(test_path, 'n'))
        os.mkdir(os.path.join(test_path, 'v'))
        os.mkdir(os.path.join(test_path, 'a'))

    print('TEST')
    for path in paths:
        with open(path, 'r') as file:
            input_lines = file.readlines()

        Random.set_seed(19)
        test_lines, train_lines = Random.split(input_lines, test_size)

        output_path = os.path.join(reduced_path, os.path.basename(os.path.dirname(path)), os.path.basename(path))
        with open(output_path, 'w+') as o:
            o.writelines(train_lines)
        output_paths.append(output_path)

        test_output_path = os.path.join(test_path, os.path.basename(os.path.dirname(path)), os.path.basename(path))
        print(test_output_path)
        with open(test_output_path, 'w+') as o:
            print(os.path.basename(os.path.dirname(path)), len(test_lines))
            o.writelines(test_lines)

    print("training: ".upper(), output_paths)
    return output_paths


class POSAwareExampleToNumpy:
    def __init__(self, data=None, target=None, target_pos=None, w1_pos=None, w2_pos=None):
        self.data = data if data is not None else []
        self.target = target if target is not None else []
        if target_pos is None or w1_pos is None or w2_pos is None:
            self.target_pos = []
            self.w1_pos = []
            self.w2_pos = []
        else:
            self.target_pos = target_pos
            self.w1_pos = w1_pos
            self.w2_pos = w2_pos

    def add_example(self, example):
        if 'target' not in example or 'data' not in example or 'target_pos' not in example or 'w1_pos' not in example or 'w2_pos' not in example:
            raise Exception(example)
        if len(example['data']) != 2:
            raise Exception(example)

        self.data.append(np.array(example['data']))
        self.target.append(np.array(example['target']))
        self.target_pos.append(np.array(example['target_pos']))
        self.w1_pos.append(np.array(example['w1_pos']))
        self.w2_pos.append(np.array(example['w2_pos']))

    def save_numpy_examples(self, path):
        print("Saving at "+path+" "+str(len(self.target))+" examples")
        np.savez(path, data=self.data, target_pos=self.target_pos, target=self.target,
                 w1_pos=self.w1_pos, w2_pos=self.w2_pos)


class ExampleWriter:
    def __init__(self, example_paths, separator, output_path, preprocessor):
        self.example_paths = example_paths
        self.separator = separator
        self.output_path = output_path
        self.preprocessor = preprocessor

    def write_w2v_examples(self):
        raise NotImplemented()


class POSAwareExampleWriter(ExampleWriter):
    def __init__(self, example_paths, separator, output_path, preprocessor: POSAwarePreprocessingWord2VecEmbedding,
                 words_in_test=None):
        super().__init__(example_paths, separator, output_path, preprocessor)

    def write_w2v_examples(self, words_in_test=None):
        if words_in_test is None:
            words_in_test = {}

        saver = POSAwareExampleToNumpy()
        for path in self.example_paths:
            df = pd.read_csv(path, sep="\t")
            print("=========="+path+"==========")
            print(df.head())
            first = True
            for row in df.to_dict(orient='records'):
                if first:
                    first = False
                    continue
                if row["target"] not in words_in_test:
                    example = self.preprocessor.get_vector_example([row["target"], row["w1"], row["w2"]],
                                                                   {'target_pos': row["target_pos"],
                                                                    'w1_pos': row["w1_pos"],
                                                                    'w2_pos': row["w2_pos"]})
                    if example is not None:
                        saver.add_example(example)

        saver.save_numpy_examples(self.output_path)


def write_w2v_exaples_from_to(paths, output_path, tagset, save_for_test=False, words_in_test=None):
    if words_in_test is None:
        words_in_test = {}
    if save_for_test:
        paths = save_test(paths, test_size=0.20, test_path='data/definitions_test')

    writer = POSAwareExampleWriter(example_paths=paths, separator='\t', output_path=output_path,
                                   preprocessor=POSAwarePreprocessingWord2VecEmbedding(
                                       "data/pretrained_embeddings/GoogleNews-vectors-negative300.bin",
                                       binary=True, tagset=tagset
                                   ))
    writer.write_w2v_examples(words_in_test=words_in_test)


def load_dataset_from(path):
    with np.load(path, allow_pickle=True) as data:
        dataset_data = data['data']
        dataset_target = data['target']
        dataset_target_pos = data['target_pos']
        dataset_w1_pos = data['w1_pos']
        dataset_w2_pos = data['w2_pos']

    dataset = list(zip(dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos))
    random.shuffle(dataset)
    dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos = zip(*dataset)
    return dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos


def split_in(split_test: float, dataset_data, dataset_target, dataset_target_pos, dataset_w1_pos, dataset_w2_pos):
    TEST_SIZE = int(len(dataset_data) * split_test)

    test_data = np.array(dataset_data[0: TEST_SIZE])
    test_target = np.array(dataset_target[0: TEST_SIZE])
    test_target_pos = np.array(dataset_target_pos[0: TEST_SIZE])
    test_w1_pos = np.array(dataset_w1_pos[0: TEST_SIZE])
    test_w2_pos = np.array(dataset_w2_pos[0: TEST_SIZE])

    train_data = np.array(dataset_data[TEST_SIZE:])
    train_target = np.array(dataset_target[TEST_SIZE:])
    train_target_pos = np.array(dataset_target_pos[TEST_SIZE:])
    train_w1_pos = np.array(dataset_w1_pos[TEST_SIZE:])
    train_w2_pos = np.array(dataset_w2_pos[TEST_SIZE:])

    return (test_data, test_target, test_target_pos, test_w1_pos, test_w2_pos), \
           (train_data, train_target, train_target_pos, train_w1_pos, train_w2_pos)


def all_files_of(base):
    input_paths = []
    for file in os.listdir(base):
        input_path = os.path.join(base, file)
        if os.path.isfile(input_path):
            input_paths.append(input_path)
    return input_paths


def write_w2v_example_excluding(test_paths, input_paths, output_path, tagset, save_for_test=False, words_in_test=None):
    if words_in_test is None:
        words_in_test = {}
    s1_index = 0
    w1_index = 1

    s2_index = 8
    w2_index = 9

    first_indexes = [2, 4]
    s_pos_index = 4
    w1_pos = 5
    w2_pos = 6
    exclude_first = True

    test_couples = []
    for test_path in test_paths:
        test_couples.extend(ReaderSynsetOOVCouple.read(test_path, s1_index=s1_index, w1_index=w1_index,
                                                       s2_index=s2_index, w2_index=w2_index,
                                                       first_indexes=first_indexes,
                                                       s_pos_index=s_pos_index, w1_pos=w1_pos, w2_pos=w2_pos,
                                                       exclude_first=exclude_first))

    test_dict = {'_'.join(couple.first): 1 for couple in test_couples}

    new_input_paths = []
    for path in input_paths:
        (directory, name) = os.path.split(path)

        new_input_dir = os.path.join(directory, 'no_sister_terms_test')
        if not os.path.exists(new_input_dir):
            os.mkdir(new_input_dir)
            os.mkdir(os.path.join(new_input_dir, os.path.basename(os.path.dirname(path))))

        new_input_path = os.path.join(new_input_dir, os.path.basename(os.path.dirname(path)),
                                      name.split('.')[0] + '_no_sister_terms_test.txt')
        new_input_paths.append(new_input_path)
        with open(path, 'r+') as file:
            with open(new_input_path, 'w+') as new_input:
                output_lines = []
                lines = file.readlines()
                for line in lines:
                    split = line.split('\t')
                    if '_'.join(split[2:4]) not in test_dict:
                        output_lines.append(line)
                print('original ', len(lines))
                print('reduced ', len(output_lines))
                new_input.writelines(output_lines)
    print(new_input_paths)

    write_w2v_exaples_from_to(new_input_paths, output_path, tagset, save_for_test, words_in_test=words_in_test)
