import os
from copy import Error
import json
import numpy as np 
from pathlib import Path
from numpy.core.numeric import indices 
from sklearn.model_selection import train_test_split
import pandas as pd 
import random
import torch
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import default_collate
from transformers import data
import conllu
import re

def get_senteval_data_preprocessed(fpath, args):
    """
    Return:
        train_data: list of {'X': str, 'y': int}
        val_data: list of {'X': str, 'y': int}
        nclasses: int
    """
    if not os.path.exists(fpath):
        raise FileNotFoundError(f'{fpath} not exists.')

    data = torch.load(fpath)
    data['X'] = torch.tensor(data['X'], dtype=torch.float)
    data['y'] = torch.tensor(data['y'], dtype=torch.long)
    nclasses = torch.unique(data['y']).shape[0]


    if args.even_distribute:
        assert args.train_size_per_class * nclasses + args.val_size_per_class * nclasses < len(data['y']), "train and val sizes should add up to be no more than the total num in the data!"

        train_x, other_x, train_y, other_y = train_test_split(
            data['X'], data['y'],
            random_state=args.seed, 
            train_size=args.train_size_per_class * nclasses, 
            shuffle=True,
            stratify=data['y']
        )
        val_x, remain_x, val_y, remain_y = train_test_split(
            other_x, other_y,
            random_state=args.seed,
            train_size=args.val_size_per_class * nclasses,
            shuffle=True,
            stratify=other_y
        )
        test_x, _, test_y, _ = train_test_split(
            remain_x, remain_y,
            random_state=args.seed,
            train_size=args.val_size_per_class * nclasses,
            shuffle=True,
            stratify=remain_y
        )
    else:
        raise ValueError("Only supports even distribution.")


    # Optional: Inject Gaussian noise
    if args.representation_gaussian_noise > 0:
        train_x += torch.normal(0, args.representation_gaussian_noise, size=train_x.size())
        val_x += torch.normal(0, args.representation_gaussian_noise, size=val_x.size())
        test_x += torch.normal(0, args.representation_gaussian_noise, size=test_x.size())
    return TensorDataset(train_x, train_y), TensorDataset(val_x, val_y), TensorDataset(test_x, test_y), nclasses 


def get_olmpics_data(task, args):
    """
    Return:
        train_data: TensorDataset of embs, yids, y for MLM tasks. TensorDataset of embs, y for 
          QA tasks. embs are question or question-choice pair embeddings. yids are token
          ids of the choices. y are the index of the correct choice.
        val_data: Same with above but for validation set.
        nclasses: Number of classes. For MLM tasks, it's the size of vocabulary.
          for QA tasks, it's the number of choices for each question.
    """
    raise NotImplemented("TODO - also get the test data")
    train_data = olmpics_load_preprocessd(task, 'train')
    val_data = olmpics_load_preprocessd(task, 'dev')
    nclasses = len(train_data['y'].unique())
    
    assert args.train_size_per_class * nclasses < len(train_data['y']), "train and val sizes should add up to be no more than the total num in the data!"
    # assert args.val_size_per_class * nclasses < len(val_data['y']), "train and val sizes should add up to be no more than the total num in the data!"

    train_subset = {k:d[:args.train_size_per_class * nclasses] for k,d in train_data.items() if type(d) is torch.Tensor}
    # val_subset = {k:d[:args.val_size_per_class * nclasses] for k,d in train_data.items() if type(d) is torch.Tensor}
    val_subset = {k:d for k,d in train_data.items() if type(d) is torch.Tensor}

    #TODO: Should we shuffle the data before selecting subset?
    # random.shuffle(train_data)
    # random.shuffle(val_data)
    
    if train_data['task_type'] == 'mlm':
        return TensorDataset(train_subset['embs'], train_subset['yids'], train_subset['y']), \
               TensorDataset(val_subset['embs'], val_subset['yids'], val_subset['y']), \
               train_data['vocab_size'], \
               train_data['task_type']
    else: # train_data['task_type'] == 'qa'
        return TensorDataset(train_subset['embs'], train_subset['y']), \
               TensorDataset(val_subset['embs'], val_subset['y']), \
               len(train_data['y'].unique()), \
               train_data['task_type']

