from transformers import RobertaTokenizer, RobertaModel, AutoModelWithLMHead, AutoTokenizer, Trainer, AutoModel, BertLMHeadModel
from datasets.load import load_dataset, load_from_disk
import torch, os, sys, time, random, json, argparse
from rouge_score.rouge_scorer import RougeScorer

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.utils.data.distributed import DistributedSampler

class QuestionReferenceDensity(torch.nn.Module):
    def __init__(self):
        """
        Initialize the model with pre-trained question and reference encoders.
        """
        super().__init__()
        self.question_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco")
        self.reference_encoder = AutoModel.from_pretrained("facebook/contriever-msmarco")

        total = sum([param.nelement() for param in self.parameters()])
        print("Number of parameter: %.2fM" % (total / 1e6))
    
    def mean_pooling(self, token_embeddings, mask):
        """
        均值池化，将所有token的embedding进行求和，然后除以token数量，得到每个sentence的embedding
        Args:
            token_embeddings (torch.Tensor): shape=(batch_size, seq_len, embed_size), token的embedding矩阵
            mask (torch.BoolTensor): shape=(batch_size, seq_len), True对应的是padding部分，False对应的是非padding部分
        
        Returns:
            torch.Tensor: shape=(batch_size, embed_size), 每个sentence的embedding矩阵
        """
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings
        
    
    def forward(self, question, pos, neg):
        """
        Forward function for the model. It takes in a dictionary containing the input question and two reference
        sentences, and returns the logits of positive and negative examples.
        
        Args:
            question (dict): A dictionary containing the input question, with keys "input_ids" (required),
                "attention_mask" (optional), and "token_type_ids" (optional).
            pos (dict): A dictionary containing one reference sentence for positive example, with keys "input_ids"
                (required), "attention_mask" (optional), and "token_type_ids" (optional).
            neg (dict): A dictionary containing one reference sentence for negative example, with keys "input_ids"
                (required), "attention_mask" (optional), and "token_type_ids" (optional).
        
        Returns:
            tuple: A tuple containing two elements, l_pos (torch.Tensor): The logits of positive examples, and
                l_neg (torch.Tensor): The logits of negative examples. Both are torch.Tensors of shape (batch_size, 1).
        
        """
        global args
        
        q = self.question_encoder(**question)
        r_pos = self.reference_encoder(**pos)
        r_neg = self.reference_encoder(**neg)
        cls_q = self.mean_pooling(q[0], question["attention_mask"])
        cls_q /= args.temp
        cls_r_pos = self.mean_pooling(r_pos[0], pos["attention_mask"])
        cls_r_neg = self.mean_pooling(r_neg[0], neg["attention_mask"])
        
        l_pos = torch.matmul(cls_q, torch.transpose(cls_r_pos, 0, 1))

        l_neg = torch.matmul(cls_q, torch.transpose(cls_r_neg, 0, 1))

        return l_pos, l_neg
        
    @staticmethod
    def loss(l_pos, l_neg):
        """
        计算损失函数，包含正样本和负样本的交叉熵
        Args:
            l_pos (Tensor, shape=[N, C]): 正样本的特征向量，C为类别数
            l_neg (Tensor, shape=[M, C]): 负样本的特征向量，M为负样本数，C为类别数
            返回值 (Tensor, shape=[1]): 返回一个长度为1的张量，表示整体的交叉熵损失
        """
        return torch.nn.functional.cross_entropy(torch.cat([l_pos, l_neg], dim=1), torch.arange(0, len(l_pos), dtype=torch.long, device=args.device))
    
    @staticmethod
    def num_correct(l_pos, l_neg):
        """
        Given two matrices (one positive and one negative), count the number of correct predictions.
        Args:
            l_pos (Tensor): (N, N) matrix representing logits for positive examples.
            l_neg (Tensor): (N, N) matrix representing logits for negative examples.
        Returns:
            int: Number of correct predictions.
        """
        return ((torch.diag(l_pos) > torch.diag(l_neg))==True).sum()

    @staticmethod
    def acc(l_pos, l_neg):
        """
        计算正确率，返回一个float值，范围在0-1之间
        Args:
            l_pos (Tensor): 正样本的标签，形状为(N, N)，其中N是batch size
            l_neg (Tensor): 负样本的标签，形状为(N, N)，其中N是batch size
        Returns:
            float: 正确率，范围在0-1之间
        """
        return ((torch.diag(l_pos) > torch.diag(l_neg))==True).sum() / len(l_pos)


class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
        self.warmup = warmup
        self.total = total
        self.ratio = ratio
        super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        """Calculate the learning rate at given step.
        
        Args:
            step (int): The current step of training.
        
        Returns:
            float: The calculated learning rate.
        """
        if step < self.warmup:
            return (1 - self.ratio) * step / float(max(1, self.warmup))

        return max(
            0.0,
            1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)),
        )


def move_dict_to_device(obj, device):
    """
    将字典中的所有元素移动到指定设备上。
    
    Args:
        obj (dict): 需要移动的字典，其中每个值都是一个 torch.Tensor。
        device (torch.device, optional): 目标设备，默认为 None。如果不指定，则保持原样。
    
    Returns:
        dict: 返回一个新的字典，其中每个值都是一个在目标设备上的 torch.Tensor。
    
    Raises:
        TypeError: 如果 obj 不是一个字典，或者其中任何一个值不是一个 torch.Tensor，会抛出 TypeError。
    """
    for key in obj:
        obj[key] = obj[key].to(device)

