import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,T5ForConditionalGeneration, pipeline
from huggingface_hub import login
import sys

class LLMBackbone_agnews(nn.Module):
    def __init__(self, config):
        super(LLMBackbone_agnews, self).__init__()
        self.config = config
        
        access_token = "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" # huggingface token
        # model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
        # model_path = "google/flan-t5-xxl"
        print('='*77)
        print('model_path: ', self.config.model_path)
        print('='*77)
        if 'llama' in self.config.model_path.lower():
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
            self.engine = AutoModelForCausalLM.from_pretrained(self.config.model_path, 
                                                               quantization_config=bnb_config, 
                                                               device_map="auto", 
                                                               torch_dtype=torch.float16)
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) 
            # self.tokenizer.pad_token = "[PAD]"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.padding_side = "left"
        elif 't5' in self.config.model_path.lower():
            if 'xxl' in config.model_path.lower():
                self.engine = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", load_in_8bit=True)
            else:
                self.engine = T5ForConditionalGeneration.from_pretrained(config.model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)

        self.num_extraction_error = 0

    def forward(self, **kwargs):
        input_ids, input_masks, output_ids, output_masks = [kwargs[w] for w in '\
        input_ids, input_masks, output_ids, output_masks'.strip().split(', ')]
        output_ids[output_ids[:, :] == self.tokenizer.pad_token_id] = -100
        output = self.engine(input_ids, attention_mask=input_masks, decoder_input_ids=None,
                             decoder_attention_mask=output_masks, labels=output_ids)
        loss = output[0]
        return loss

    def generate(self, **kwargs):
        input_ids, input_masks = [kwargs[w] for w in '\
        input_ids, input_masks'.strip().split(', ')]
        output = self.engine.generate(input_ids=input_ids, attention_mask=input_masks,
                                      max_length=self.config.max_length)
        dec = [self.tokenizer.decode(ids) for ids in output]
        output = [context.replace('<pad>', '').replace('</s>', '').strip() for context in dec]
        return output

    def evaluate(self, **kwargs):
        input_ids, input_masks = [kwargs[w] for w in '\
        input_ids, input_masks'.strip().split(', ')]
        print('len(input_ids[0]): ', len(input_ids[0]))
        
        output = self.engine.generate(input_ids=input_ids, attention_mask=input_masks, max_new_tokens = 5)
        dec = [self.tokenizer.decode(ids) for ids in output]
        label_dict = {w: i for i, w in enumerate(self.config.label_list)}
        output = []
        true_labels = kwargs['input_labels']
        
        for ii, w in enumerate(dec):
            # x = w.replace('<pad>', '').replace('</s>', '').replace('<s>', '').replace('<unk>', '').strip().lower()
            x = w.replace('<pad>', '').replace('</s>', '').strip().lower()
            fs = x.split('than')[0]
            # fs = x.split('verifiable evidence')[1].split('is:')[1].split('feedback')[0].split('reasoning')[0].split('than')[0]
            # fs = x.split('test 4')[1].split('test 5:')[0].split('feedback')[0].split('reasoning')[0].split('is:')[1].split('than')[0]

            found_answer = False
            phrase_list_world = ['world']
            phrase_list_business = ['business']
            phrase_list_sports = ['sports']
            phrase_list_scitech = ['sci/tech', 'tech/sci', 'science and technology', 'science/technology', 'technology/science', 'sci']
            
            for word in phrase_list_world:
                if word in fs:
                    pred = 'world'
                    found_answer = True
                    break

            if not found_answer:
                for word in phrase_list_business:
                    if word in fs:
                        pred = 'business'
                        found_answer = True
                        break
                        
            if not found_answer:
                for word in phrase_list_sports:
                    if word in fs:
                        pred = 'sports'
                        found_answer = True
                        break

            if not found_answer:
                for word in phrase_list_scitech:
                    if word in fs:
                        pred = 'sci/tech'
                        found_answer = True
                        break
                    
            if not found_answer:
                print('='*77)
                print('Error in answer extraction!')
                pred = 'world'
                print('fs: ')
                print(fs)
                self.num_extraction_error += 1
                print('='*77)
                print('phrase_list_world: ', phrase_list_world)
                
            output.append(label_dict.get(pred, 0))

            if true_labels[ii] != label_dict.get(pred, 0):
                print('='*77)
                print('Error analysis')
                print('True label: ', true_labels[ii])
                print('Pred: ', pred)
                print('fs:')
                print(fs)
                    # output = [label_dict.get(w.replace('<pad>', '').replace('</s>', '').strip().lower().split('the answer is ')[1].replace('.', ''), 0) for w in dec]

        print('self.num_extraction_error: ', self.num_extraction_error)
        return output
