import os
import torch
import numpy as np
import torch.nn as nn

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from collections import defaultdict
from src.utils import prompt_for_opinion_inferring, prompt_for_polarity_inferring, prompt_for_polarity_label
import json


class PromptTrainer_agnews:
    def __init__(self, model, config, train_loader, test_loader) -> None:
        self.model = model
        self.config = config
        self.train_loader, self.test_loader = train_loader, test_loader
        self.save_name = os.path.join(config.target_dir, config.save_name)
        self.final_score = 0
        self.final_res = ''

        self.scores, self.lines = [], []
        self.re_init()

    def evaluate_step(self, dataLoader=None, mode='valid'):
        self.model.eval()
        dataiter = dataLoader
        print('dataLoader.data_length: ', dataLoader.data_length)
        for i, data in tqdm(enumerate(dataiter), total=dataLoader.data_length):
            # if i == 0:
            #     print("data['input_ids'].size(): ", data['input_ids'].size())
            #     print("data['input_ids'][0][:20]: ", data['input_ids'][0][:20])
            #     print("data['input_ids'][1][:20]: ", data['input_ids'][1][:20])
            #     print("data['input_ids'][2][:20]: ", data['input_ids'][2][:20])
            #     print('='*77)
            # for key in data.keys():
            #     print('key: ', key)
            #     print(data[key])
            #     print('-'*77)
            with torch.no_grad():
                output = self.model.evaluate(**data)
                self.add_output_agnews(data, output)
            result = self.report_score_agnews(mode=mode)
            print('='*77)
            print(result)
            print('='*77)
        result = self.report_score_agnews(mode=mode)
        print('-'*77)
        print(f"Gold Labels: World: {self.golds['total'].count(0)}, Sports: {self.golds['total'].count(1)}, Business: {self.golds['total'].count(2)}, Scitech: {self.golds['total'].count(3)}")
        print(f"Pred. Labels: World: {self.preds['total'].count(0)}, Sports: {self.preds['total'].count(1)}, Business: {self.preds['total'].count(2)}, Scitech: {self.preds['total'].count(3)}")
        print('-'*77)
        # with open('./data/gold.json','w') as f:
        #     json.dump(self.golds, f)

        # with open('./data/pred.json','w') as f:
        #     json.dump(self.preds, f)
        
        return result

    def final_evaluate(self, epoch=0):
        PATH = self.save_name.format(epoch)
        self.model.load_state_dict(torch.load(PATH, map_location=self.config.device)['model'])
        self.model.eval()
        res = self.evaluate_step(self.test_loader, mode='test')
        self.add_instance(res)
        return res

    def add_instance(self, res):
        self.lines.append(res)

    def get_best(self):
        best_id = np.argmax([w['default'] for w in self.lines])
        res = self.lines[best_id]
        return res

    def re_init(self):
        self.preds, self.golds = defaultdict(list), defaultdict(list)
        self.keys = ['total']

    def add_output_agnews(self, data, output):
        gold = data['input_labels']
        self.preds['total'] += output
        self.golds['total'] += gold.tolist()

    def report_score_agnews(self, mode='valid'):
        res = {}
        res['Acc_total'] = accuracy_score(self.golds['total'], self.preds['total'])
        res['F1_total'] = f1_score(self.golds['total'], self.preds['total'], labels=[0, 1, 2, 3], average='macro')
        res['mode'] = mode
        for k, v in res.items():
            if isinstance(v, float):
                res[k] = round(v * 100, 3)
        return res