import random

from torch.utils.data.dataset import Dataset
from transformers import AutoTokenizer
import csv
import torch
import logging
import numpy as np
import json
from functools import lru_cache
from transformers import BertForMaskedLM, BertTokenizer, set_seed
from transformers.pipelines import Pipeline, FillMaskPipeline
from prompt.modeling_bert import BertForMaskedLM_E
from prompt.modeling_bert import BertForMaskedLM_CO
import itertools
import numpy as np


logger = logging.getLogger(__name__)

class pretrainDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, file_path):
        self.tokenizer = tokenizer
        with open(file_path, 'rt', encoding='utf-8') as f:
            dataset = json.load(f)
        print(f"样本数量{len(dataset)}")
        input_ids = [data[0] for data in dataset]
        labels = [data[1] for data in dataset]
        print("tokenize前")
        for i in range(5):
            print(input_ids[i], labels[i])
        input_ids = tokenizer.batch_encode_plus(input_ids)['input_ids']
        label_ids = tokenizer.batch_encode_plus(labels, add_special_tokens=False)['input_ids']
        for i in label_ids:
            assert len(i) == 1, i
        input_ids = [torch.tensor(e, dtype=torch.long) for e in input_ids]
        label_ids = [label_id[0] for label_id in label_ids]
        label_ids = torch.tensor(label_ids, dtype=torch.long)

        print("tokenize后")
        for i in range(5):
            print(input_ids[i], label_ids[i])

        self.encodings = self._collate(input_ids, label_ids)

    def _collate(self, input_ids, label_ids):
        # 传进来原始id和mask后的id 返回padding和label
        #
        # 进行padding
        max_length = max(x.size(0) for x in input_ids)
        print("最大长度:", max_length)
        padd_inputs = input_ids[0].new_full([len(input_ids), max_length], self.tokenizer.pad_token_id)
        padd_labels = input_ids[0].new_full([len(input_ids), max_length], -100)

        for i in range(len(input_ids)):
            assert self.tokenizer.padding_side == "right", tokenizer.padding_side
            padd_inputs[i, : input_ids[i].shape[0]] = input_ids[i]
        # 构造标签
        mask_idx = torch.nonzero(padd_inputs == self.tokenizer.mask_token_id, as_tuple=True)
        padd_labels[mask_idx] = label_ids
        batch = {
            "input_ids": padd_inputs,
            "labels": padd_labels,
        }
        print("inputs_ids和labels: ")
        for i in range(5):
            print(batch["input_ids"][i], batch["labels"][i])
        return batch

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

class pretrainDataset_an(pretrainDataset):
    def __init__(self, tokenizer, file_path):
        super().__init__(tokenizer, file_path)

    def _collate(self, input_ids, label_ids):
        # 传进来原始id和mask后的id 返回padding和label
        #
        # 进行padding
        max_length = max(x.size(0) for x in input_ids)

        padd_inputs = input_ids[0].new_full([len(input_ids), max_length], self.tokenizer.pad_token_id)
        padd_labels = input_ids[0].new_full([len(input_ids), max_length], -100)

        for i in range(len(input_ids)):
            assert self.tokenizer.padding_side == "right", tokenizer.padding_side
            padd_inputs[i, : input_ids[i].shape[0]] = input_ids[i]
        # 构造标签
        mask_idx = torch.nonzero(torch.logical_or(padd_inputs==1, padd_inputs==2), as_tuple=True)
        padd_labels[mask_idx] = label_ids
        batch = {
            "input_ids": padd_inputs,
            "labels": padd_labels,
        }
        print("inputs_ids和labels: ")
        for i in range(5):
            print(batch["input_ids"][i], batch["labels"][i])
        return batch

class tripleDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, file_path, train_ratio, lang='zh'):
        tav = list(map(float, train_ratio.split('_')))
        assert sum(tav) == 1, (train_ratio, tav)
        self.tav = tav
        """加载文本并且mask喻体和属性"""
        self.tokenizer = tokenizer

        with open(file_path, "r", encoding="utf-8") as f:
            origin_sentences = []
            masked_sentences = []
            for i, datas in enumerate(csv.reader(f)):
                if i == 0:
                    continue
                text, *lens = datas
                tenor_len, attribute_len, vehicle_len = [int(len) for len in lens]

                origin, masked = self.mask(lang, text, tenor_len, attribute_len, vehicle_len)
                origin_sentences.append(origin)
                masked_sentences.append(masked)

        print("原始文本和mask后的文本: ")
        for i in range(5):
            print(origin_sentences[i], masked_sentences[i])

        origin_sentences = tokenizer.batch_encode_plus(origin_sentences)['input_ids']
        masked_sentences = tokenizer.batch_encode_plus(masked_sentences)['input_ids']
        origin_sentences = [torch.tensor(e, dtype=torch.long) for e in origin_sentences]
        masked_sentences = [torch.tensor(e, dtype=torch.long) for e in masked_sentences]

        print("tokenize后, 原始文本和mask后的文本: ")
        for i in range(5):
            print(origin_sentences[i], masked_sentences[i])

        self.encodings = self._collate(tokenizer, masked_sentences, origin_sentences)

    def mask(self, lang, text, tenor_len, attribute_len, vehicle_len):
        """在这里进行mask"""
        mask_ele = np.random.choice(['t', 'a', 'v'], p=self.tav)
        if lang == 'zh':
            # if random.random() >= 0.5:
            #     # 属性预测
            #     # {本体}像{喻体}一样{属性}。
            #     # 老师像园丁一样辛勤. 三个长度都是2
            #     # pre = 2 + 1 + 2 + 2
            #     pre = tenor_len + 1 + vehicle_len + 2
            #     MASK_LEN = attribute_len
            # else:
            #     # 喻体预测
            #     pre = tenor_len + 1
            #     MASK_LEN = vehicle_len
            pre = 0
            MASK_LEN = tenor_len
            SLICE = slice(pre, pre + MASK_LEN)
            # 本体预测
            text = list(text)
            masked = list(text)
            masked[SLICE] = [tokenizer.mask_token] * MASK_LEN
            return text, ''.join(text)
        elif lang == 'en':
            # 模板: The {tenor} is as {attribute} as {vehicle}.
            if mask_ele == 't':
                pre = len("The ")
                MASK_LEN = tenor_len
            elif mask_ele == 'a':
                pre = len("The  is as ") + tenor_len
                MASK_LEN = attribute_len
            elif mask_ele == 'v':
                # 喻体的预测
                pre = len("The  is as  as ") + tenor_len + attribute_len
                MASK_LEN = vehicle_len
            else:
                assert 1 == 0, mask_ele
            text = list(text)
            # 只预测目标的第一个子词
            subword = self.tokenizer.tokenize("".join(text[pre:pre + MASK_LEN]))[0]
            orgin = text[:pre] + list(subword) + text[pre + MASK_LEN:]
            masked = text[:pre] + list(self.tokenizer.mask_token) + text[pre + MASK_LEN:]
            return "".join(orgin), "".join(masked)
        else:
            raise ValueError("语言错误")

    def _collate(self, tokenizer, input_ids, origin_ids):
        # 传进来原始id和mask后的id 返回padding和label
        #
        # 进行padding
        max_length = max(x.size(0) for x in input_ids)

        padd_inputs = input_ids[0].new_full([len(input_ids), max_length], tokenizer.pad_token_id)
        padd_labels = input_ids[0].new_full([len(input_ids), max_length], tokenizer.pad_token_id)
        for i in range(len(input_ids)):
            if tokenizer.padding_side == "right":
                padd_inputs[i, : input_ids[i].shape[0]] = input_ids[i]
                padd_labels[i, : input_ids[i].shape[0]] = origin_ids[i]
            else:
                padd_inputs[i, : input_ids[i].shape[0]] = input_ids[i]
                padd_labels[i, : input_ids[i].shape[0]] = origin_ids[i]
        # 构造标签
        padd_labels[padd_labels == padd_inputs] = -100

        batch = {
            "input_ids": padd_inputs,
            "labels": padd_labels,
        }
        print("inputs_ids和labels: ")
        for i in range(5):
            print(batch["input_ids"][i], batch["labels"][i])
        return batch

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])
    
    
import torch, pickle
class POSVocab():
    """词表相关的组合成一个类"""
    def __init__(self, tokenizer, adj_file="词表/adj_4800", noun_file = "词表/noun_trankit"):
        self.tokenizer = tokenizer
        
        self.vocab_file = {}
        self.vocab_file['adj'] = adj_file
        self.vocab_file['noun'] = noun_file

        self.words = {}
        with open(self.vocab_file['adj'], 'rb') as f:
            self.words['adj'] = pickle.load(f)

        with open(self.vocab_file['noun'], 'rb') as f:
            self.words['noun'] = pickle.load(f)

        self.mask = {}
        self.mask['adj'] = self.getVocabMask(tokenizer, self.vocab_file['adj'])
        self.mask['noun'] = self.getVocabMask(tokenizer, self.vocab_file['noun'])

    def getVocabMask(self, tokenizer, vocab_file):
        """返回一个mask矩阵, mask掉那些不在词表里的单词"""
        with open(vocab_file, 'rb') as f:
            words = set(pickle.load(f))
    #     print("词表大小:", len(words))

        vocab = tokenizer.vocab
        mask = [False] * len(vocab)
        tmp = None
        for word,i in vocab.items():
            if word in words:
                mask[i] = True
                tmp = i

    #     print(f"在词表中的词有: {sum([1 for i in mask if i])} 个, {tmp} 是最后一个")
        return torch.tensor(mask)
    
    
