import os
import copy
import torch
import math
import torch.nn as nn
from scipy.stats import kendalltau, spearmanr
import torch.nn.functional as F

from itertools import *
from operator import add

from dataclasses import dataclass
from typing import Optional, Dict, Tuple
from torch import Tensor
from transformers import (
    AutoModel,
    PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput

from config import Arguments
from logger_config import logger
from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels

ta_models_name = ['cotmae', 'retromae', 'simlm', 'm2dpr']
combination_lists = []
for i in range(1, len(ta_models_name)+1):
    combination_lists += combinations(ta_models_name,i)


@dataclass
class BiencoderOutput(ModelOutput):
    q_reps: Optional[Tensor] = None
    p_reps: Optional[Tensor] = None
    loss: Optional[Tensor] = None
    labels: Optional[Tensor] = None
    scores: Optional[Tensor] = None


class BiencoderModel(nn.Module):
    def __init__(self, args: Arguments,
                 lm_q: PreTrainedModel,
                 lm_p: PreTrainedModel):
        super().__init__()
        self.lm_q = lm_q
        self.lm_p = lm_p
        self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
        self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
        self.args = args
        self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()

        from trainers import BiencoderTrainer
        self.trainer: Optional[BiencoderTrainer] = None

    def forward(self, query: Dict[str, Tensor] = None,
                passage: Dict[str, Tensor] = None):

        current_epoch = int(self.trainer.state.epoch or 0)
        assert self.args.process_index >= 0

        scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(query, passage)

        start = self.args.process_index * q_reps.shape[0]

        group_indices = select_grouped_indices(scores=scores,
                                               group_size=self.args.train_n_passages,
                                               start=start * self.args.train_n_passages)

        if not self.args.do_kd_biencoder:
            if self.args.use_scaled_loss:
                loss = self.cross_entropy(all_scores, all_labels)
                loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
            else:
                loss = self.cross_entropy(scores, labels)
        else:
            group_scores = torch.gather(input=scores, dim=1, index=group_indices)
            assert group_scores.shape[1] == self.args.train_n_passages
            group_log_scores = torch.log_softmax(group_scores, dim=-1)
            kd_log_target = torch.log_softmax(query['kd_labels'], dim=-1)

            kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target)

            if self.training and self.args.kd_mask_hn:
                scores = torch.scatter(input=scores, dim=1, index=group_indices[:, 1:], value=float('-inf'))
            if self.args.use_scaled_loss:
                ce_loss = self.cross_entropy(all_scores, all_labels)
                ce_loss *= self.args.world_size if self.args.loss_scale <= 0 else self.args.loss_scale
            else:
                ce_loss = self.cross_entropy(scores, labels)

            loss = (self.args.kd_cont_loss_weight * ce_loss + kd_loss)

        if self.args.do_multi_kd:
            if self.args.selection_method == 'KL':
                kd_distillation_sequential, multi_kd_ce_dict = self._compute_kd_sequential(query, group_log_scores)
                min_name = [i[0]for i in sorted(multi_kd_ce_dict.items(), key=lambda x: x[1])][0]

                multi_kd_loss=[]
                for index, i in enumerate(kd_distillation_sequential):
                    t_name, st_kd_loss = i[0], i[1]
                    if min_name == t_name:
                       multi_kd_loss.append(st_kd_loss)

                if len(multi_kd_loss) > 0:
                    loss += 15.0*sum(multi_kd_loss)
                    
            elif self.args.selection_method == 'KT':
                min_name, _ = self._select_kendall_tau(query)
                min_name_list = min_name.split(" ")
                if min_name_list[0] == "max" and min_name_list[1] == "min":
                    if len(min_name_list) == 4:
                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        min_tensor = torch.min(merge_tensor, dim=0).values
                        min_tensor[:,0] = max_tensor[:,0]
                        
                        min_name_scores = torch.log_softmax(min_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                        
                    elif len(min_name_list) == 5:

                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0), query[min_name_list[4]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        min_tensor = torch.min(merge_tensor, dim=0).values
                        min_tensor[:,0] = max_tensor[:,0]
                        
                        min_name_scores = torch.log_softmax(min_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                        
                    else:
                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0), query[min_name_list[4]].unsqueeze(0),query[min_name_list[5]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        min_tensor = torch.min(merge_tensor, dim=0).values
                        min_tensor[:,0] = max_tensor[:,0]
                        
                        min_name_scores = torch.log_softmax(min_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss

                elif min_name_list[0] == "max" and min_name_list[1] == "mean":

                    if len(min_name_list) == 4:
                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        mean_tensor = torch.mean(merge_tensor, dim=0)
                        mean_tensor[:,0] = max_tensor[:,0]
                        
                        min_name_scores = torch.log_softmax(mean_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                        
                    elif len(min_name_list) == 5:

                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0), query[min_name_list[4]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        mean_tensor = torch.mean(merge_tensor, dim=0)
                        mean_tensor[:,0] = max_tensor[:,0]  
                        
                        min_name_scores = torch.log_softmax(mean_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                        
                    else:
                        merge_tensor =torch.cat([query[min_name_list[2]].unsqueeze(0), query[min_name_list[3]].unsqueeze(0), query[min_name_list[4]].unsqueeze(0),query[min_name_list[5]].unsqueeze(0)], dim=0)
                        max_tensor = torch.max(merge_tensor, dim=0).values
                        mean_tensor = torch.mean(merge_tensor, dim=0)
                        mean_tensor[:,0] = max_tensor[:,0]
                        
                        min_name_scores = torch.log_softmax(mean_tensor/10, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                else:
                    if len(min_name_list) == 1:
                        min_name_scores = torch.log_softmax(query[min_name_list[0]]/10.0, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                    elif len(min_name_list) == 2:
                        min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]])/20.0, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                    elif len(min_name_list) == 3:
                        min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]])/30.0, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                    else:
                        min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]]+query[min_name_list[3]])/40, dim=-1)
                        min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                        loss += 15.0 * min_name_loss
                        
            elif self.args.selection_method == "SP":
                min_name, _ = self._select_spearman(query)
                min_name_list = min_name.split(" ")
                if len(min_name_list) == 1:
                    min_name_scores = torch.log_softmax(query[min_name_list[0]]/10.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                elif len(min_name_list) == 2:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]])/20.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                elif len(min_name_list) == 3:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]])/30.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                else:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]]+query[min_name_list[3]])/40, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                
            else:
                min_name, _ = self._select_rbo(query)
                min_name_list = min_name.split(" ")
                if len(min_name_list) == 1:
                    min_name_scores = torch.log_softmax(query[min_name_list[0]]/10.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                elif len(min_name_list) == 2:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]])/20.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                elif len(min_name_list) == 3:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]])/30.0, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                else:
                    min_name_scores = torch.log_softmax((query[min_name_list[0]]+query[min_name_list[1]]+query[min_name_list[2]]+query[min_name_list[3]])/40, dim=-1)
                    min_name_loss = self.kl_loss_fn(input=group_log_scores, target=min_name_scores)
                    loss += 15.0 * min_name_loss
                
        total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages

        return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps,
                               labels=labels.contiguous(),
                               scores=scores[:, :total_n_psg].contiguous())
                    
    def _select_kendall_tau(self, query: Dict[str, Tensor]):
        with torch.no_grad():
            teacher_label = query['kd_labels'].clone().detach().cpu().numpy()  
            teacher_ta_kt_dict = {}

            for i in combination_lists:
                if len(i) == 1:
                    ta_scores = query[i[0]].clone().detach().cpu().numpy()
                    teacher_ta_kt_dict[i[0]] = float(kendalltau(ta_scores, teacher_label).statistic)

                elif len(i) == 2:
                    # mean
                    ta1_ta2_scores = ((query[i[0]]+query[i[1]])/2.0).clone().detach().cpu().numpy()
                    teacher_ta_kt_dict[i[0]+" "+i[1]] = float(kendalltau(ta1_ta2_scores, teacher_label).statistic)
                    
                elif len(i) == 3:
                    ta1_ta2_ta3_scores = ((query[i[0]]+query[i[1]]+query[i[2]])/3.0).clone().detach().cpu().numpy()
                    teacher_ta_kt_dict[i[0]+" "+i[1]+" "+i[2]] = float(kendalltau(ta1_ta2_ta3_scores, teacher_label).statistic)
                    
                else:
                    ta1_ta2_ta3_ta4_scores = ((query[i[0]]+query[i[1]]+query[i[2]]+query[i[3]])/4.0).clone().detach().cpu().numpy()
                    teacher_ta_kt_dict[i[0]+" "+i[1]+" "+i[2]+" "+i[3]] = float(kendalltau(ta1_ta2_ta3_ta4_scores, teacher_label).statistic)
                
            min_name = [i[0]for i in sorted(teacher_ta_kt_dict.items(), key=lambda x: x[1])][-1]
        
        return min_name, teacher_ta_kt_dict
        
    def _select_spearman(self, query: Dict[str, Tensor]):
        with torch.no_grad():
            teacher_srt1, teacher_idx = query['kd_labels'].sort(dim=1, descending=True)
            teacher_srt2, teacher_rank = teacher_idx.sort(dim=1) 
            teacher_rank = (teacher_rank + 1).tolist() 
            
            teacher_ta_spearman_dict = {}

            for i in combination_lists:
                if len(i) == 1:
                    ta_srt1, ta_idx = query[i[0]].sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += spearmanr(j, ta_rank[idx]).statistic
                    teacher_ta_spearman_dict[i[0]] = sum_num/len(teacher_rank)

                elif len(i) == 2:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]])/2.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist() 
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += spearmanr(j, ta_rank[idx]).statistic
                    teacher_ta_spearman_dict[i[0]+" "+i[1]] = sum_num/len(teacher_rank)
                
                elif len(i) == 3:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]]+query[i[2]])/3.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()  
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += spearmanr(j, ta_rank[idx]).statistic
                    teacher_ta_spearman_dict[i[0]+" "+i[1]+" "+i[2]] = sum_num/len(teacher_rank)
                else:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]]+query[i[2]]+query[i[3]])/4.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += spearmanr(j, ta_rank[idx]).statistic
                    teacher_ta_spearman_dict[i[0]+" "+i[1]+" "+i[2]+" "+i[3]] = sum_num/len(teacher_rank)
                
            min_name = [i[0]for i in sorted(teacher_ta_spearman_dict.items(), key=lambda x: x[1])][-1]

        return min_name, teacher_ta_spearman_dict

    def _compute_rbo_score(self,l1, l2, p=0.5):
        if not l1 or not l2:
            return 0
        s1 = set()
        s2 = set()
        max_depth = len(l1)
        score = 0.0
        for d in range(max_depth):
            s1.add(l1[d])
            s2.add(l2[d])
            avg_overlap = len(s1 & s2) / (d + 1)
            score += math.pow(p, d) * avg_overlap
        return (1 - p) * score

    def _select_rbo(self, query: Dict[str, Tensor]):
        with torch.no_grad():
            teacher_srt1, teacher_idx = query['kd_labels'].sort(dim=1, descending=True)
            teacher_srt2, teacher_rank = teacher_idx.sort(dim=1) 
            teacher_rank = (teacher_rank + 1).tolist()

            teacher_ta_rbo_dict = {}

            for i in combination_lists:
                if len(i) == 1:
                    ta_srt1, ta_idx = query[i[0]].sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += self._compute_rbo_score(l1=j, l2=ta_rank[idx])
                    teacher_ta_rbo_dict[i[0]] = sum_num/len(teacher_rank)

                elif len(i) == 2:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]])/2.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += self._compute_rbo_score(l1=j, l2=ta_rank[idx])
                    teacher_ta_rbo_dict[i[0]+" "+i[1]] = sum_num/len(teacher_rank)
                
                elif len(i) == 3:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]]+query[i[2]])/3.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += self._compute_rbo_score(l1=j, l2=ta_rank[idx])
                    teacher_ta_rbo_dict[i[0]+" "+i[1]+" "+i[2]] = sum_num/len(teacher_rank)
                else:
                    ta_srt1, ta_idx = ((query[i[0]]+query[i[1]]+query[i[2]]+query[i[3]])/4.0).sort(dim=1, descending=True)
                    ta_srt2, ta_rank = ta_idx.sort(dim=1) 
                    ta_rank = (ta_rank + 1).tolist()
                    sum_num = 0.0
                    for idx, j in enumerate(teacher_rank):
                        sum_num += self._compute_rbo_score(l1=j, l2=ta_rank[idx])
                    teacher_ta_rbo_dict[i[0]+" "+i[1]+" "+i[2]+" "+i[3]] = sum_num/len(teacher_rank)
                
            min_name = [i[0]for i in sorted(teacher_ta_rbo_dict.items(), key=lambda x: x[1])][-1]
        return min_name, teacher_ta_rbo_dict
        
    def _compute_kd_sequential(self, query: Dict[str, Tensor],group_log_scores: torch.Tensor):
        kd_log_target = torch.log_softmax(F.normalize(query['kd_labels'],dim=-1)*50.0, dim=-1)
        multi_kd_dict = {}

        for i in combination_lists:
            if len(i) == 1:
                ta_scores = torch.log_softmax(query[i[0]]/10.0, dim=-1)
                ta_kd_loss = self.kl_loss_fn(input=group_log_scores, target=ta_scores)
                multi_kd_dict[i[0]] = ta_kd_loss

            elif len(i) == 2:
                ta1_ta2_scores = torch.log_softmax((query[i[0]]+query[i[1]])/20.0, dim=-1)
                ta1_ta2_kd_loss = self.kl_loss_fn(input=group_log_scores, target=ta1_ta2_scores)
                multi_kd_dict[i[0]+" "+i[1]] = ta1_ta2_kd_loss
                    
            elif len(i) == 3:
                ta1_ta2_ta3_scores = torch.log_softmax((query[i[0]]+query[i[1]]+query[i[2]])/30.0, dim=-1)
                ta1_ta2_ta3_kd_loss = self.kl_loss_fn(input=group_log_scores, target=ta1_ta2_ta3_scores)
                multi_kd_dict[i[0]+" "+i[1]+" "+i[2]] = ta1_ta2_ta3_kd_loss
                    
            else:
                ta1_ta2_ta3_ta4_scores = torch.log_softmax((query[i[0]]+query[i[1]]+query[i[2]]+query[i[3]])/40.0, dim=-1)
                ta1_ta2_ta3_ta4_kd_loss = self.kl_loss_fn(input=group_log_scores, target=ta1_ta2_ta3_ta4_scores)
                multi_kd_dict[i[0]+" "+i[1]+" "+i[2]+" "+i[3]] = ta1_ta2_ta3_ta4_kd_loss
                    
        with torch.no_grad():
            multi_kd_ce_dict = {}

            for i in combination_lists:
                if len(i) == 1:
                    ta_scores = torch.log_softmax(query[i[0]]/2.0, dim=-1)
                    ta_kd_loss = self.kl_loss_fn(input=ta_scores, target=kd_log_target).item()
                    multi_kd_ce_dict[i[0]] = ta_kd_loss  

                elif len(i) == 2:
                    # mean
                    ta1_ta2_scores = torch.log_softmax((query[i[0]]+query[i[1]])/4.0, dim=-1)
                    ta1_ta2_kd_loss = self.kl_loss_fn(input=ta1_ta2_scores, target=kd_log_target).item()
                    multi_kd_ce_dict[i[0]+" "+i[1]] = ta1_ta2_kd_loss
                    
                elif len(i) == 3:
                    ta1_ta2_ta3_scores = torch.log_softmax((query[i[0]]+query[i[1]]+query[i[2]])/6.0, dim=-1)
                    ta1_ta2_ta3_kd_loss = self.kl_loss_fn(input=ta1_ta2_ta3_scores, target=kd_log_target).item()
                    multi_kd_ce_dict[i[0]+" "+i[1]+" "+i[2]] = ta1_ta2_ta3_kd_loss
                    
                else:
                    ta1_ta2_ta3_ta4_scores = torch.log_softmax((query[i[0]]+query[i[1]]+query[i[2]]+query[i[3]])/8.0, dim=-1)
                    ta1_ta2_ta3_ta4_kd_loss = self.kl_loss_fn(input=ta1_ta2_ta3_ta4_scores, target=kd_log_target).item()
                    multi_kd_ce_dict[i[0]+" "+i[1]+" "+i[2]+" "+i[3]] = ta1_ta2_ta3_ta4_kd_loss


        result_sequential = [(i[0],i[1]) for i in sorted(multi_kd_dict.items(), key=lambda x: x[1])]
        result_sequential_ = [(i[0],i[1]) for i in sorted(multi_kd_ce_dict.items(), key=lambda x: x[1])]
        
        return result_sequential, multi_kd_ce_dict

    def _compute_scores(self, query: Dict[str, Tensor] = None,
                        passage: Dict[str, Tensor] = None) -> Tuple:

        q_reps = self._encode(self.lm_q, query)
        p_reps = self._encode(self.lm_p, passage)


        all_q_reps = dist_gather_tensor(q_reps)
        all_p_reps = dist_gather_tensor(p_reps)

        assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages


        all_scores, all_labels = full_contrastive_scores_and_labels(query=all_q_reps, key=all_p_reps,
                                                                    use_all_pairs=self.args.full_contrastive_loss)

        if self.args.l2_normalize:
            if self.args.t_warmup:
                scale = 1 / self.args.t * min(1.0, self.trainer.state.global_step / self.args.warmup_steps)
                scale = max(1.0, scale)
            else:
                scale = 1 / self.args.t
            all_scores = all_scores * scale

        start = self.args.process_index * q_reps.shape[0]
        local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device)
        scores = all_scores.index_select(dim=0, index=local_query_indices)
        labels = all_labels.index_select(dim=0, index=local_query_indices)

        return scores, labels, q_reps, p_reps, all_scores, all_labels

    def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]:
        if not input_dict:
            return None

        output_hidden_states = True

        outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['kd_labels', 'TA1', 'TA2', 'TA3', 'TA4']},
                          output_hidden_states=output_hidden_states,
                          return_dict=True)

        hidden_state = outputs.hidden_states[4]+outputs.hidden_states[5]+outputs.hidden_states[6]
        
        embeds = hidden_state[:, 0]/3.0

        embeds = self.pooler(embeds)
        
        if self.args.l2_normalize:
            embeds = F.normalize(embeds, p=2, dim=-1)
            
        return embeds.contiguous()

    @classmethod
    def build(cls, args: Arguments, **hf_kwargs):
        # load local
        if os.path.isdir(args.model_name_or_path):
            if not args.share_encoder:
                _qry_model_path = os.path.join(args.model_name_or_path, 'query_model')
                _psg_model_path = os.path.join(args.model_name_or_path, 'passage_model')
                if not os.path.exists(_qry_model_path):
                    _qry_model_path = args.model_name_or_path
                    _psg_model_path = args.model_name_or_path
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
            else:
                logger.info(f'loading shared model weight from {args.model_name_or_path}')
                lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        else:
            lm_q = AutoModel.from_pretrained(args.model_name_or_path, **hf_kwargs)
            lm_p = copy.deepcopy(lm_q) if not args.share_encoder else lm_q

        model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
        return model

    def save(self, output_dir: str):
        if not self.args.share_encoder:
            os.makedirs(os.path.join(output_dir, 'query_model'), exist_ok=True)
            os.makedirs(os.path.join(output_dir, 'passage_model'), exist_ok=True)
            self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
            self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
        else:
            self.lm_q.save_pretrained(output_dir)
        if self.args.add_pooler:
            torch.save(self.pooler.state_dict(), os.path.join(output_dir, 'pooler.pt'))


