import collections
import json
import random
from glob import glob
from typing import List

import conllu
import torch
from conllu import parse_incr
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import AutoTokenizer


""" collator: convert items to batches """
class CoNLLCollator(object):
    def __init__(self, tokenizer_pad_id, position_pad_id, label_pad_id, label_dict):
        self.tokenizer_pad_id = tokenizer_pad_id
        self.position_pad_id = position_pad_id
        self.label_pad_id = label_pad_id
        self.label_dict = label_dict
        self.id2label = {k: v for v, k in self.label_dict.items()}
        assert label_pad_id < 0 and position_pad_id < 0, 'Padding IDs should be negative.'

    @classmethod
    def from_json(cls, tokenizer_pad_id, position_pad_id, label_pad_id, label_dict_path):
        label_dict = json.load(open(label_dict_path))
        return cls(tokenizer_pad_id, position_pad_id, label_pad_id, label_dict)

    @classmethod
    def from_state_dict(cls, state_dict):
        return cls(**state_dict)

    def state_dict(self):
        return {
            'tokenizer_pad_id': self.tokenizer_pad_id,
            'position_pad_id': self.position_pad_id,
            'label_pad_id': self.label_pad_id,
            'label_dict': self.label_dict
        }

    def convert_element(self, item):
        sent_ids, sent = item
        if sent_ids.shape[0] == 2:  # sentence too long, use placeholder
            heads, labels = [self.position_pad_id, self.position_pad_id], [self.label_pad_id, self.label_pad_id]
        else:
            heads, labels = [self.position_pad_id], [self.label_pad_id]
            for w in sent:
                if isinstance(w['id'], int):
                    heads.append(w['head'] if isinstance(w['head'], int) else self.position_pad_id)
                    labels.append(self.label_dict.get(w['deprel'], self.label_pad_id))
            heads.append(self.position_pad_id)
            labels.append(self.label_pad_id)
        return sent_ids, torch.tensor(heads).long(), torch.tensor(labels).long(), sent
    
    def batch_compress(self, sents, pad_id=None):
        if pad_id is None:
            pad_id = self.tokenizer_pad_id
        max_len = max(sent.shape[0] for sent in sents)
        max_token_len = max(sent.shape[1] for sent in sents)
        batch_shape = len(sents), max_len, max_token_len
        token_ids = torch.ones(*batch_shape).long() * pad_id
        for i, sent in enumerate(sents):
            token_ids[i, :sent.shape[0], :sent.shape[1]] = sent
        masks = token_ids.ne(pad_id)
        return token_ids, masks

    def __call__(self, batch):
        sents, heads, labels, info = zip(*batch)
        token_ids, masks = self.batch_compress(sents)
        heads = pad_sequence(heads, batch_first=True, padding_value=self.position_pad_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_id)
        return token_ids, masks, heads, labels, info


""" soft collator: convert soft training instances to batches, can be initialized with hard collator """
class SoftCollator(CoNLLCollator):
    def __init__(self, *args, **kwargs):
        super(SoftCollator, self).__init__(*args, **kwargs)

    @classmethod
    def from_state_dict(cls, state_dict):
        return cls(**state_dict)

    def convert_element(self, item):
        sent_ids, sent_ids_para, alignment = item
        alignment_tensor = torch.zeros(sent_ids.shape[0], sent_ids_para.shape[0])
        for x, y in alignment:
            alignment_tensor[x+1, y+1] = 1
        return sent_ids, sent_ids_para, alignment_tensor

    def __call__(self, batch):
        sents, sents_para, alignments = zip(*batch)
        token_ids_s, masks_s = self.batch_compress(sents)
        token_ids_p, masks_p = self.batch_compress(sents_para)
        alignments, _ = self.batch_compress(alignments, 0)
        return token_ids_s, masks_s, token_ids_p, masks_p, alignments


