import os
from itertools import repeat, chain
from multiprocessing import Pool

import spacy
from torchtext import data
from torchtext.vocab import Vectors

from util import timing

PATH_TO_PUBLIC = '../..'


@timing
def get_spacy():
    return spacy.load('en_core_web_sm')


def get_news_args(cfg, field):
    args = []
    for split in ['train', 'test']:
        for i, task in enumerate(cfg.data['tasks']):
            for j, topic in enumerate(cfg.data['topics'][i]):
                args.append((f'{PATH_TO_PUBLIC}/dataset/20newsgroups/20news-bydate-{split}/{task}.{topic}', field, j, i))
    return args


def tokenizer(s):
    return [tok.text for tok in nlp.tokenizer(s) if tok.text != " "]


def load_data_news(args):
    p, field, topic_label, task_label = args
    examples = []
    for fname in os.listdir(p):
        with open(os.path.join(p, fname), encoding="ISO-8859-1") as f:
            examples.append(data.Example.fromlist([f.read(), topic_label], field))
    return examples


def load_data_sentiment(p, field):
    with open(p, encoding="ISO-8859-1") as f:
        examples = [data.Example.fromlist([line[1:], line[0]], field) for line in f]
        return examples


@timing
def load_all_data(cfg, paths, field):
    if cfg.data['name'] == 'sentiment':
        return [data.Dataset(load_data_sentiment(p, field), field) for p in paths]
    elif cfg.data['name'] == '20news':
        args = get_news_args(cfg, field)
        list_data = [load_data_news(arg) for arg in args]
        list_data = [list(chain(*[list_data[i*4+j] for j in range(4)])) for i in range(8)]
        list_data = [data.Dataset(examples, field) for examples in list_data]
        return list_data
    else:
        raise RuntimeError('dataset not implemented')


@timing
def load_all_data_parallel(cfg, paths, field):
    p = Pool(cfg.data['processes'])
    if cfg.data['name'] == 'sentiment':
        list_data = p.starmap(load_data_sentiment, zip(paths, repeat(field)))
    elif cfg.data['name'] == '20news':
        args = get_news_args(cfg, field)
        list_data = p.map(load_data_news, args)
        list_data = [list(chain(*[list_data[i*4+j] for j in range(4)])) for i in range(8)]
    else:
        RuntimeError('dataset not implemented')
    list_data = [data.Dataset(examples, field) for examples in list_data]
    return list_data


@timing
def build_vocab(field, all_data, d_name):
    field.build_vocab(*all_data, vectors=Vectors(f'{PATH_TO_PUBLIC}/model/glove.840B.300d.txt', cache=f'.vector_cache_{d_name}'))


class SentimentDataset(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.tasks = self.cfg.data['tasks']
        TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=self.cfg.model.max_sen_len)
        LABEL = data.Field(sequential=False, use_vocab=False)
        TASK = data.Field(sequential=False, use_vocab=False)
        field = [('text', TEXT), ('label', LABEL)]
        all_field = [('text', TEXT), ('label', LABEL), ('task', TASK)]
        paths = [f'{PATH_TO_PUBLIC}/dataset/mtl-dataset/{t}.task.{s}' for t in self.tasks for s in ['train', 'test']]
        self.data = load_all_data_parallel(self.cfg, paths, field)
        self.train = {t: data.BucketIterator((self.data[2*i]), batch_size=self.cfg.data['batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=True) for (i, t) in enumerate(self.tasks)}
        self.test = {t: data.BucketIterator((self.data[2*i+1]), batch_size=self.cfg.data['batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=False) for (i, t) in enumerate(self.tasks)}
        all_examples = []
        for i, t in enumerate(self.tasks):
            for e in self.data[2 * i].examples:
                all_examples.append(data.Example.fromlist([e.text, e.label, i], all_field))
        self.all = data.BucketIterator(data.Dataset(all_examples, all_field), batch_size=self.cfg.data['batch_size'] * len(self.tasks), sort_key=lambda x: len(x.text), repeat=False, shuffle=True)
        build_vocab(TEXT, self.data, 'sentiment')
        self.vocab = TEXT.vocab
        self.word_embeddings = TEXT.vocab.vectors


class NewsDataset(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.tasks = self.cfg.data['tasks']
        TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=self.cfg.model.max_sen_len)   # sequtial=True is what
        LABEL = data.Field(sequential=False, use_vocab=False)
        TASK = data.Field(sequential=False, use_vocab=False)
        field = [('text', TEXT), ('label', LABEL)]
        all_field = [('text', TEXT), ('label', LABEL), ('task', TASK)]
        self.data = load_all_data_parallel(self.cfg, None, field)
        self.train = {t: data.BucketIterator(self.data[i], batch_size=self.cfg.data['batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=True) for (i, t) in enumerate(self.tasks)}
        self.test = {t: data.BucketIterator(self.data[i+4], batch_size=self.cfg.data['batch_size'], sort_key=lambda x: len(x.text), repeat=False, shuffle=False) for (i, t) in enumerate(self.tasks)}
        all_examples = []
        for i, t in enumerate(self.tasks):
            for e in self.data[i].examples:
                all_examples.append(data.Example.fromlist([e.text, e.label, i], all_field))
        self.all = data.BucketIterator(data.Dataset(all_examples, all_field), batch_size=self.cfg.data['batch_size'] * len(self.tasks), sort_key=lambda x: len(x.text), repeat=False, shuffle=True)
        build_vocab(TEXT, self.data, '20news')
        self.vocab = TEXT.vocab
        self.word_embeddings = TEXT.vocab.vectors


nlp = get_spacy()
if __name__ == '__main__':
    print('running dataloader file')
