import numpy as np 
import torch
from torch import nn

class SelfSupervisedModule():
    def __init__(self, self_supervised_gamma):

        from cocoNLP.config.phrase import rake 
        self.raker = rake.Rake()
        self.self_supervised_gamma = self_supervised_gamma
        self.fct = nn.CrossEntropyLoss()


    def self_supervised_logits(self, batch_size, labels, chosen_logits, feature_num=32000):
        """
        put this function in concatenated_forward function.
        """
        # get labels which are used to local positions of answers in logits
        chosen_labels, _ = labels.split(batch_size, dim=0)

        # get mask indexs
        mask_indexs = self.extract_keywords_and_masked(chosen_logits, self.tokenizer, chosen_labels)

        self_supervised_logits = list()
        self_supervised_labels = list()

        for i in range(len(mask_indexs)):
            if len(mask_indexs[i]) != 0:
                
                # equal sampling probabiliy for selecting number for masking logits
                mask_num = np.random.choice([1, 2, 3], 1)

                # check length of keywords and mask number
                while mask_num[0] > len(mask_indexs[i]):    
                    mask_num = np.random.choice([1, 2, 3], 1)

                indices = list()
                for mask_id in range(mask_num[0]):
                    mask_start, mask_end = mask_indexs[i][mask_id]

                    for j in range(chosen_logits.size(1)):
                        if j < mask_start or j >= mask_end:
                            indices.append(torch.tensor(j))

                indices = torch.stack(indices).to(chosen_logits.device)
                masked_chosen_logits = torch.index_select(chosen_logits[i], 0, indices)
                logits_tmp = masked_chosen_logits.sum(0).view(-1, feature_num)#.type(torch.bfloat16)

                # self-supervised module forward
                lm_head_logits = model.module.post_module(logits_tmp)
                mask_num -= 1

                self_supervised_logits.append(lm_head_logits.squeeze(0))
                self_supervised_labels.append(mask_num)

            else:
                mask_num = np.random.choice([1, 2, 3], 1)
                mask_num -= 1

                pseudo_logits = F.one_hot(torch.tensor(mask_num[0]), 3).to(chosen_logits.device)
                self_supervised_logits.append(pseudo_logits.float())
                self_supervised_labels.append(mask_num)

            
        return torch.stack(self_supervised_logits), self_supervised_labels


    def add_self_supervised_loss(self, losses, self_supervised_labels, self_supervised_logits):
        """
        put this function in get_batch_loss_metrics function.
        """
        self_supervised_labels = torch.LongTensor(np.array(self_supervised_labels)).to(llm_loss.device)
        losses += self.self_supervised_gamma * self.fct(self_supervised_logits, self_supervised_labels.squeeze(1))
        return losses


    def extract_keywords_and_masked(self, chosen_logits, llm_tokenizer, labels, min_len=5, max_len=20):
            """
            Extract keywords from logits
            Args:
                seq: sequence transfered from output logits of llm.
                min_len: minmium length of keywords.
                max_len: maxmium length of keywords.
            Return:
                Masked text
            """

            # convert logits to text
            chosen_tokens = torch.max(chosen_logits, dim=-1)[1] # [1, 360]

            chosen_ans_tokens = list()
            for i in range(chosen_tokens.shape[0]):
                chosen_ans_tokens.append(torch.masked_select(chosen_tokens[i], labels[i].ne(-100)))

            chosen_txt = llm_tokenizer.batch_decode(chosen_ans_tokens)
            #chosen_txt = str(llm_tokenizer.batch_decode(chosen_tokens[labels != -100]))

            mask_indexs_final = list()
            for i in range(chosen_tokens.shape[0]):

                # extract keywords based on text from llm
                self.raker.extract_keywords_from_sentences([chosen_txt[i]], min_len, max_len)
                ranked_words_score = self.raker.get_ranked_phrases_with_scores()

                mask_indexs = list()
                for ranked_res in ranked_words_score:
                    key_words = "".join(ranked_res[1].split("  "))

                    masked_tokens = self.tokenizer.encode(key_words, add_special_tokens=False)

                    mask_start, mask_end = self.align_index(chosen_ans_tokens[i], masked_tokens)

                    if mask_start is not None:
                        mask_indexs.append((mask_start, mask_end))

                mask_indexs_final.append(mask_indexs)

            return mask_indexs_final