""" dataset definition """
class CoNLLDataset(Dataset):
    def __init__(self, tokenizer, sents):
        super(CoNLLDataset, self).__init__()
        self.tokenizer = tokenizer
        max_seq_len = tokenizer.model_max_length
        bos_token = tokenizer.bos_token or tokenizer.cls_token
        eos_token = tokenizer.eos_token or tokenizer.sep_token
        self.data = list()
        for sent in tqdm(sents):
            inputs = self.tokenizer(
                [bos_token] + [w['form'] for w in sent if isinstance(w['id'], int)] + [eos_token],
                padding=True, return_tensors='pt', add_special_tokens=False
            )
            token_ids, mask = inputs['input_ids'], inputs['attention_mask']
            if mask.sum().item() <= max_seq_len:
                self.data.append((token_ids, sent))
            else:
                inputs = self.tokenizer([bos_token] + [eos_token], padding=True, return_tensors='pt', add_special_tokens=False)
                self.data.append((inputs['input_ids'], sent))
    
    def convert(self, collator):
        for i, item in enumerate(self.data):
            self.data[i] = collator.convert_element(item)
    
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


class TextDataset(CoNLLDataset):
    def __init__(self, tokenizer, sents):
        super(TextDataset, self).__init__(tokenizer, sents)

    def convert(self, collator):
        for i, item in enumerate(self.data):
            if item[0].shape[0] == 2:
                heads = torch.tensor([collator.position_pad_id] * 2).long()
                labels = torch.tensor([collator.label_pad_id] * 2).long()
            else:
                heads = torch.tensor([collator.position_pad_id] * (len(item[-1]) + 2)).long()
                labels = torch.tensor([collator.label_pad_id] * (len(item[-1]) + 2)).long()
            self.data[i] = (item[0], heads, labels, item[-1])


""" soft dataset definition """
class SoftDataset(Dataset):
    def __init__(self, tokenizer, sents, parallel_sents, alignments, reverse_lang):
        super(SoftDataset, self).__init__()
        assert len(sents) == len(parallel_sents)
        self.tokenizer = tokenizer
        max_seq_len = tokenizer.model_max_length
        bos_token = tokenizer.bos_token or tokenizer.cls_token
        eos_token = tokenizer.eos_token or tokenizer.sep_token
        self.data = list()
        for i, sent in enumerate(tqdm(sents)):
            inputs = self.tokenizer(
                [bos_token] + sent + [eos_token],
                padding=True, return_tensors='pt', add_special_tokens=False
            )
            token_ids, mask = inputs['input_ids'], inputs['attention_mask']
            inputs_para = self.tokenizer(
                [bos_token] + parallel_sents[i] + [eos_token], 
                padding=True, return_tensors='pt', add_special_tokens=False
            )
            token_ids_para, mask_para = inputs_para['input_ids'], inputs_para['attention_mask']
            if mask.sum().item() <= max_seq_len and mask_para.sum().item() <= max_seq_len:
                if reverse_lang:
                    reversed_aligns = [[y, x] for x, y in alignments[i]]
                    self.data.append((token_ids_para, token_ids, reversed_aligns))
                else:
                    self.data.append((token_ids, token_ids_para, alignments[i]))
            else:
                inputs = self.tokenizer([bos_token] + [eos_token], padding=True, return_tensors='pt', add_special_tokens=False)
                self.data.append((inputs['input_ids'], inputs['input_ids'], []))
    
    def convert(self, collator):
        for i, item in enumerate(self.data):
            self.data[i] = collator.convert_element(item)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


""" length based sampler for efficient training """
class SequenceLengthSampler(Sampler[int]):
    def __init__(self, data_source):
        super(SequenceLengthSampler, self).__init__(data_source)
        self.data_source = data_source
        self.length2ids = collections.defaultdict(list)
        for i, item in enumerate(self.data_source):
            length = item[0].shape[0]
            self.length2ids[length].append(i)
    
    def __len__(self):
        return len(self.data_source)

    def __iter__(self):
        ids = list()
        for key in sorted(self.length2ids.keys()):
            random.shuffle(self.length2ids[key])
            ids.extend(self.length2ids[key])
        return iter(ids)


