import argparse
import ast
import logging
import os
import sys
import json
import jsonlines

import pandas as pd
import torch
from tqdm import tqdm

from transformers import BartForConditionalGeneration, MT5ForConditionalGeneration, AutoTokenizer
from transformers import logging as transformers_logging


sys.path.append(os.path.join(os.getcwd()))  # noqa: E402 # isort:skip

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

transformers_logging.set_verbosity_info()

def main():
    model_class = MT5ForConditionalGeneration

    checkpoint = '/scratch/p313030/datasets/CORA-models/mGEN_model'

    logger.info("Evaluate the following checkpoints: %s", checkpoint)

    model = model_class.from_pretrained(checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
   
    
    dir_path = './xor_attriqa/in-language/'
    files = os.listdir(dir_path)
    
    for fname in files:
        print('*****************************')
        print(fname)
        data_list1 = []
        data_list2 = []
        data_list3 = []
        data_list4 = []
        data_list5 = []
        data_list6 = []
        data_list7 = []
        prediction_list = []
        count = 0
        with open(dir_path + fname) as f:
            for item in jsonlines.Reader(f):
                data_list1.append(item['query'] + item['passage_in_language'])
                data_list2.append(item['query'] + ' ' + item['passage_in_language'])
                data_list3.append(item['query'] + item['passage_in_language'] + '.')
                data_list4.append(item['query'] + ' ' + item['passage_in_language'] + '.')
                data_list5.append(item['passage_in_language'] + item['query'])
                data_list6.append(item['passage_in_language'] + ' ' + item['query'])
                data_list7.append(item['passage_in_language'] + '. ' + item['query'])

                prediction_list.append(item['prediction'])
            
            count = len(data_list1)
            data_lists = [data_list1, data_list2, data_list3, data_list4, data_list5, data_list6, data_list7]

            for data_list in data_lists:
                with torch.no_grad():
                    inputs_dict = tokenizer.batch_encode_plus(
                            data_list,
                            return_tensors="pt", 
                            padding=True, 
                            truncation=True
                    )

                    input_ids = inputs_dict.input_ids.to(device)
                    attention_mask = inputs_dict.attention_mask.to(device)
                    outputs = model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        num_beams=4,
                        min_length=1,
                        max_length=20,
                        early_stopping=False,
                        num_return_sequences=1,
                        # BART likes to repeat BOS tokens, dont allow it to generate more than one,
                        bad_words_ids=[[0, 0]],
                        output_scores=False,
                        return_dict_in_generate=False
                    )
        
                    answers = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                    num_same = 0
                    for i in count:
                        if answers[i] == prediction_list[i]: 
                            num_same += 1
                    print(num_same)
                    print(count)
                    print(num_same/count)
                    print()

if __name__ == "__main__":
    main()
