import logging
import math
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import numpy as np
import random

from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
from .modeling_gpt2 import GPT2PreTrainedModel


logger = logging.getLogger(__name__)

GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",
}

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class GPT2MultiHeadLMModel(GPT2PreTrainedModel):
    def __init__(self, config, transformer_input, n_facet_all = 2, n_facet=1, n_facet_hidden=1, 
                n_facet_window=0, n_facet_MLP=0, use_proj_bias=False, weight_mode = '', 
                only_commpute_loss=False, softmax_nonlinear='None', 
                efficient_mode='None', masking_ratio=-1, device=None, 
                last_num=0):

        super(GPT2MultiHeadLMModel, self).__init__(config)
        
        self.n_facet_all = n_facet_all
        self.n_facet_effective = n_facet

        self.n_facet = n_facet 
        self.n_facet_hidden = n_facet_hidden #0
        assert n_facet_MLP <= 0 #-1 or 0
        assert n_facet_window <= 0 # 0
        n_facet_window = - n_facet_window
        n_facet_MLP = - n_facet_MLP
        self.n_facet_MLP = n_facet_MLP
        self.n_facet_window = n_facet_window

        self.softmax_nonlinear=softmax_nonlinear
        self.efficient_mode = efficient_mode
        self.masking_ratio = masking_ratio
        self.only_commpute_loss = only_commpute_loss
        self.efficient_mode = efficient_mode
        
        # for multiple input hidden states
        if n_facet_MLP > 0:
            hidden_state_input_ratio = 1 + n_facet_MLP #1 + 1
            self.MLP_linear = nn.Linear(config.n_embd * (n_facet_hidden * (n_facet_window+1) ), config.n_embd * n_facet_MLP) # (hid_dim*2) -> (hid_dim)
            self.MLP_linear_l2 = nn.Linear(config.n_embd * n_facet_MLP, config.n_embd * n_facet_MLP)
        else:            
            hidden_state_input_ratio = n_facet_hidden * (n_facet_window+1) #1 * (0+1)
        print("hidden_state_input_ratio ", hidden_state_input_ratio)
        print("n_facet_all, n_facet, ", self.n_facet_all, self.n_facet)
        print("n_facet_hidden, n_facet_window, n_facet_MLP ", self.n_facet_hidden, self.n_facet_window, self.n_facet_MLP)
        print("Mode: ", self.efficient_mode)

        total_lin_dim = config.n_embd * hidden_state_input_ratio
        small_value = 0.0001
        
        self.project_arr = nn.ModuleList([nn.Linear(total_lin_dim, config.n_embd, bias=use_proj_bias) for i in range(n_facet_all)])
        for i in range(n_facet_all):
            if use_proj_bias:
                self.project_arr[i].bias.data.zero_()
            linear_weights = torch.zeros_like(self.project_arr[i].weight.data)

            # if i!= n_facet - 1:
            #     linear_weights = linear_weights + small_value * (torch.rand((config.n_embd, total_lin_dim)) - 0.5 )
            linear_weights[:,:config.n_embd] = torch.eye(config.n_embd)
            #if i < n_facet:
            #     linear_weights[:,:config.n_embd] = torch.eye(config.n_embd)
            # else:
            #     linear_weights[:,:config.n_embd] = 1e-10 * torch.eye(config.n_embd)
            self.project_arr[i].weight.data = linear_weights

        self.project_emb = nn.Linear(config.n_embd, config.n_embd, bias=use_proj_bias)

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        if len(weight_mode) > 0:
            self.weight_facet_decoder = nn.Linear(config.hidden_size * hidden_state_input_ratio, self.n_facet_effective)
            #self.weight_facet_decoder = nn.Linear(config.hidden_size * n_facet_hidden * (n_facet_window+1), n_facet)
            self.weight_global = nn.Parameter( torch.ones(self.n_facet_effective) )

        self.weight_mode = weight_mode
        self.transformer = transformer_input
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd
        self.output_probs = True
        self.c = 100

    def get_facet_emb(self,input_emb, i):
        return self.project_arr[i](input_emb)

    def get_output_embeddings(self):
        return self.lm_head

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if "past" in kwargs and kwargs["past"]:
            input_ids = input_ids[:, -1].unsqueeze(-1)
        inputs = {"input_ids": input_ids}
        inputs.update(kwargs)
        return inputs
    #add all_input_ids and past_hidden_states
    def forward(self, input_ids=None, past=None, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, inputs_embeds=None, labels=None, output_weight=True, eval_recon_top_k=None, 
                vis_simple=False, vis_seq_loss = False, exclude_neg_labels_for_MRR=False, all_input_ids = None, prev_hidden_states = None):

        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        all_hidden_states = transformer_outputs[2]
        #print('tuple size ', len(all_hidden_states), all_hidden_states[-1].size())
        # during evaluation, seq_len =1, so concatenate prev hidden states
        
        if labels is None and prev_hidden_states is not None and prev_hidden_states is not None:
            temp_tuple = tuple()
            for layer, _ in enumerate(prev_hidden_states):
                #print(layer, prev_hidden_states[layer].size(), all_hidden_states[layer].size())
                temp_tuple += (torch.cat((prev_hidden_states[layer], all_hidden_states[layer]), dim=1),)
            input_ids = torch.cat((all_input_ids, input_ids), dim=1)
            all_hidden_states = temp_tuple

        #insert extra token to input_ids
        device = all_hidden_states[0].device
        bsz, seq_len = input_ids.size()

        #check seq_len from hidden size

        ## Multi-input hidden states: generate q_ct from hidden states
        #list of hidden state embeddings taken as input
        hidden_emb_arr = []
        # h_facet_hidden -> H, n_face_window -> W, here 1 and 0
        for i in range(self.n_facet_hidden):
            hidden_states = all_hidden_states[-(i+1)] #i-th hidden-state embedding from the top
            device = hidden_states.device
            hidden_emb_arr.append(hidden_states)
            for j in range(self.n_facet_window):
                bsz, seq_len, hidden_size = hidden_states.size() #bsz -> , seq_len -> , hidden_size -> 768 in GPT-small?
                if j+1 < hidden_states.size(1):
                    shifted_hidden = torch.cat( (torch.zeros( (bsz, (j+1), hidden_size), device = device), hidden_states[:,:-(j+1),:]), dim = 1)
                else:
                    shifted_hidden = torch.zeros( (bsz, hidden_states.size(1), hidden_size), device = device)
                hidden_emb_arr.append(shifted_hidden)
        #hidden_emb_arr -> (W*H, bsz, seq_len, hidden_size)

        #n_facet_MLP -> 1
        if self.n_facet_MLP > 0:
            stacked_hidden_emb_raw_arr = torch.cat(hidden_emb_arr, dim=-1) #(bsz, seq_len, W*H*hidden_size)
            # self.MLP_linear = nn.Linear(config.n_embd * (n_facet_hidden * (n_facet_window+1) ), config.n_embd * n_facet_MLP) -> why +1?
            hidden_emb_MLP = self.MLP_linear(stacked_hidden_emb_raw_arr) #bsz, seq_len, hidden_size
            stacked_hidden_emb_arr = torch.cat([hidden_emb_arr[0], gelu(hidden_emb_MLP)], dim=-1) #bsz, seq_len, 2*hidden_size
        else:
            stacked_hidden_emb_arr = hidden_emb_arr[0]

        #list of linear projects per facet
        projected_emb_arr = []
        #list of final logits per facet
        facet_lm_logits_arr = []
        facet_lm_logits_real_arr = []

        #logits for orig facets
        if self.efficient_mode == 'even_last_2':
            bsz, seq_len, hidden_size = all_hidden_states[-1].size()
            logit_all = torch.empty( (bsz, seq_len, self.vocab_size) , device=all_hidden_states[-1].device )
            n_facet_not_last = self.n_facet_all - (self.n_facet_effective-1) # 6 - (3-1) = 4 -> partitions
            for i in range(n_facet_not_last):
                #projected_emb = self.project_arr[i](stacked_hidden_emb_arr)
                projected_emb = self.get_facet_emb(stacked_hidden_emb_arr,i) #bsz, seq_len, n_embd
                # stacked_hidden_emb_arr -> sz, seq_len, 2*hidden_size
                # same as project_arr? output_dim -> n_embd, 6 linear models, weights are zero for last one
                projected_emb_arr.append(projected_emb) #4 partitions
                #projected_emb_real_arr.append(projected_emb)
                #self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
                if self.lm_head.bias is None:
                    logit_all[:,:,i::n_facet_not_last] = F.linear(projected_emb, self.lm_head.weight[i::n_facet_not_last,:], None)
                else:
                    logit_all[:,:,i::n_facet_not_last] = F.linear(projected_emb, self.lm_head.weight[i::n_facet_not_last,:], self.lm_head.bias[i::n_facet_not_last])
            facet_lm_logits_arr.append(logit_all)
            #last two softmax, project_arr -> L^f
            for i in range(self.n_facet_effective-1):
                projected_emb = self.project_arr[-(i+1)](stacked_hidden_emb_arr)
                #projected_emb = self.project_arr[n_facet_not_last+i](stacked_hidden_emb_arr)
                #projected_emb = self.get_facet_emb(stacked_hidden_emb_arr,n_facet_not_last+i)
                projected_emb_arr.append(projected_emb)
                #projected_emb_real_arr.append(projected_emb)

                facet_lm_logits_arr.append( self.lm_head( projected_emb ) )
        else:
            for i in range(self.n_facet):
            #     #linear projection
                projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, i) #(bsz, seq_len, hidden_dim)
                projected_emb_arr.append(projected_emb) 
                #logits for all tokens in vocab
                lm_logits = self.lm_head(projected_emb) #(bsz, seq_len, vocab_size)
                facet_lm_logits_arr.append(lm_logits)
            

        #logits for n_facet (==n_facet_effective)
        for i in range(self.n_facet):       
            facet_lm_logits_real_arr.append( facet_lm_logits_arr[i] )

        with torch.no_grad():
            if not self.only_commpute_loss:
                stacked_facet_emb = torch.stack(projected_emb_arr, dim=0)
                stacked_facet_emb = stacked_facet_emb / (1e-12 + stacked_facet_emb.norm(dim = -1, keepdim=True))
                pred_mean = stacked_facet_emb.mean(dim = 0, keepdim = True)
                div_raw = (stacked_facet_emb - pred_mean).norm(dim = -1)
                emb_div_arr = - div_raw.mean(dim=1).mean(dim=0)
                emb_div = emb_div_arr.mean()
            if vis_seq_loss or vis_simple:
                seq_len = emb_div_arr.numel()
                num_token_removed = seq_len % self.num_vis_bin
                proper_seq_len = seq_len - num_token_removed
                emb_div_arr_vis = emb_div_arr[:proper_seq_len].view(self.num_vis_bin,-1).mean(dim=-1)            
            if self.masking_ratio > 0:
                num_facet, bsz, seq_len = div_raw.size()
                var_avg_flat = div_raw.mean(0).view(-1)
                var_avg = var_avg_flat.median()
                single_facet_mask = var_avg_flat < var_avg * self.masking_ratio
        
        stacked_facet_lm_logits = torch.stack(facet_lm_logits_arr, dim=0)

        #weight_mode = ''
        weight = None
        if self.weight_mode == 'dynamic':
            weight = self.weight_facet_decoder(stacked_hidden_emb_arr).softmax(dim=-1) #hidden_dim*hidden_input_state_ration -> n_facet_effective
        elif self.weight_mode == 'static':
            weight = self.weight_global.softmax(dim=-1) #torch.ones(n_facet_effective)
        #print(weight)
        prediction_prob = 0

        for i in range(self.n_facet_effective):
            facet_lm_logits = facet_lm_logits_real_arr[i]
            if self.softmax_nonlinear == 'sigsoftmax': #'None' here
                facet_lm_logits_sig = torch.exp(facet_lm_logits - facet_lm_logits.max(dim=-1,keepdim=True)[0]) * (1e-20 + torch.sigmoid(facet_lm_logits))
                facet_lm_logits_softmax = facet_lm_logits_sig / facet_lm_logits_sig.sum(dim=-1,keepdim=True)
            elif self.softmax_nonlinear == 'None':
                facet_lm_logits_softmax = facet_lm_logits.softmax(dim=-1) #softmax over final logits
            if self.weight_mode == 'dynamic':
                prediction_prob += facet_lm_logits_softmax * weight[:,:,i].unsqueeze(-1)
            elif self.weight_mode == 'static':
                prediction_prob += facet_lm_logits_softmax * weight[i]
            else:
                prediction_prob += facet_lm_logits_softmax / self.n_facet_effective #softmax over final logits/1

        outputs = (prediction_prob,) + (stacked_facet_lm_logits, ) + transformer_outputs[1:]
        # outputs = (lm_logits,) + transformer_outputs[1:]
        if labels is not None:
            # Shift so that tokens < n predict n            
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens

            # shift_logits = lm_logits[..., :-1, :].contiguous()
            # loss_fct = CrossEntropyLoss(ignore_index = -100)
            # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            loss_fct = torch.nn.NLLLoss(reduction='none')
            shift_prediction_prob = prediction_prob[..., :-1, :].contiguous()
            shift_labels_flat = shift_labels.view(-1)
            loss_raw = loss_fct(torch.log(shift_prediction_prob.view(-1, self.vocab_size)+1e-8), shift_labels_flat)
            loss = loss_raw[shift_labels_flat != -100].mean()

            dist_to_period = None
            top_val_single = None
            top_idx_single = None
            if vis_seq_loss or vis_simple:
                with torch.no_grad():
                    bsz, seq_len, vocab_size = shift_prediction_prob.size()
                    if exclude_neg_labels_for_MRR:
                        shift_labels_MRR = shift_labels.view(-1)
                        good_label_mask = shift_labels_MRR >= 0
                        shift_prediction_prob_MRR = shift_prediction_prob.view(-1,vocab_size)[good_label_mask,:]
                        gt_prob_MRR = torch.gather(shift_prediction_prob_MRR,  dim=-1 , index = shift_labels_MRR[good_label_mask].unsqueeze(dim=-1))
                        gt_rank_MRR = (gt_prob_MRR.expand_as(shift_prediction_prob_MRR) <= shift_prediction_prob_MRR).type(torch.long).sum(dim = -1)
                        seq_len_small = gt_rank_MRR.size(0)

                    else:
                        gt_prob_MRR = torch.gather(shift_prediction_prob,  dim=-1 , index = shift_labels.unsqueeze(dim=-1))
                        gt_rank_MRR = (gt_prob_MRR.expand_as(shift_prediction_prob) <= shift_prediction_prob).type(torch.long).sum(dim = -1)
                        seq_len_small = seq_len
                    num_token_removed = seq_len_small % self.num_vis_bin_loss
                    proper_seq_len = seq_len_small - num_token_removed
                    MRR_raw = (1 / gt_rank_MRR.type(torch.float))
                    if not exclude_neg_labels_for_MRR:
                        MRR_seq = MRR_raw.mean(dim=0)
                    else:
                        MRR_seq = MRR_raw
                    MRR_seq_vis = MRR_seq[:proper_seq_len].view(self.num_vis_bin_loss,-1).mean(dim=-1)
                    MRR_raw = MRR_raw.view(-1)
                    
                    num_token_removed = seq_len % self.num_vis_bin_loss
                    proper_seq_len = seq_len - num_token_removed
                    
                    loss_seq = loss_raw.view(bsz, seq_len).mean(dim=0)
                    loss_seq_vis = loss_seq[:proper_seq_len].view(self.num_vis_bin_loss,-1).mean(dim=-1)
                    
                    if self.period_idx > 0:
                        dist_to_period = torch.zeros( (bsz, seq_len), device = shift_labels.device, dtype = torch.long )
                        for i in range(bsz):
                            period_position = (shift_labels[i,:] == self.period_idx).nonzero(as_tuple=True)[0]
                            num_period = period_position.numel()
                            if num_period > 0:
                                diff_pos = period_position.unsqueeze(dim=-1).expand(num_period, seq_len) - torch.arange(seq_len, device=period_position.device).unsqueeze(dim=0).expand(num_period, seq_len)
                                relative_pos = torch.abs( diff_pos )
                                #dist_to_period[i] = relative_pos.min(dim=0)[0]
                                dist_to_period[i] = torch.gather(diff_pos,0,relative_pos.min(dim=0,keepdim=True)[1])
            if vis_seq_loss:
                with torch.no_grad():
                    div_weighted_sq_raw = None
                    facet_norm = None
                    collapse_difference = None
                    collapse_difference_val = None
                    collapse_difference_inv = None
                    stacked_facet_emb_n = stacked_facet_emb

                    weighted_facet_first = weight.permute(2,0,1) # bsz, seq_len, facet -> facet, bsz, seq_len
                    stacked_facet_emb_norm = stacked_facet_emb_n / stacked_facet_emb_n.norm(dim=-1, keepdim=True)
                    pred_mean_weighted = (stacked_facet_emb_norm * weighted_facet_first.unsqueeze(dim=-1)).sum(dim=0) / weight.sum(dim=-1).unsqueeze(dim=-1)
                    div_norm_weighted = (stacked_facet_emb_norm - pred_mean_weighted).norm(dim = -1) # facet, bsz, seq_len
                    div_weighted_sq_raw = (div_norm_weighted*div_norm_weighted * weighted_facet_first).sum(dim=0) # bsz, seq_len
                    if self.n_facet > 1:
                        collapse_single_prob = self.lm_head(pred_mean_weighted).softmax(dim=-1) #bsz, seq_len, vocab
                        top_val_single, top_idx_single = torch.topk(collapse_single_prob, eval_recon_top_k, dim=-1)
                        top_val_before, top_idx_before = torch.topk(prediction_prob, eval_recon_top_k, dim=-1)
                        #top_val_org = torch.gather(prediction_prob, dim=-1 , index = top_idx)
                        top_val_now = torch.gather(prediction_prob, dim=-1 , index = top_idx_single)
                        top_val_new = torch.gather(collapse_single_prob, dim=-1 , index = top_idx_before)
                        collapse_difference_val = top_val_before.sum(dim = -1) -  top_val_now.sum(dim = -1) #torch.abs(top_val - top_val_now).sum(dim = -1)
                        collapse_difference = top_val_now.sum(dim = -1) / top_val_before.sum(dim = -1)
                        collapse_difference_inv = top_val_new.sum(dim = -1) / top_val_single.sum(dim = -1)
                    else:
                        facet_norm = projected_emb_arr[0].norm(dim=-1)

            if not self.only_commpute_loss:
                with torch.no_grad():
                    lm_logits_max, lm_max_idx = torch.max(stacked_facet_lm_logits.softmax(dim=-1), dim=0)
                    count_best_arr = torch.zeros( (1, self.n_facet_effective), device = lm_max_idx.device)
                    shift_labels = input_ids[..., 1:].contiguous()
                    best_idx = torch.gather(lm_max_idx[..., :-1, :], dim=-1 , index = shift_labels.unsqueeze(dim=-1))
                    have_label_best_idx = best_idx.squeeze(dim=-1)
                    for i in range(self.n_facet_effective):
                        count = torch.sum(have_label_best_idx == i)
                        count_best_arr[0,i] = count

            outputs = (loss,) + outputs
        #print(self.only_commpute_loss, vis_simple, vis_seq_loss, output_weight, labels)
        #count_best_arr = None   
        if self.only_commpute_loss:
            return outputs
        elif vis_simple:
            return outputs, emb_div, count_best_arr, weight, emb_div_arr_vis.view(1,-1), loss_seq_vis.view(1,-1), MRR_seq_vis.view(1,-1), loss_raw, MRR_raw, div_raw.mean(dim=0)

        elif vis_seq_loss:
            return outputs, emb_div, count_best_arr, weight, emb_div_arr_vis.view(1,-1), loss_seq_vis.view(1,-1), MRR_seq_vis.view(1,-1), loss_raw, MRR_raw, div_raw.mean(dim=0), dist_to_period.view(-1), div_weighted_sq_raw, collapse_difference, collapse_difference_inv, collapse_difference_val, facet_norm, top_val_single, top_idx_single
        elif output_weight:
            return outputs, emb_div, count_best_arr, weight
        else:
            return outputs, emb_div, count_best_arr