""" length based soft sampler for efficient training """
class SoftSequenceLengthSampler(Sampler[int]):
    def __init__(self, data_source):
        super(SoftSequenceLengthSampler, self).__init__(data_source)
        self.data_source = data_source
        self.length2ids = collections.defaultdict(list)
        for i, item in enumerate(self.data_source):
            length = max(item[0].shape[0], item[1].shape[0])
            self.length2ids[length].append(i)
    
    def __len__(self):
        return len(self.data_source)

    def __iter__(self):
        ids = list()
        for key in sorted(self.length2ids.keys()):
            random.shuffle(self.length2ids[key])
            ids.extend(self.length2ids[key])
        return iter(ids)


class SequenceLengthBatchSampler(Sampler[List[int]]):
    def __init__(self, sampler, batch_size):
        super(SequenceLengthBatchSampler, self).__init__(None)
        self.sampler = sampler
        self.batch_size = batch_size
        self.length = 0
        current_bsz = 0
        for idx in self.sampler:
            if current_bsz + self.sampler.data_source[idx][0].shape[0] > self.batch_size:
                current_bsz = self.sampler.data_source[idx][0].shape[0]
                self.length += 1
            else:
                current_bsz += self.sampler.data_source[idx][0].shape[0]
        self.length += 1
    
    def __iter__(self):
        current_bsz = 0
        current_batch = list()
        batches = list()
        for idx in self.sampler:
            if current_bsz + self.sampler.data_source[idx][0].shape[0] > self.batch_size:
                batches.append(current_batch)
                current_batch = [idx]
                current_bsz = self.sampler.data_source[idx][0].shape[0]
            else:
                current_batch.append(idx)
                current_bsz += self.sampler.data_source[idx][0].shape[0]
        batches.append(current_batch)
        random.shuffle(batches)
        return iter(batches)

    def __len__(self):
        return self.length


""" data environment: dataset + collator """
class CoNLLDatasetCollection(object):
    def __init__(
                self, path_template, model_name, label_dict_path=None, encoding='utf-8', 
                position_pad_id=-1, label_pad_id=-1, collator_state_dict=None, mode='train', input_format='conll', n_examples=50000,
                shuffle=False, **kwargs
            ):
        self.path_template = path_template
        self.model_name = model_name
        self.encoding = encoding
        self.tokenizer = AutoTokenizer.from_pretrained(model_name if model_name != 'criss' else 'facebook/mbart-large-cc25')
        self.dataset = collections.defaultdict()
        self.mode = mode
        self.input_format = input_format
        self.n_examples = n_examples
        self.shuffle = shuffle
        self.splits = self.__get_splits__(mode)
        if input_format == 'txt': 
            assert mode == 'predict', 'Text input only available for prediction'
        for split in self.splits:
            path_template_split = self.path_template.format(split=split)
            sents = list()
            for filename in glob(path_template_split):
                if input_format == 'conll':
                    sents += list(parse_incr(open(filename, encoding=encoding)))
                else:
                    for i, line in enumerate(open(filename, encoding=encoding)):
                        if i == n_examples:
                            break
                        words = line.strip().split()
                        sent = conllu.TokenList(
                            [
                                {
                                    'id': i + 1, 
                                    'form': w,
                                    'lemma': '_', 
                                    'upos': '_',
                                    'xpos': '_', 
                                    'feats': '_',
                                    'head': '_',
                                    'deprel': '_',
                                    'deps': '_',
                                    'misc': '_'
                                } for i, w in enumerate(words)
                            ]
                        )
                        sents.append(sent)
            if split == 'train' and shuffle:
                random.shuffle(sents)
            if split != 'test' or mode == 'predict':
                sents = sents[:n_examples]
            self.dataset[split] = CoNLLDataset(self.tokenizer, sents) if input_format == 'conll' else TextDataset(self.tokenizer, sents)
        if collator_state_dict is not None:
            self.collator = CoNLLCollator.from_state_dict(collator_state_dict)
        else:
            assert label_dict_path is not None
            self.collator = CoNLLCollator.from_json(self.tokenizer.pad_token_id, position_pad_id, label_pad_id, label_dict_path)
        for split in self.splits:
            self.dataset[split].convert(self.collator)

    @staticmethod
    def __get_splits__(mode):
        if mode == 'train':
            return ['train', 'dev', 'test']
        elif mode in ['evaluate', 'predict']:
            return ['test']
        else:
            raise Exception(f'Mode {mode} not supported.')

    def __getitem__(self, index):
        return self.dataset[index]
    
    @classmethod
    def from_state_dict(cls, state_dict):
        return cls(**state_dict)

    def state_dict(self):
        return {
            'path_template': self.path_template,
            'model_name': self.model_name,
            'encoding': self.encoding,
            'collator_state_dict': self.collator.state_dict(),
            'mode': self.mode, 
            'input_format': self.input_format,
            'n_examples': self.n_examples,
            'shuffle': self.shuffle
        }