class BiencoderModelForInference(BiencoderModel):
    def __init__(self, args: Arguments,
                 lm_q: PreTrainedModel,
                 lm_p: PreTrainedModel):
        nn.Module.__init__(self)
        self.args = args
        self.lm_q = lm_q
        self.lm_p = lm_p
        self.pooler = nn.Linear(self.lm_q.config.hidden_size, args.out_dimension) if args.add_pooler else nn.Identity()

    @torch.no_grad()
    def forward(self, query: Dict[str, Tensor] = None,
                passage: Dict[str, Tensor] = None):
        q_reps = self._encode(self.lm_q, query)
        p_reps = self._encode(self.lm_p, passage)
        return BiencoderOutput(q_reps=q_reps, p_reps=p_reps)

    @classmethod
    def build(cls, args: Arguments, **hf_kwargs):
        model_name_or_path = args.model_name_or_path

        if os.path.isdir(model_name_or_path):
            _qry_model_path = os.path.join(model_name_or_path, 'query_model')
            _psg_model_path = os.path.join(model_name_or_path, 'passage_model')
            if os.path.exists(_qry_model_path):
                logger.info(f'found separate weight for query/passage encoders')
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = AutoModel.from_pretrained(_qry_model_path, **hf_kwargs)
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = AutoModel.from_pretrained(_psg_model_path, **hf_kwargs)
            else:
                logger.info(f'try loading tied weight')
                logger.info(f'loading model weight from {model_name_or_path}')
                lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        else:
            logger.info(f'try loading tied weight {model_name_or_path}')
            lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
            lm_p = lm_q

        model = cls(args=args, lm_q=lm_q, lm_p=lm_p)

        pooler_path = os.path.join(args.model_name_or_path, 'pooler.pt')
        if os.path.exists(pooler_path):
            logger.info('loading pooler weights from local files')
            state_dict = torch.load(pooler_path, map_location="cpu")
            model.pooler.load_state_dict(state_dict)
        else:
            assert not args.add_pooler
            logger.info('No pooler will be loaded')
        return model
