import torch

from functools import partial
from torch.utils import data
from torchtext.data import Field, TabularDataset
from transformers import BertTokenizer


class Example(object):
    def __init__(self, ex):
        for k, v in ex.__dict__.items():
            setattr(self, k, v)

    def __getitem__(self, k):
        return self.__dict__[k]

    def __setitem__(self, k, v):
        setattr(self, k, v)

    def __delitem__(self, k):
        del self.__dict__[k]

    def __contains__(self, k):
        return k in self.__dict__

    def get(self, k, default=None):
        return self.__dict__.get(k, default)


class Batch(object):
    def __init__(self, examples, tokenizer, fields, device="cpu"):
        self._examples = examples
        self.tokenizer = tokenizer
        self._fields = fields
        self.device = device

    def __iter__(self):
        return iter(self._examples)

    def __len__(self):
        return len(self._examples)

    def __getitem__(self, idx):
        return self._examples[idx]

    def build(self, tensor=False):
        for name in self._fields:
            f_val = getattr(self._examples[0], name)
            if isinstance(f_val, (list, tuple)):
                if not f_val or not isinstance(f_val[0], str):
                    continue

            f_vals = [getattr(ex, name) for ex in self._examples]
            f_encodes = self.tokenizer(f_vals)
            f_input_ids = f_encodes['input_ids']

            if tensor:
                f_input_ids, f_lengths = pad(f_input_ids, with_length=True, device=self.device)
                setattr(self, f"{name}_length", f_lengths)
            setattr(self, name, f_input_ids)


def mask(lengths, device="cpu"):
    bs = len(lengths)
    max_length = max(lengths)

    batch_ids = torch.arange(max_length).expand(bs, max_length)
    batch_masks = batch_ids < torch.tensor(lengths).unsqueeze(-1)
    batch_masks = batch_masks.type(torch.long).to(device)
    return batch_masks


def isinstanceof(x, types):
    if isinstance(x, torch.Tensor):
        return isinstance(x.item(), types)
    return isinstance(x, types)


# [bs, max_seq]
# [bs, max_s1, max_s2]
def pad_(data, padding=0, with_length=True, device="cpu"):
    length = []
    midx1 = max(range(len(data)), key=data.__getitem__)
    if isinstanceof(data[midx1], (int, float)):
        data = torch.tensor(data, device=device)
        if with_length:
            return data, data.size(0)
        return data

    midx2 = max(range(len(data[midx1])), key=data[midx1].__getitem__)
    if isinstanceof(data[midx1][midx2], (torch.Tensor)):
        data = [item.tolist() for item in data]
    if isinstanceof(data[midx1][midx2], (int, float)):
        md0 = len(data)
        md1 = max([len(d1) for d1 in data])
        for d0 in range(md0):
            ld1 = len(data[d0])
            data[d0] += [padding] * (md1 - ld1)
            length.append(ld1)

        data = torch.tensor(data, device=device)
        if with_length:
            return data, length
        return data

    midx3 = max(range(len(data[midx1][midx2])), key=data[midx1][midx2].__getitem__)
    if isinstanceof(data[midx1][midx2][midx3], (int, float)):
        md0 = len(data)
        md1 = max([len(d1) for d1 in data])
        md2 = max([max([len(d2) for d2 in d1]) for d1 in data])

        for d0 in range(md0):
            ld1 = len(data[d0])
            length_d2 = []
            for d1 in range(ld1):
                ld2 = len(data[d1][d1])
                data[d0][d1] += [0] * (md2 - ld2)
                length_d2.append(ld2)
            padding_list = [padding] * md2
            data[d0] += [padding_list] * (md1 - ld1)
            length_d2 += [0] * (md1 - ld1)
            length.append(length_d2)

        data = torch.tensor(data, device=device)
        if with_length:
            return data, length
        return data

    raise ValueError("data shape must in [bs, m_seq] or [bs, m_sq1, m_sq2]")


def pad(data, padding=0, device="cpu"):
    return pad_(data, padding=padding, with_length=False, device=device)


def encode(inputs, tokenizer=None):
    assert tokenizer is not None
    out = []
    if not inputs:
        return out

    if isinstance(inputs, str):
        out = tokenizer(inputs)['input_ids']
    elif isinstance(inputs[0], str):
        out = tokenizer(inputs)['input_ids']
    elif isinstance(inputs[0], (list, tuple)):
        if isinstance(inputs[0][0], str):
            for txts in inputs:
                txt_encoded = tokenizer(txts)['input_ids']
                out.append(txt_encoded)
    return {'input_ids': out}


def load_data(config):
    if type(config['fields']) == list:
        config['fields'] = {key: (key, {}) for key in config['fields']}

    fields_ = {}
    for key, (name, fargs) in config['fields'].items():
        f = Field(sequential=False, tokenize=None, use_vocab=False)
        fields_[key] = (name, f)

    shuffle = config.get("shuffle", False)
    dataset = {tag: TabularDataset(path=filename, format="json", fields=fields_) for tag, filename in config.get("dataset", {}).items()}
    fields_ = {name for _, (name, _) in fields_.items()}

    if "ptm_name" in config:
        ptm_name = config["ptm_name"]
        tokenizer = BertTokenizer.from_pretrained(ptm_name)
        specials = [f"[unused{i}]" for i in range(0, 100)]
        tokenizer.add_special_tokens({"additional_special_tokens": specials})
        tokenizer_ = partial(tokenizer, add_special_tokens=False, return_token_type_ids=False, return_attention_mask=False)
        encoder = partial(encode, tokenizer=tokenizer_)
    else:
        raise NotImplementedError

    return {
        "vocab": tokenizer.vocab,
        "tokenizer": tokenizer,
        "dataset": dataset,
        "iterator": {
            tag: data.DataLoader(
                [Example(ex) for ex in dataset[tag].examples],
                batch_size=(config['bs'] if tag not in ['dev', 'test'] else 1),
                shuffle=(shuffle if tag not in ['dev', 'test'] else False),
                collate_fn=lambda b: Batch(b, tokenizer=encoder, fields=fields_, device=config.get('device', 'cpu')),
            ) for tag in dataset
        },
        "config": config,
    }
