import argparse
import ast
import logging
import os
import sys
import json
import jsonlines

import pandas as pd
import torch
from tqdm import tqdm

from transformers import BartForConditionalGeneration, MT5ForConditionalGeneration, AutoTokenizer
from transformers import logging as transformers_logging


sys.path.append(os.path.join(os.getcwd()))  # noqa: E402 # isort:skip

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

transformers_logging.set_verbosity_info()

def main():
    model_class = MT5ForConditionalGeneration

    checkpoint = '/projects/0/prjs0888/datasets/CORA-models/mGEN_model'
    logger.info("Evaluate the following checkpoints: %s", checkpoint)

    model = model_class.from_pretrained(checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
   
    
    dir_path = './xor_attriqa/in-language/'
    #dir_path = './xor_attriqa/in-english/'
    #dir_path = './xor_attriqa/concat-in-language/'
    #dir_path = './xor_attriqa/concat-in-english/'

    files = os.listdir(dir_path)
    
    for fname in files:
        print('*****************************')
        print(fname)
        alternatives = 3

        with open(dir_path + fname) as f:
            count = 0
            num = dict()
            num_easy1 = dict()
            num_easy2 = dict()
            num_easy3 = dict()
            pos_case = dict()
            neg_case = dict()
            comp_case = dict()
            pos_idx = dict()
            neg_idx = dict()

            for i in range(alternatives): 
                num[i] = 0
                num_easy1[i] = 0
                num_easy2[i] = 0
                num_easy3[i] = 0
                pos_case[i] = []
                neg_case[i] = []
                comp_case[i] = []
                pos_idx[i] = []
                neg_idx[i] = []

            for item in tqdm(jsonlines.Reader(f)):
                if item['passage_retrieved_language'] == 'en':
                    passage = item['passage_en']
                    #passage = item['passage_in_language']
                else:
                    #passage = item['passage_en']
                    passage = item['passage_in_language']
                '''
                # Template0: "<Q>: {0} <P>:{1}"
                data0 = "<Q>: " + item['query'] + " <P>:" + passage + ""
                data1 = "<Q>: " + item['query'] + " <P>: " + passage + ""
                data2 = "<Q>: " + item['query'] + " <P>: " + passage + ". "
                data_list = [data0, data1, data2]
                '''
                data0 = "Passage: " + passage + " Query: " + item['query'] + " Answer: "
                data1 = "<Q>: " + item['query'] + " <P>:" + passage
                data2 = "<Q>: " + item['query'] + " <P>: " + passage
                
                data_list = [data0, data1, data2]

                count += 1
                with torch.no_grad():
                    inputs_dict = tokenizer.batch_encode_plus(
                            data_list,
                            return_tensors="pt", 
                            padding=True, 
                            truncation=True
                    )

                    input_ids = inputs_dict.input_ids.to(device)
                    attention_mask = inputs_dict.attention_mask.to(device)
                    outputs = model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        num_beams=4,
                        min_length=1,
                        max_length=20,
                        early_stopping=False,
                        num_return_sequences=1,
                        output_scores=False,
                        return_dict_in_generate=False
                    )
        
                    answers = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                    for i in range(alternatives):
                        if answers[i] == item['prediction']: 
                            num[i] += 1
                            pos_case[i].append(item)
                            pos_idx[i].append(count)
                        else:
                            neg_case[i].append(item)
                            neg_idx[i].append(count)
                            comp_case[i].append("Answer: " + answers[i] + ', XAttriQA Answer: ' + item['prediction'])

                        if item['prediction'] in answers[i]:
                            num_easy1[i] += 1
                        if answers[i] in item['prediction']:
                            num_easy2[i] += 1
                        if (item['prediction'] in answers[i]) or (answers[i] in item['prediction']):
                            num_easy3[i] += 1
            print(num)
            print(num_easy1)
            print(num_easy2)
            print(num_easy3)
            print(count)
            for i in range(alternatives):
                num[i] /= count
                num_easy1[i] /= count
                num_easy2[i] /= count
                num_easy3[i] /= count
            print(num)
            print(num_easy1)
            print(num_easy2)
            print(num_easy3)
            print()

            '''
            #print(pos_idx)
            #print(neg_idx)
            #print(comp_case)
            #print()
            '''
            with jsonlines.open('./record/pos-' + fname, 'w') as f:
                for i in pos_case[1]:
                    f.write(i)
            with jsonlines.open('./record/neg-' + fname, 'w') as f:
                for i in neg_case[1]:
                    f.write(i)

if __name__ == "__main__":
    main()