def get_ud_pos_data(args):
    """
    Returns:
        train_data, val_data, test_data: list of {'X': conllu.TokenList, 'loc': int, 'y': int}, following the dataset sizes specified in args
        nclasses: int
    """
    lang = args.task.split("_")[-1] 
    train_all, val_all, test_all, nclasses = ud_pos_loadfile(lang) 
    if args.even_distribute:
        tr_targets = [args.train_size_per_class] * nclasses 
        val_targets = [args.val_size_per_class] * nclasses 
        test_targets = [args.val_size_per_class] * nclasses 
    else:
        tr_targets = args.train_sizes_by_class 
        val_targets = args.val_sizes_by_class
        test_targets = args.val_sizes_by_class
    
    train_data = _select_by_class(train_all, tr_targets)
    val_data = _select_by_class(val_all, val_targets)
    test_data = _select_by_class(test_all, test_targets)
    return train_data, val_data, test_data, nclasses


def _select_by_class(data_all, targets, shuffle=False):
    counters = np.array(targets)
    selected_data = []
    # Don't need to shuffle again if train_test_split has shuffled.
    if shuffle:  # Shuffle the source data, so that the samples come from a diverse range of sentences
        random.shuffle(data_all)
    
    for item in data_all:
        label = item['y']
        if counters[label] > 0:
            selected_data.append(item)
            counters[label] -= 1
        else:
            if counters.sum() == 0:
                break 
    # if counters.sum() > 0:
    #    print ("There isn't enough data to select from")
    # Some tags in UD just doesn't have enough data points. It's ok.
    return selected_data 


def ud_pos_loadfile(lang):
    """
    Returns:
        train_data, val_data, test_data: list of {'X': conllu.TokenList, 'loc': int, 'y': int}
        nclasses: int
    """
    ud_path = "../../data/ud-treebanks-v2.5"
    
    lang_fnames = {
        "basque": "UD_Basque-BDT/eu_bdt-ud-{}.conllu",
        "english": "UD_English-GUM/en_gum-ud-{}.conllu",
        "finnish": "UD_Finnish-FTB/fi_ftb-ud-{}.conllu",
        "marathi": "UD_Marathi-UFAL/mr_ufal-ud-{}.conllu",
        "russian": "UD_Russian-GSD/ru_gsd-ud-{}.conllu",
        "turkish": "UD_Turkish-IMST/tr_imst-ud-{}.conllu"
    }
    
    pos2int = get_ud_pos2int(lang)
    all_data = {"train": [], "dev": [], "test": []}
    for portion in ["train", "dev", "test"]:
        lang_fn =lang_fnames[lang].format(portion)
        text = Path(ud_path, lang_fn).read_text()
        sentences = conllu.parse(text)
        
        for s in sentences:
            # Safety hack: eliminate the sentences containing words like 
            # "well-done" and "that's". When there are punctuations, the 
            # huggingface tokenizer splits the tokens without prepending "##".
            # (since I'm only considering "##" when aligning the tokens) 
            # There are ~1/3 sentences in English and Russian of this type, and
            # around ~1/10 in Basque and Finnish.
            skip_sentence = False
            for token in s:
                if len(token["form"]) > 1 and re.search("[!?/'|-]", token["form"]):
                    skip_sentence = True
            if skip_sentence:
                continue 

            # Then add to the data collection. 
            for i, token in enumerate(s):
                if token['upos'] in pos2int:
                    all_data[portion].append({
                        'X': s, 
                        'loc': i, 
                        'y': pos2int[token['upos']]
                    })
    train_data, val_data, test_data = all_data['train'], all_data['dev'], all_data['test']
    return train_data, val_data, test_data, len(pos2int)


