from typing import Any, Optional
from transformers.file_utils import ModelOutput
import numpy as np
import torch.nn.functional as F
import torch
        

class CustomModuleOutput(ModelOutput):
    """
    A custom variant of `CausalLMOutputWithCrossAttentions` that also stores the value predicted by a value head
    """
    # current_loss, current_logits, score_loss, reg_loss, loss_all
    current_loss: Optional[torch.FloatTensor] = None
    current_logits: torch.FloatTensor = None
    score_loss: Optional[torch.FloatTensor] = None
    reg_loss: Optional[torch.FloatTensor] = None

    
class CustomModel(torch.nn.Module):
    def __init__(self, privacy_type: str, model_dict: Optional[dict[str, Any]]=None):
        super(CustomModel,self).__init__()
        self.privacy_type = privacy_type
        self.backbone = model_dict['backbone']
        self.scorer = model_dict['scorer']
        self.reg = model_dict['reg']
        
        self.kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
        self.flag = True
            
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        logits_config: Optional[dict[str, Any]] = None):
        
        outputs = self.backbone(
            input_ids,
            attention_mask=attention_mask,
            labels=labels)
        current_loss, current_logits = outputs.loss, outputs.logits
        score_loss, reg_loss = 0.0, 0.0
        
        current_loss, current_logits = outputs.loss, outputs.logits
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        if 'ga' in self.privacy_type:
            
            if self.reg is not None:
                with torch.no_grad():
                    reg_outputs = self.reg(input_ids, attention_mask=attention_mask, labels=labels)
                    reg_logits = reg_outputs.logits
                reg_loss = self.kl_loss(F.log_softmax(current_logits, dim=-1), F.softmax(reg_logits, dim=-1))
                reg_loss = self.auto_scale(loss=current_loss, reg_loss=reg_loss)
            
            print_loss_flag = True
            if print_loss_flag:
                shift_logits = current_logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                
                # Flatten the tokens
                loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                print_loss = losses.view(input_ids.size(0), input_ids.size(1) - 1)
                
                # Convert the mask to float
                shift_mask_float = shift_attention_mask.float()

                # Mask the positions in print_loss where shift_mask == 0
                masked_print_loss = print_loss * shift_mask_float

                # Apply softmax for each row
                log_softmax_loss = torch.nn.functional.log_softmax(masked_print_loss, dim=1)

                final_loss = log_softmax_loss * shift_mask_float

                print(f'masked and log-softmax batch loss is: {final_loss}')
                
                # Convert the tensor to a numpy array
                numpy_array = final_loss.detach().numpy()
                non_zero_array = numpy_array[numpy_array != 0]*10

                # Save the numpy array to a text file
                np.savetxt("./ga_loss.txt", non_zero_array)
        
        elif 'ours' in self.privacy_type:
            # get next token to create new sequence
            shift_logits = current_logits[..., :-1, :].contiguous().view(-1, current_logits.size(-1))
            # softmax
            shift_logits_softmax = torch.nn.functional.softmax(shift_logits, dim=1)

            shift_inputs = input_ids[..., :-1].contiguous().view(-1, input_ids.size(-2))

            truth_labels = input_ids[..., 1:].contiguous().view(-1)

            # Utilize built-in topk() function to get the top k values and indices
            k = 5
            _, top_k_indices = torch.topk(shift_logits_softmax, k)
            
            def create_tensor(x):
                return torch.zeros(shift_logits_softmax.size(), device=device).type(x.dtype).scatter_(1, top_k_indices, x)

            # Create mask probability with top_k values and true labels
            mask_prob = create_tensor(shift_logits_softmax.gather(1, top_k_indices))
            mask_prob.scatter_(1, truth_labels.unsqueeze(1), shift_logits_softmax.gather(1, truth_labels.unsqueeze(1)))

            # Getting the non-zero probabilities
            non_zero_prob = mask_prob > 0
            non_zero_indices = torch.nonzero(non_zero_prob, as_tuple=True)

            # Extracting the token ids and corresponding probabilities
            token_ids = non_zero_indices[1]
            probs = mask_prob[non_zero_indices]

            rows = non_zero_indices[0] // shift_inputs.size(1)
            cols = non_zero_indices[0] % shift_inputs.size(1)

            # Extracting the sequences per index and appending current token id
            seqs = [shift_inputs[row, :col + 1] for row, col in zip(rows, cols)]
            seqs = [torch.cat((seq, torch.tensor([token_id], device=device))) for seq, token_id in zip(seqs, token_ids)]

            with torch.no_grad():
                new_seq_list = self.backbone.tokenizer.batch_decode(seqs, return_tensors='pt')
                # bad_prob_list = torch.tensor(self.scorer.score_texts(new_seq_list))
                bad_prob_list = torch.tensor(self.scorer.score_texts(new_seq_list), device=device)
            score_loss = (probs.to(device) * bad_prob_list).sum()
            
            print_loss_flag = True
            if print_loss_flag:
                sequences_prob_dict = {} 
                score_dict = []

                # Using one loop and add sequences and their probabilities together into one dictionary
                for row, seq, prob in zip(non_zero_indices[0], seqs, probs):
                    key = int(row.item())
                    value = (seq, float(prob.item()))

                    if key not in sequences_prob_dict:
                        sequences_prob_dict[key] = [value]
                    else:
                        sequences_prob_dict[key].append(value)

                # Now, let's do scoring
                sequences_decode_dict = {key: self.backbone.tokenizer.batch_decode(
                    [value[0] for value in values_list], skip_special_tokens=True, return_tensors='pt') for key, values_list in sequences_prob_dict.items()}

                with torch.no_grad():
                    sequences_bad_prob_dict = {key: torch.tensor(
                        self.scorer.score_texts(values_decode), device=device) for key, values_decode in sequences_decode_dict.items()}

                prob_dict = {key: [value[1] for value in sequences_prob_dict[key]] for key in sequences_prob_dict.keys()}
                score_dict = [sum(torch.tensor(prob_dict[key], device=device) * torch.tensor(sequences_bad_prob_dict[key], device=device)) for key in sequences_prob_dict.keys()]
                
                print_loss = torch.tensor(score_dict).view(-1, input_ids.size(1) - 1)
                
                # Convert the mask to float
                shift_mask_float = shift_attention_mask.float()

                # Mask the positions in print_loss where shift_mask == 0
                masked_print_loss = print_loss * shift_mask_float

                # Apply softmax for each row
                log_softmax_loss = torch.nn.functional.log_softmax(masked_print_loss, dim=1)

                final_loss = log_softmax_loss * shift_mask_float

                print(f'masked and log-softmax batch adaptive loss is: {final_loss}')

                # Convert the tensor to a numpy array
                numpy_array = final_loss.detach().numpy()
                non_zero_array = numpy_array[numpy_array != 0]

                # Save the numpy array to a text file
                np.savetxt("./ours_adaptive_loss.txt", non_zero_array)
                           
            # compute reg loss
            if self.reg is not None:
                with torch.no_grad():
                    reg_outputs = self.reg(input_ids, attention_mask=attention_mask, labels=labels)
                    reg_logits = reg_outputs.logits
                reg_loss = self.kl_loss(F.log_softmax(current_logits, dim=-1), F.softmax(reg_logits, dim=-1))
                reg_loss = self.auto_scale(loss=score_loss, reg_loss=reg_loss)
                
            print(f"forward score loss: {score_loss}, reg loss : {reg_loss}")
      
        elif 'dpo' in self.privacy_type:
            method1, method2 = True, False
            
            if method1:
                if self.flag:
                    # Determine the maximum and minimum values in 'shift_labels' to set the range for random perturbations
                    min_val, max_val = torch.min(input_ids), torch.max(input_ids)
                    min_perturb, max_perturb = max(-100, -min_val.item()), min(100, self.backbone.tokenizer.vocab_size - max_val.item() - 1)

                    # Generate random perturbations within the range [min_perturb, max_perturb]
                    perturbations = torch.randint(min_perturb, max_perturb + 1, input_ids[..., ...].shape, device=device)

                    # Perturb 'input_ids' and ensure the perturbed labels still fall within the valid range
                    self.pertur_labels = (input_ids[..., ...] + perturbations).clamp_(0, self.backbone.tokenizer.vocab_size - 1)
                    
                    self.flag = False
                new_outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, labels=self.pertur_labels)
                current_loss = new_outputs.loss
            elif method2: 
                if self.flag:
                    self.new_input_ids, self.new_attention_mask = self.generated_dpo(input_ids = input_ids, m = 10, n = 3)
                    self.flag = False
                # new_outputs = self.backbone(input_ids, attention_mask=attention_mask, labels=self.pertur_labels)
                new_outputs = self.backbone(input_ids=self.new_input_ids, attention_mask=self.new_attention_mask, labels=self.new_input_ids)
                current_loss = new_outputs.loss
            
            with torch.no_grad():
                reg_outputs = self.reg(input_ids, attention_mask=attention_mask, labels=labels)
                reg_logits = reg_outputs.logits
            reg_loss = self.kl_loss(F.log_softmax(current_logits, dim=-1), F.softmax(reg_logits, dim=-1))
            reg_loss = self.auto_scale(loss=current_loss, reg_loss=reg_loss)
            
            print(f"forward current loss: {current_loss}, reg loss : {reg_loss}")
            
        elif 'quark' in self.privacy_type:
            new_input_ids, new_attention_mask = self.generated_dpo(input_ids = input_ids, m = 10, n = 1)
            new_outputs = self.backbone(input_ids=new_input_ids, attention_mask=new_attention_mask, labels=new_input_ids)
            current_loss = new_outputs.loss
            
            with torch.no_grad():
                reg_outputs = self.reg(input_ids, attention_mask=attention_mask, labels=labels)
                reg_logits = reg_outputs.logits
            reg_loss = self.kl_loss(F.log_softmax(current_logits, dim=-1), F.softmax(reg_logits, dim=-1))
            reg_loss = self.auto_scale(loss=current_loss, reg_loss=reg_loss)
                
        return CustomModuleOutput(current_loss = current_loss, 
                                  current_logits = current_logits, 
                                  score_loss = score_loss, 
                                  reg_loss = reg_loss)
    
    def auto_scale(self, loss, reg_loss):
        # scale ratio
        scale_factor = 0.95  
        
        while reg_loss > loss * 0.1:
            reg_loss *= scale_factor
            
        return reg_loss
    
    def generated_dpo(self, input_ids, m, n):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Ensure that the input data is also on GPU
        input_ids = input_ids.to('cuda')

        # Decode the input sequence, concatenate the sentences, 
        # and encode the new sentence sequences
        new_input_ids = self.backbone.tokenizer.batch_encode_plus(
            ["Based on the following sentence, generate a safer and non-toxic new sentence: " + seq for seq in
             self.backbone.tokenizer.batch_decode(
                 input_ids, skip_special_tokens=True, return_tensors='pt')], 
                 max_length=512, padding='max_length', truncation=True, return_tensors="pt")

        # Move the new input data to GPU
        new_input_ids = new_input_ids.to('cuda')

        # Generate new samples under the condition of no gradient tracking
        with torch.no_grad():
            samples = self.backbone.generate(torch.tensor(new_input_ids['input_ids']).to(self.backbone.device), 
                                             pad_token_id=self.backbone.tokenizer.pad_token_id, 
                                             do_sample=True, 
                                             max_new_tokens=512, 
                                             num_return_sequences=m, temperature=0.9)[:, 512:]

        # Decode the generated samples
        decoded_samples = self.backbone.tokenizer.batch_decode(samples, skip_special_tokens=True, return_tensors='pt')

        # Score the decoded samples, get the indices of the n samples with the lowest scores and append them into one list
        scores = torch.tensor(self.scorer.score_texts(decoded_samples)).to('cuda').view(-1, m)
        _, indices = torch.topk(scores, n, dim=-1, largest=False)

        # Gather the top-n scored samples
        new_encode_samples = self.backbone.tokenizer.batch_encode_plus(
            decoded_samples, max_length=512, padding='max_length', truncation=True, return_tensors="pt")
        new_input_ids = torch.tensor(new_encode_samples['input_ids'], device=device)
        new_attention_mask = torch.tensor(new_encode_samples['attention_mask'], device=device)
        top_n_ids = torch.gather(new_input_ids.view(-1, m, new_input_ids.size(1)), 1, 
                                     torch.tensor(indices, device=device).unsqueeze(-1).expand(-1, -1, new_input_ids.size(1)))
        top_n_attention_mask = torch.gather(new_attention_mask.view(-1, m, new_attention_mask.size(1)), 1, 
                                                torch.tensor(indices, device=device).unsqueeze(-1).expand(-1, -1, new_attention_mask.size(1)))

        return top_n_ids, top_n_attention_mask