# Copyright 2021 Reranker Author. All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pickle
from typing import Optional, Dict
import os
import logging

import torch
from torch import nn
import torch.distributed as dist
from torch.nn import functional as F2
from transformers import PreTrainedModel, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling
from .arguments import ModelArguments, DataArguments, \
    RerankerTrainingArguments as TrainingArguments

logger = logging.getLogger(__name__)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        return self.net(x)

class MatRank(nn.Module):
    def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments,
                 train_args: TrainingArguments, config):
        super().__init__()
        self.hf_model = hf_model
        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.config = config
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.mlp = FeedForward(config.hidden_size * 2, config.hidden_size)
        if train_args.local_rank >= 0:
            self.world_size = dist.get_world_size()

    def forward(self, batch: Dict):
        if self.training:
            tgt = batch.pop('labels')
        sent_out: BaseModelOutputWithPooling = self.hf_model(**batch, return_dict=True)
        reps = sent_out.last_hidden_state[:, 0]
        reps = reps.contiguous().view(
            -1,
            self.data_args.train_group_size,
            self.config.hidden_size,
        )
        left = reps.unsqueeze(2).expand(-1, -1, self.data_args.train_group_size, -1)
        right = reps.unsqueeze(1).expand(-1, self.data_args.train_group_size, -1, -1)
        reps_pair = torch.cat([left, right], dim=-1)
        reps_score = self.mlp(reps_pair).squeeze(dim=-1)
        reps_score = reps_score - torch.diag_embed(reps_score.diagonal(dim1=1, dim2=2))
        logits_sum1 = torch.sum(reps_score, dim=1)
        score1 = F2.softmax(logits_sum1)
        logits_sum2 = torch.sum(reps_score, dim=2)
        score2 = F2.softmax(-logits_sum2)
        score = torch.cat([score1, score2], dim=-1)
        if self.training:
            tgt = tgt.view(self.train_args.per_device_train_batch_size, self.data_args.train_group_size)
            tgt = tgt[:, 0]
            loss1 = self.cross_entropy(logits_sum1, tgt)
            loss2 = self.cross_entropy(-logits_sum2, tgt)
            loss = loss1 + loss2
            return SequenceClassifierOutput(
                loss=loss,
                logits=score
            )
        else:
            return SequenceClassifierOutput(
                logits=score
            )

    @classmethod
    def from_pretrained(
            cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments,
            *args, **kwargs
    ):
        hf_model = AutoModel.from_pretrained(*args, **kwargs)
        path = args[0]
        reranker = cls(hf_model, model_args, data_args, train_args, kwargs['config'])
        if os.path.exists(os.path.join(path, 'model.pt')):
            logger.info('loading extra weights from local files')
            model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
            reranker.load_state_dict(model_dict, strict=False)
        return reranker

    def save_pretrained(self, output_dir: str):
        self.hf_model.save_pretrained(output_dir)
        model_dict = self.state_dict()
        hf_weight_keys = [k for k in model_dict.keys() if k.startswith('hf_model')]
        for k in hf_weight_keys:
            model_dict.pop(k)
        torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
        torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)
        all_tensors[self.train_args.local_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

class MatRankSplit(nn.Module):
    def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments,
                 train_args: TrainingArguments, config):
        super().__init__()
        self.hf_model = hf_model
        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.config = config
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.mlp = nn.Linear(config.hidden_size * 2, 1)
        self._keys_to_ignore_on_save = []
        self._keys_to_ignore_on_load_missing = []
        if train_args.local_rank >= 0:
            self.world_size = dist.get_world_size()
        self.max_part_num = data_args.max_part_num

    def forward(self, batch):
        if self.training:
            group_size = self.data_args.train_group_size
            part_num = self.data_args.part_num
        else:
            group_size = self.data_args.eval_group_size
            part_num = self.data_args.eval_part_num
        if 'labels' in batch:
            tgt = batch.pop('labels')
        part = batch.pop('part')
        if self.training:
            start_idx = part_num - 1
            while torch.sum(part) > self.max_part_num:
                part[start_idx::part_num] = 0
                start_idx -= 1
        input_ids = batch['input_ids']
        input_idx = (part == 1).nonzero().repeat(1, input_ids.shape[1])
        attention_mask = batch['attention_mask']
        batch_select = {}
        batch_select['input_ids'] = torch.gather(input_ids, 0, input_idx)
        batch_select['attention_mask'] = torch.gather(attention_mask, 0, input_idx)
        if 'token_type_ids' in batch:
            token_type_ids = batch['token_type_ids']
            batch_select['token_type_ids'] = torch.gather(token_type_ids, 0, input_idx)
        sent_out: BaseModelOutputWithPooling = self.hf_model(**batch_select, return_dict=True)
        reps = sent_out.last_hidden_state[:, 0]
        full_zeros = torch.zeros(part.shape[0], self.config.hidden_size).to(self.hf_model.device)
        hidden_idx = (part == 1).nonzero().repeat(1, self.config.hidden_size)
        reps_full = full_zeros.scatter_(0, hidden_idx, reps)
        # reps = torch.mean(sent_out.last_hidden_state, dim=1)
        part = part.view(-1, group_size * part_num)
        reps_full = reps_full.contiguous().view(
            -1,
            group_size * part_num,
            self.config.hidden_size,
        )
        left = reps_full.unsqueeze(2).expand(-1, -1, group_size * part_num, -1)
        right = reps_full.unsqueeze(1).expand(-1, group_size * part_num, -1, -1)
        reps_pair = torch.cat([left, right], dim=-1)
        logits = self.mlp(reps_pair).squeeze(dim=-1)
        logit_mask = part.unsqueeze(1) * part.unsqueeze(2)
        for i in range(group_size):
            logit_mask[:, i * part_num:(i + 1) * part_num, i * part_num:(i + 1) * part_num] = 0
        logits_mtx = logits * logit_mask
        logits_sum1 = torch.sum(logits_mtx, dim=1)
        logits_sum1 = logits_sum1 + (part - 1) * 1e6
        logits_sum2 = torch.sum(logits_mtx, dim=2)
        logits_sum2 = logits_sum2 + (1 - part) * 1e6
        # score2 = F2.softmax(-logits_sum2)
        logits_1 = logits_sum1.reshape(-1, group_size, part_num).max(-1).values
        logits_2 = logits_sum2.reshape(-1, group_size, part_num).min(-1).values
        score = torch.cat([logits_1, logits_2], dim=-1)
        if self.training:
            tgt = tgt.view(self.train_args.per_device_train_batch_size, group_size * part_num)
            tgt = tgt[:, 0]
            loss1 = self.cross_entropy(logits_1, tgt)
            loss2 = self.cross_entropy(-logits_2, tgt)
            loss = loss1 + loss2

            return SequenceClassifierOutput(
                loss=loss,
                logits=score
            )
        else:
            return SequenceClassifierOutput(
                logits=score
            )

    @classmethod
    def from_pretrained(
            cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments,
            *args, **kwargs
    ):
        hf_model = AutoModel.from_pretrained(*args, **kwargs)
        path = args[0]
        reranker = cls(hf_model, model_args, data_args, train_args, kwargs['config'])
        if os.path.exists(os.path.join(path, 'model.pt')):
            logger.info('loading extra weights from local files')
            model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
            load_result = reranker.load_state_dict(model_dict, strict=False)
        return reranker

    def save_pretrained(self, output_dir: str):
        self.hf_model.save_pretrained(output_dir)
        model_dict = self.state_dict()
        hf_weight_keys = [k for k in model_dict.keys() if k.startswith('hf_model')]
        for k in hf_weight_keys:
            model_dict.pop(k)
        torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
        torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)
        all_tensors[self.train_args.local_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors
