import argparse
import yaml
import torch
from attrdict import AttrDict
import pandas as pd

from src.utils import set_seed, load_params_LLM
from src.loader import MyDataLoader
from src.model import LLMBackbone
from src.engine import PromptTrainer

import time


class Template:
    def __init__(self, args):
        config = AttrDict(yaml.load(open(args.config, 'r', encoding='utf-8'), Loader=yaml.FullLoader))
        names = []
        print('config: ')
        print(config)
        for k, v in vars(args).items():
            setattr(config, k, v)
        config.dataname = config.data_name
        if config.dataname == 'agnews':
            config.label_list = ['world', 'sports', 'business', 'sci/tech']
        config.prompt_file = args.prompt_file
        print('config.seed: ', config.seed)
        set_seed(config.seed)

        config.device = torch.device('cuda:{}'.format(config.cuda_index) if torch.cuda.is_available() else 'cpu')
        names = [config.model_size, config.dataname] + names
        config.save_name = '_'.join(list(map(str, names))) + '_{}.pth.tar'
        self.config = config

    def forward(self):
        if self.config.dataname in ['laptops', 'restaurants'] and self.config.reasoning == 'prompt_few_shot':
            (self.trainLoader, self.validLoader, self.testLoader), self.config = MyDataLoader(self.config).get_data_few_shot()
        elif self.config.dataname in ['laptops', 'restaurants']:
            (self.trainLoader, self.validLoader, self.testLoader), self.config = MyDataLoader(self.config).get_data()

        if self.config.dataname in ['agnews'] and self.config.reasoning == 'prompt_few_shot':
            (self.trainLoader, self.testLoader), self.config = MyDataLoader_agnews(self.config).get_data_few_shot()

        if self.config.dataname in ['laptops', 'restaurants']:
            self.model = LLMBackbone(config=self.config).to(self.config.device)
            self.config = load_params_LLM(self.config, self.model, self.trainLoader)
        elif self.config.dataname in ['agnews']:
            self.model = LLMBackbone_agnews(config=self.config).to(self.config.device)
            self.config = load_params_LLM(self.config, self.model, self.trainLoader)

        print(f"Running on the {self.config.data_name} data.")
        if self.config.reasoning in ['prompt', 'prompt_twice', 'prompt_few_shot']:
            print("Choosing prompt one-step infer mode.")
            if self.config.dataname in ['laptops', 'restaurants']:
                trainer = PromptTrainer(self.model, self.config, self.trainLoader, self.validLoader, self.testLoader)
            elif self.config.dataname == 'agnews':
                trainer = PromptTrainer_agnews(self.model, self.config, self.trainLoader, self.testLoader)
        else:
            raise 'Should choose a correct reasoning mode: prompt.'

        if self.config.zero_shot == True:
            print("Zero-shot mode for evaluation.")
            time_0 = time.time()

            r = trainer.evaluate_step(self.testLoader, 'test')
            # r = trainer.evaluate_step(self.trainLoader, 'test')
            print(r)
            print('-'*77)
            print('Time took: ', time.time() - time_0)
            print('-'*77)
            return

        print("Fine-tuning mode for training.")
        trainer.train()
        lines = trainer.lines

        df = pd.DataFrame(lines)
        print(df.to_string())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--cuda_index', default=0)
    parser.add_argument('-r', '--reasoning', default='prompt_few_shot', choices=['prompt', 'prompt_few_shot'],
                        help='with one-step prompt or multi-step reasoning')
    parser.add_argument('-z', '--zero_shot', action='store_true', default=True,
                        help='running under zero-shot mode or fine-tune mode')
    parser.add_argument('-d', '--data_name', default='laptops', choices=['restaurants', 'laptops', 'agnews'],
                        help='semeval data name')
    parser.add_argument('-f', '--config', default='./config/config.yaml', help='config file')
    parser.add_argument('--seed', default=42, help='seed', type=int)
    parser.add_argument('--add_phrase', default=None, type=str)
    parser.add_argument('--few_shot_example_indices',
                        type=lambda s: [int(item) for item in s.split(',')],
                        default = None,
                        help='few_shot_exampe_indices')
    parser.add_argument('--prompt_file', default='./src/few_shot_prompts.txt', type=str)
    print('parser: ', parser)
    args = parser.parse_args()
    print('args: ', args)
    template = Template(args)

    print('='*77)
    print('prompt used: ')
    with open(args.prompt_file) as f:
        instruction_prompt = f.read()
        print(instruction_prompt)
    print('='*77)
    template.forward()