from transformers.pipelines import FillMaskPipeline
from itertools import chain
from tqdm import trange, tqdm
import time

class MetaphorGenerate(FillMaskPipeline):
    def __init__(self, model, tokenizer, top_k=50, pos_vocab=False, device=torch.device("cuda:0")):
        super().__init__(model, tokenizer)
        self.top_k = top_k
        self.mask_token = tokenizer.mask_token
        self.pos_vocab = pos_vocab
        self.device = device
        self.model.to(self.device)
        
    def generate(self, eles, score_fnt, pattern_fnt):
        sorted_scores, origin_scores, ids, topk_tokens = [], [], [], []
        logits_tmp = []
        for ele in tqdm(eles):
            # 要预测的元素
            if ele[1] is None:
                mode = 'noun'
            elif ele[2] is None:
                mode = 'adj'
            else:
                assert 1==2, ele

            # 构造模板, 前向传播
            sentence, weight = pattern_fnt(ele, self.tokenizer.mask_token)
            inputs = self._parse_and_tokenize(sentence)
            outputs = self._forward(inputs, return_tensors=True)
            
            # 取出[MASK]位置的logit
            masked_index = torch.nonzero(inputs["input_ids"] == self.tokenizer.mask_token_id, as_tuple=True)
            logits = outputs[masked_index[0], masked_index[1], :]
            
            # 用词表筛选一下
            if self.pos_vocab is not None:
                logits.masked_fill_(~self.pos_vocab.mask[mode], -1e9)
            
            # 计算分数, 返回分数
            score, tmp = score_fnt(logits, ele=ele, pos_vocab=self.pos_vocab, weight=weight, mode=mode)
            score_sorted, id = score.topk(self.top_k)
            if tmp is not None:
                print(tmp[:,id])
            # TODO 子词这里会预测错误
            token = self.tokenizer.decode(id).split(' ')
            
            
            sorted_scores.append(score_sorted)
            origin_scores.append(score)
            ids.append(id)
            topk_tokens.append(token)
            
        return origin_scores, topk_tokens

    
def getMT(type, model_name):
    if type in ['rulebased', 'amod', 'bert'] :
    #     model_name = 'output/pretained_e1'
#         model_name = 'pretrained_models/bert-large-uncased'
        model = BertForMaskedLM.from_pretrained(model_name)
        tokenizer = BertTokenizer.from_pretrained(model_name)

    if type == '只调整embedding':
#         model_name = 'prompt/output/只动embedding_r_012_1e74'
        model = BertForMaskedLM_E.from_pretrained(model_name)
        tokenizer = BertTokenizer.from_pretrained(model_name)

    if type == '外部unusedtoken':
#         model_name = 'prompt/output/外部unused_r_012_35e74'
        model = BertForMaskedLM_CO.from_pretrained(model_name)
        unused_tokens = [f"[unused{i}]" for i in range(100)]
        tokenizer = BertTokenizer.from_pretrained(model_name, additional_special_tokens=unused_tokens)
    print(type, model_name)
    return model, tokenizer


class Glove(torch.nn.Module):
    def __init__(self, weight, word2id, id2word):
        super().__init__()
        self.embedding = torch.nn.Embedding.from_pretrained(weight)
        self.word2id = word2id
        self.id2word = id2word
        
    def forward(self, x):
        """x是一个list, 返回两两相似度的值"""
        res = {}
        embed = self.embedding(x)
        for i,j in itertools.combinations(range(embed.shape[1]), 2):
            res[(i,j)] = torch.cosine_similarity(embed[:,i],embed[:,j],dim=1)
        return res
    def encoding(self, triples):
        # 输入batch*n的字符 或者 id
        input = []
        for words in triples:
#             input.append([self.word2id[word] for word in words])
            input.append([self.word2id[word.split(' ')[-1]] for word in words])
            
        return torch.from_numpy(np.array(input))

if __name__ == '__main__':
    model_name = "bert-large-uncased"
    cache_file = "./pretrained_models/bert-large-uncased"
    model = BertForMaskedLM.from_pretrained(model_name, cache_dir=cache_file)
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_file)
    # 加载已经mask好的
    examples = pretrainDataset(tokenizer, "masked_text.txt")
    print(len(examples))
    # x = MyDataCollatorForLanguageModeling(tokenizer)
    # x(examples)

    # batch = {"input_ids": _collate_batch(examples, tokenizer, pad_to_multiple_of=None)}