def collate(data):
    """
    将数据集中的样本进行整理，以便可以被模型处理。
    
    Args:
        data (list[dict]): 包含多个字典的列表，每个字典都应该包含"question"、"positive_reference"和"negative_reference"三个键，分别对应问题、正面引用和负面引用。
    
    Returns:
        tuple: 返回一个元组，包含两个部分：第一部分是一个字典，包含了问题、正面引用和负面引用的张量；第二部分是一个字典，包含了问题、正面引用和负面引用的长度。
    
    Raises:
        None
    
    """
    question = tokenizer([item["question"] for item in data], return_tensors="pt", padding=True, truncation=True)
    positive_reference = tokenizer([item["positive_reference"] for item in data], return_tensors="pt", padding=True, truncation=True)
    negative_reference = tokenizer([item["negative_reference"] for item in data], return_tensors="pt", padding=True, truncation=True)

    for key in question: question[key] = question[key].to(args.device)
    for key in positive_reference: positive_reference[key] = positive_reference[key].to(args.device)
    for key in negative_reference: negative_reference[key] = negative_reference[key].to(args.device)

    return question, positive_reference, negative_reference

def eval():
    """
    评估模型的性能，计算准确率。
    
    返回（None），无返回值。
    
    更多信息请参考：https://pytorch.org/tutorials/beginner/basics/model_evaluation_tutorial.html
    
    Yields:
        None, 没有返回值
    
    Raises:
        None, 没有异常抛出
    
    Warns:
        None, 没有警告信息
    
    Example:
        >>> eval()
        EVALUATION, Acc: 95.234897%
    """
    # print("EVAL ...")
    model.eval()
    with torch.no_grad():
        total_acc = 0
        for q, pos, neg in eval_loader:
            results = model(q, pos, neg)
            # print(results)
            # exit()
            tot_cr = model.num_correct(*results)
            total_acc += tot_cr

        print("EVALUATION, Acc: %10.6f"%(total_acc / len(eval_set)))
    
def save(name):
    """
    保存模型参数和训练信息到指定目录下，包括问题编码器和引用编码器。
    
    Args:
        name (str): 模型的名称，将作为保存目录的一部分。
            log_dir/name/query_encoder和log_dir/name/reference_encoder将被创建。
    
    Returns:
        None: 无返回值。
    
    Raises:
        None: 没有异常抛出。
    
    """
    os.makedirs(log_dir, exist_ok=True)
    model.question_encoder.save_pretrained(os.path.join(log_dir, name, "query_encoder"))
    model.reference_encoder.save_pretrained(os.path.join(log_dir, name, "reference_encoder"))

def train(max_epoch = 10, eval_step = 200, save_step = 400, print_step = 50):
    """
    训练模型，包括每个epoch的训练、保存和评估。
    
    Args:
        max_epoch (int, optional): 最大训练次数，默认为10. Defaults to 10.
        eval_step (int, optional): 评估频率，默认为200. Defaults to 200.
        save_step (int, optional): 保存频率，默认为400. Defaults to 400.
        print_step (int, optional): 打印频率，默认为50. Defaults to 50.
    
    Returns:
        None: 无返回值。
    """
    step = 0
    for epoch in range(0, max_epoch):
        print("EPOCH %d"%epoch)
        for q, pos, neg in train_loader:
            model.train()
            step += 1
            opt.zero_grad()
            results = model(q, pos, neg)
            loss = model.loss(*results)
            
            if step % print_step == 0:
                print("Step %4d, Loss, Acc: %10.6f, %10.6f"%(step, loss, model.acc(*results)))
            
            loss.backward()
            opt.step()
            
            scheduler.step()
            model.zero_grad()
            if step % eval_step == 0:
                eval()
                pass
            if step % save_step == 0:
                save("step-%d"%(step))
            

        save("step-%d-epoch-%d"%(step, epoch))
        # eval()

if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--max_epoch", type=int, default=3)
    args.add_argument("--eval_step", type=int, default=40)
    args.add_argument("--save_step", type=int, default=40)
    args.add_argument("--print_step", type=int, default=40)
    args.add_argument("--device", type=str, default="cuda")
    args.add_argument("--temp", type=float, default=0.05)
    args.add_argument("--train_batch_size", type=int, default=64)
    args.add_argument("--eval_batch_size", type=int, default=32)
    args.add_argument("--lr", type=float, default=1e-6)
    args.add_argument("--warmup", type=int, default=100)
    args.add_argument("--total", type=int, default=1000)
    args.add_argument("--ratio", type=float, default=0.0)
    args.add_argument("--save_dir", type=str, default="./retriever_runs")
    args.add_argument("--train_data_dir", type=str, required=True)
    
    args = args.parse_args()
    
    log_dir = os.path.join(args.save_dir, time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())))
    
    train_set = load_from_disk(os.path.join(args.train_data_dir, "train"))
    eval_set = load_from_disk(os.path.join(args.train_data_dir, "eval"))
    
    tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
    train_loader = DataLoader(train_set, batch_size=args.train_batch_size, collate_fn=collate)
    eval_loader = DataLoader(eval_set, batch_size=args.eval_batch_size, collate_fn=collate)

    model = QuestionReferenceDensity()
    model = model.to(args.device)
    opt = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)
    scheduler_args = {
        "warmup": args.warmup,
        "total": args.total,
        "ratio": args.ratio,
    }
    scheduler = WarmupLinearScheduler(opt, **scheduler_args)
    temp = args.temp
    
    train(max_epoch=args.max_epoch, eval_step=args.eval_step, save_step=args.save_step, print_step=args.print_step)

