"""
    Gather results on the diagnostic using the specificed transformer
    See the argparser below for details on the command line arguments
"""
from posixpath import basename, split
import pandas as pd
from simpletransformers.classification import ClassificationModel
import torch
from sys import argv
from os import path, listdir
import numpy as np
import argparse

def DiagTrial(args):
    # pth = args.model_type.split("/")[-2]

    # print(f"Starting {pth}")
    print(f"Starting {path.basename(args.model_type)}")
    train = pd.read_json(args.train_set,lines=True)
    dev = pd.read_json(args.dev_set,lines=True)
    diagnostic = pd.read_json("inability_innoculation_test_v3.jsonl",lines=True,orient="records")

    label_map = {'contradiction':0,'neutral':1,'entailment':2}

    train = train[['gold_label','sentence1','sentence2']]
    dev = dev[['gold_label','sentence1','sentence2']]

    train = train.sample(frac=1).reset_index(drop=True)
    dev = dev.sample(frac=1).reset_index(drop=True)
    train.rename({'gold_label':'labels','sentence1':'text_a','sentence2':'text_b'},inplace=True,axis=1)
    dev.rename({'gold_label':'labels','sentence1':'text_a','sentence2':'text_b'},inplace=True,axis=1)
    diagnostic.rename({'gold_label':'labels','sentence1':'text_a','sentence2':'text_b'},inplace=True,axis=1)
    train['labels'] = train['labels'].apply(lambda x: label_map[x])
    dev['labels'] = dev['labels'].apply(lambda x: label_map[x])
    diagnostic['labels'] = diagnostic['labels'].apply(lambda x: label_map[x])

    # Create a TransformerModel
    model = ClassificationModel(
        args.model_name,
        args.model_type,
        num_labels=3,
        use_cuda=True,
        cuda_device=args.cuda_device,
        args=(
        {
        'output_dir':f'./checkpoint_results/{path.basename(args.model_type)}/',
        'overwrite_output_dir': True,
        'fp16': True, # uses apex
        # for innoculation experiments, set epochs=40
        'num_train_epochs': args.epochs,
        'reprocess_input_data': True,
        "learning_rate": 1e-5,
        "train_batch_size": args.batch_size,
        "eval_batch_size": args.batch_size,
        "max_seq_length": args.seq_len, #175
        "weight_decay": 0.01,
        "do_lower_case": False,
        "evaluate_during_training":False,
        "evaluate_during_training_verbose":False,
        "evaluate_during_training_steps":15000,
        "use_early_stopping":True,
        "early_stopping_patience":5,
        "early_stopping_consider_epochs":True,
        "save_steps":-1, #15000
        "n_gpu": args.num_gpus,
        "logging_steps":10,
        })
    )

    if args.finetune != 'no':
        model.train_model(train,eval_df=dev)
    
    if args.eval != 'no':
        # Evaluate the model
        result, model_outputs, wrong_predictions = model.eval_model(diagnostic)
        for index, row in diagnostic.iterrows():
            prediction = np.argmax(model_outputs[index])
            if prediction == row['labels']:
                diagnostic.at[index, 'Correct'] = 1
            else:
                diagnostic.at[index, 'Correct'] = 0

        name = args.model_type.split("/")[-1]
        print(diagnostic['Correct'].mean())
        diagnostic.to_csv(f"{args.diag_save_dir}/challenge_nli_inability_{name}.csv")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run a Negation diagnostic experiment using Transformers')
    parser.add_argument('--finetune', action='store', type=str)
    parser.add_argument('--model_name', action='store', type=str, required=True, help='The name of the class of model architectures (BERT, etc), must be supported by simpletransformers.')
    parser.add_argument('--model_type', action='store', type=str, required=True, help='The specific model to load, either huggingface id or path to local dir.')
    parser.add_argument('--epochs', action='store', type=int, default=5, help='Number of finetuning epochs, ignored if not finetuning.')
    parser.add_argument('--batch_size', action='store', type=int, default=16, help='Finetuning batch size')
    parser.add_argument('--seq_len', action='store', type=int, default=175, help='Max seq length for finetuning.')
    parser.add_argument('--cuda_device', action='store', type=int, required=True, help="Device to run experiment on.")
    parser.add_argument('--num_gpus', action='store', type=int, default=1, help='Number of gpus to use.')
    parser.add_argument('--eval', action='store', type=str, required=True)
    parser.add_argument('--train_set', action='store', type=str, required=True, help="Path to training dataset")
    parser.add_argument('--dev_set', action='store', type=str, required=True, help="Path to dev set.")
    parser.add_argument('--diag_save_dir', action='store', type=str, help="The directory where diagnostic results should be saved")
    args = parser.parse_args()

    DiagTrial(args)