class SoftDatasetCollection(object):
    def __init__(
                self, path_template, model_name, collator_state_dict, encoding='utf-8', mode='train', reverse_lang=False,
                n_examples_train=int(1e10), shuffle=False
            ):
        self.path_template = path_template
        self.model_name = model_name
        self.encoding = encoding
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.dataset = collections.defaultdict()
        self._collator = collections.defaultdict()
        self.mode = mode
        self.reverse_lang = reverse_lang
        self.n_examples_train = n_examples_train
        self.shuffle = shuffle
        self.splits = self.__get_splits__(mode)
        for split in self.splits:
            path_template_split = self.path_template.format(split=split)
            if split == 'test':
                sents = list()
                for filename in glob(path_template_split):
                    assert filename.endswith('conll') or filename.endswith('conllu')
                    sents += list(parse_incr(open(filename, encoding=encoding)))
                self.dataset[split] = CoNLLDataset(self.tokenizer, sents)
                self._collator[split] = CoNLLCollator.from_state_dict(collator_state_dict)
            else:
                sents_info = list()
                for filename in glob(path_template_split):
                    assert filename.endswith('json') or filename.endswith('jsonl')
                    for line in open(filename, encoding=encoding):
                        item = json.loads(line)
                        keys = list(item[-1].keys())
                        key = keys[0] if len(keys) == 1 else 'inter'
                        sents_info.append((item[0], item[1], item[-1][key]))
                if shuffle:
                    random.shuffle(sents_info)
                sents, sents_para, aligns = zip(*sents_info)
                sents = sents[:n_examples_train]
                sents_para = sents_para[:n_examples_train]
                aligns = aligns[:n_examples_train]
                self.dataset[split] = SoftDataset(self.tokenizer, sents, sents_para, aligns, reverse_lang)
                self._collator[split] = SoftCollator.from_state_dict(collator_state_dict)
            self.dataset[split].convert(self._collator[split])
        self.collator = self._collator['train']

    @staticmethod
    def __get_splits__(mode):
        if mode == 'train':
            return ['train', 'dev', 'test']
        elif mode in ['evaluate', 'predict']:
            return ['test']
        else:
            raise Exception(f'Mode {mode} not supported.')

    def __getitem__(self, index):
        return self.dataset[index]
       
    @classmethod
    def from_state_dict(cls, state_dict):
        return cls(**state_dict)

    def state_dict(self):
        return {
            'path_template': self.path_template,
            'model_name': self.model_name,
            'collator_state_dict': self.collator.state_dict(),
            'encoding': self.encoding,
            'mode': self.mode, 
            'reverse_lang': self.reverse_lang,
            'n_examples_train': self.n_examples_train,
            'shuffle': self.shuffle
        }


if __name__ == '__main__':
    # soft dataset collection 
    from torch.utils.data import DataLoader
    datasetc = CoNLLDatasetCollection(
        '../../data/ud-treebanks-v2.2/selected/en/en-{split}.conll', 'xlm-roberta-base', 'metadata/ud26.simp.json'
    ) 
    collator = datasetc.collator
    datasetc = SoftDatasetCollection('../../data/wikimatrix/aligns/br-en.json', 'xlm-roberta-base', collator.state_dict())
    for batch in datasetc['train']:
        from IPython import embed; embed(using=False)
    
    # hard dataset collection 
    from torch.utils.data import DataLoader
    dataloader = DataLoader(datasetc['train'], batch_size=32, shuffle=False, collate_fn=datasetc.collator)
    for batch in dataloader:
        from IPython import embed; embed(using=False)