def get_ud_pos2int(lang):
    df = pd.read_csv("../../data/ud_eda_info.csv")
    tag_portions = json.loads(df[(df.lang==lang) & (df.portion=="train")].iloc[0].tag_portions)  # List of (str, int)

    # See notebooks/20200601_UD_exploration.ipynb for the detailed procedure to generate ud_eda_info.csv
    too_few_samples = {  # <3 samples in either {train, dev, test}
        "basque": ["SYM", "INTJ"],
        "english": ["INTJ"],
        "finnish": ["SYM:", "X"],
        "marathi": ["INTJ", "PART"],
        "russian": [],
        "turkish": ["X"]
    }
    
    pos_tags = []
    for i, item in enumerate(tag_portions):
        if (item[0] not in too_few_samples[lang]) and (item[0] not in pos_tags):
            pos_tags.append(item[0])

    pos2int = {pos_tags[i]:i for i in range(len(pos_tags))}
    return pos2int

def get_collate_fn(batcher, use_cuda, task):
    """
    Required when constructing torch.utils.DataLoader.
    """
    device = torch.device("cuda") if use_cuda else torch.device("cpu")

    def _senteval_collate_fn(databatch):
        # xtensor = torch.tensor(batcher([item['X'] for item in databatch])).float()
        # ytensor = torch.tensor([item['y'] for item in databatch]).long()
        return tuple(d.to(device) for d in default_collate(databatch))

    def _ud_collate_fn(databatch):
        xtensor = torch.tensor(batcher([(item['X'],item['loc']) for item in databatch])).float()
        ytensor = torch.tensor([item['y'] for item in databatch]).long()
        return xtensor.to(device), ytensor.to(device)
    
    def _olmpics_collate_fn(databatch):
        transposed = list(zip(*databatch))
        if len(transposed) == 3: # MLM tasks
            xtensor, ylist_tensor, ytensor = transposed
            return torch.stack(xtensor).to(device), \
                    (torch.stack(ylist_tensor).to(device), torch.stack(ytensor).to(device))
        else: # QA tasks
            xtensor, ytensor = transposed
            return torch.stack(xtensor).to(device), torch.stack(ytensor).to(device)

    def cats_collate_fn(databatch):
        pass

    if task.startswith("ud_pos_"):
        return lambda databatch: _ud_collate_fn(databatch)
    elif task.startswith("olmpics_"):
        return lambda databatch: _olmpics_collate_fn(databatch)
    else:
        return _senteval_collate_fn

def senteval_load_file(filepath="../../data/senteval/subj_number.txt"):
    """
    Input:
        filepath. e.g., "<repo_dir>/data/senteval/bigram_shift.txt"
    Return: 
        task_data: list of {'X': str, 'y': int}
        nclasses: int
    """

    # Just load all portions, and then do train/dev/test splitting myself
    tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'}
    task_data=[]
    
    for linestr in Path(filepath).open().readlines():
        line = linestr.rstrip().split("\t")
        task_data.append({
            'X': line[-1], 'y': line[1]
        })

    # Convert labels str to int
    all_labels = [item['y'] for item in task_data]
    labels = sorted(np.unique(all_labels))
    tok2label = dict(zip(labels, range(len(labels))))
    nclasses = len(tok2label) 
    for i, item in enumerate(task_data):
        item['y'] = tok2label[item['y']]
    
    return task_data, nclasses


def olmpics_load_preprocessd(fpath, task, split):
    """Load data file given a oLMpics task name.

    Input:
        task: task name, choosing one from the list: 
        split: data split (train or dev)
    Return:
        task_data: list of {'X': str, 'y': int}
        nclasses: int
    """
    if not os.path.exists(fpath):
        raise Error(f'Preprocessed file {file_path} not exists. Run preprocess_data.py first.')
    else:
        return torch.load(file_path)


def main():
    # For testing.
    import ipdb; ipdb.set_trace()

if __name__ == '__main__':
    main()