import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,T5ForConditionalGeneration, pipeline
from huggingface_hub import login
import sys

class LLMBackbone(nn.Module):
    def __init__(self, config):
        super(LLMBackbone, self).__init__()
        self.config = config
        
        access_token = "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" # huggingface token
        # model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
        # model_path = "google/flan-t5-xxl"
        print('='*77)
        print('model_path: ', self.config.model_path)
        print('='*77)
        if 'llama' in self.config.model_path.lower():
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
            self.engine = AutoModelForCausalLM.from_pretrained(self.config.model_path, 
                                                               quantization_config=bnb_config, 
                                                               device_map="auto", 
                                                               torch_dtype=torch.float16)
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) 
            # self.tokenizer.pad_token = "[PAD]"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.padding_side = "left"
        elif 't5' in self.config.model_path.lower():
            if 'xxl' in config.model_path.lower():
                self.engine = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", load_in_8bit=True)
            else:
                self.engine = T5ForConditionalGeneration.from_pretrained(config.model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(config.model_path)

        self.num_extraction_error = 0
        self.new_answers = 0
        self.unique_answers = {}

    def forward(self, **kwargs):
        input_ids, input_masks, output_ids, output_masks = [kwargs[w] for w in '\
        input_ids, input_masks, output_ids, output_masks'.strip().split(', ')]
        output_ids[output_ids[:, :] == self.tokenizer.pad_token_id] = -100
        output = self.engine(input_ids, attention_mask=input_masks, decoder_input_ids=None,
                             decoder_attention_mask=output_masks, labels=output_ids)
        loss = output[0]
        return loss

    def generate(self, **kwargs):
        input_ids, input_masks = [kwargs[w] for w in '\
        input_ids, input_masks'.strip().split(', ')]
        output = self.engine.generate(input_ids=input_ids, attention_mask=input_masks,
                                      max_length=self.config.max_length)
        dec = [self.tokenizer.decode(ids) for ids in output]
        output = [context.replace('<pad>', '').replace('</s>', '').strip() for context in dec]
        return output

    def evaluate(self, **kwargs):
        input_ids, input_masks = [kwargs[w] for w in '\
        input_ids, input_masks'.strip().split(', ')]
        print('len(input_ids[0]): ', len(input_ids[0]))
        
        output = self.engine.generate(input_ids=input_ids, attention_mask=input_masks, max_new_tokens = 5)
        dec = [self.tokenizer.decode(ids) for ids in output]
        label_dict = {w: i for i, w in enumerate(self.config.label_list)}
        output = []

        targets_ = [target_word.lower() for target_word in kwargs['target']]
        true_labels = kwargs['input_labels']

        x_analysis = []
        converted_pred_analysis = []
        true_answer = []
        
        for ii, w in enumerate(dec):
            if ii % 10 == 0:
                print('='*77)
                print('new_answers: ', self.new_answers)
                print('self.unique_answers: ', self.unique_answers)
                print('='*77)
            # x = w.replace('<pad>', '').replace('</s>', '').replace('<s>', '').replace('<unk>', '').strip().lower()
            x = w.replace('<pad>', '').replace('</s>', '').strip().lower()
            if x not in self.unique_answers.keys():
                self.unique_answers[x] = 1
            else:
                self.unique_answers[x] += 1
                
            x_analysis.append(x)
            true_answer.append(int(true_labels[ii]))
            ## Flan-T5
            # print('x: ', x)
            if 'positive than neutral' in x:
                print('='*77)
                print('x: ', x)
                print('='*77)
                self.new_answers += 1
            fs = x.split('. ')[0].split('than')[0]
            ## Llama
            # fs = x.split('verifiable evidence in your reasoning')[1].split('answer:')[1].split('than')[0]
            # fs = x.split('verifiable evidence')[1].split('answer:')[1].split('than')[0].split('feedback')[0].split('reasoning')[0]
            # fs = x.split('test 4')[1].split('test 5:')[0].split('feedback')[0].split('reasoning')[0].split('is:')[1].split('than')[0]

            found_answer = False
            phrase_list_neutral = ['neutral',
                                   'mixed',
                                   'neutral than',
                                   '**neutral**',
                                   'answer: neutral', 
                                   'answer is: neutral',
                                   'answer: \nneutral',
                                   'answer:\nneutral',
                                   'answer to: neutral',
                                   'a5',
                                   'a6',
                                   f'sentiment towards the target "{targets_[ii]}" is neutral.',
                                   f'sentiment polarity towards the target "{targets_[ii]}" is neutral',
                                   f'sentiment polarity is neutral towards the target "{targets_[ii]}"',
                                   f'sentiment polarity towards the target "{targets_[ii]}" is **neutral**',
                                   f'sentiment polarity towards the target ({targets_[ii]}) is neutral.',
                                   f'sentiment polarity towards the {targets_[ii]} is neutral.',
                                   'therefore, the sentiment polarity is neutral.',
                                   '**neutral**',
                                   'therefore, the sentiment polarity is neutral',
                                   f'neutral sentiment towards the target ({targets_[ii]})',
                                   f'neutral sentiment towards the target ({targets_[ii]})',
                                   f'sentiment towards the target "{targets_[ii]}" is also neutral.',
                                    'neutral sentiment towards the target'
                                    ]
            phrase_list_positive = ['positive',
                                    'positive than',
                                    '**positive**',
                                    'strongly positive',
                                    'slightly positive',
                                    'answer: positive', 
                                    'answer: strongly positive',
                                    'answer: slightly positive',
                                    'answer: \npositive',
                                    'answer:\npositive',
                                    'answer to: positive',
                                    'answer to: slightly positive',
                                    'answer is: positive', 
                                    'a1',
                                    'a2',
                                    f'sentiment towards the target "{targets_[ii]}" is positive.',
                                    f'sentiment polarity towards the target "{targets_[ii]}" is positive',
                                    f'sentiment polarity is positive towards the target "{targets_[ii]}".',
                                    f'sentiment polarity towards the target "{targets_[ii]}" is **positive**',
                                    f'sentiment polarity towards the target ({targets_[ii]}) is positive.',
                                    f'sentiment polarity towards the {targets_[ii]} is positive.',
                                    'therefore, the sentiment polarity is positive.',
                                    '**positive**',
                                    'therefore, the sentiment polarity is positive',
                                    f'positive sentiment towards the target ({targets_[ii]})',
                                    f'positive sentiment towards the target ({targets_[ii]})',
                                    f'sentiment towards the target "{targets_[ii]}" is also positive.',
                                    'positive sentiment towards the target']
            phrase_list_negative = ['negative',
                                    'negative than',
                                    '**negative**',
                                    'strongly negative',
                                    'slightly negative',
                                    'answer: negative', 
                                    'answer: strongly negative',
                                    'answer: slightly negative',
                                    'answer: \nnegative',
                                    'answer:\nnegative',
                                    'answer to: negative',
                                    'answer to: slightly negative',
                                    'answer is: negative', 
                                    'a3',
                                    'a4',
                                    f'sentiment towards the target "{targets_[ii]}" is negative.',
                                    f'sentiment polarity towards the target "{targets_[ii]}" is negative',
                                    f'sentiment polarity is negative towards the target "{targets_[ii]}".',
                                    f'sentiment polarity towards the target "{targets_[ii]}" is **negative**',
                                    f'sentiment polarity towards the target ({targets_[ii]}) is negative.',
                                    f'sentiment polarity towards the {targets_[ii]} is negative.',
                                    'therefore, the sentiment polarity is negative.',
                                    '**negative**',
                                    'therefore, the sentiment polarity is negative',
                                    f'negative sentiment towards the target ({targets_[ii]})',
                                    f'negative sentiment towards the {targets_[ii]}'
                                    f'sentiment towards the target "{targets_[ii]}" is also negative.',
                                    'negative sentiment towards the target'
                                    ]
            for word in phrase_list_neutral:
                if word in fs:
                    pred = 'neutral'
                    found_answer = True
                    break

            if not found_answer:
                for word in phrase_list_positive:
                    if word in fs:
                        pred = 'positive'
                        found_answer = True
                        break
            if not found_answer:
                for word in phrase_list_negative:
                    if word in fs:
                        pred = 'negative'
                        found_answer = True
                        break
                    
            if not found_answer:
                print('='*77)
                print('Error in answer extraction!')
                pred = 'neutral'
                print('fs: ')
                print(fs)
                self.num_extraction_error += 1
                print('='*77)
                print('phrase_list_positive: ', phrase_list_positive)
                
            output.append(label_dict.get(pred, 0))
            converted_pred_analysis.append(label_dict.get(pred, 0))

            if true_labels[ii] != label_dict.get(pred, 0):
                print('='*77)
                print('Error analysis')
                print('True label: ', true_labels[ii])
                print('Pred: ', pred)
                print('fs:')
                print(fs)
                    # output = [label_dict.get(w.replace('<pad>', '').replace('</s>', '').strip().lower().split('the answer is ')[1].replace('.', ''), 0) for w in dec]
        print('self.num_extraction_error: ', self.num_extraction_error)
        return output, x_analysis, converted_pred_analysis, true_answer
