import os
import pickle
import numpy as np
import torch

from .dataset import BasicDataset, collate_fn, seq2seq_collate_fn, Seq2SeqDataset
from .utils import read_cli, get_config, override_config, train_tag, val_tag, entity_tag, test_tag
from .trainers import ClosureTrainer, Seq2SeqTrainer

from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification
from transformers import TrainingArguments
from transformers import T5ForConditionalGeneration, T5Tokenizer

os.environ["WANDB_DISABLED"] = "true"


def run_closure(args):
    seed = args['train']['seed']
    np.random.seed(seed)
    torch.manual_seed(seed)

    tokenizer = DebertaV2Tokenizer.from_pretrained(args['model']['wildcard'])

    # print(args['datapath'])
    # train_file = os.path.join(args['datapath'], train_tag)
    # train_dataset = BasicDataset(train_file, tokenizer, mode='train', cfg=args)
    # val_file = os.path.join(args['datapath'], test_tag)
    val_file = os.path.join(args['datapath'], val_tag)

    val_dataset = BasicDataset(val_file, tokenizer, mode='infer', cfg=args)

    train_args = TrainingArguments(
        output_dir=args['destpath'],
        overwrite_output_dir=True,
        per_device_train_batch_size=args['train']['per_device_train_batch_size'],
        per_device_eval_batch_size=args['dev']['per_device_eval_batch_size'],
    )

    chkpt_path = os.path.join(args['destpath'], f"checkpoint-{args['chkpt']}")
    print(f"Loading model from {chkpt_path}")
    model = DebertaV2ForSequenceClassification.from_pretrained(chkpt_path)

    trainer = ClosureTrainer(
        tokenizer=tokenizer,
        model=model, args=train_args,
        # train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collate_fn,
    )
    
    # metrics = trainer.evaluate(train_dataset)
    metrics = trainer.evaluate(
        val_dataset, save_results=True,
        result_path=args['result_path']
    )
    print(metrics)


def run_etypes(args):
    seed = args['train']['seed']
    np.random.seed(seed)
    torch.manual_seed(seed)
    # torch.use_deterministic_algorithms(True)

    chkpt_path = os.path.join(args['destpath'], f"checkpoint-{args['chkpt']}")
    print(f"Loading model from {chkpt_path}")
    model = T5ForConditionalGeneration.from_pretrained(chkpt_path)
    tokenizer = T5Tokenizer.from_pretrained(args['model']['wildcard'])

    print(args['datapath'])
    # val_file = os.path.join(args['datapath'], test_tag)
    val_file = os.path.join(args['datapath'], val_tag)
    val_dataset = Seq2SeqDataset(
        val_file, tokenizer, mode='infer', cfg=args
    )

    train_args = TrainingArguments(
        output_dir=args['destpath'],
        overwrite_output_dir=True,
        per_device_train_batch_size=args['train']['per_device_train_batch_size'],
        per_device_eval_batch_size=args['dev']['per_device_eval_batch_size'],
    )

    trainer = Seq2SeqTrainer(
        cfg=args,
        tokenizer=tokenizer,
        model=model, args=train_args,
        # train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=seq2seq_collate_fn,
    )
    
    # metrics = trainer.evaluate(train_dataset)
    metrics = trainer.evaluate(
        val_dataset, save_results=True,
        result_path=args['result_path']
    )
    print(metrics)


if __name__ == "__main__":
    cargs = read_cli()

    model_path = cargs['model_path']
    cfg_path = cargs['config']
    args = get_config(cfg_path)
    args['destpath'] = model_path
    if cargs['datapath'] is not None:
        args['datapath'] = cargs['datapath']
    else:
        del cargs['datapath']

    if cargs['batch_size'] is not None:
        args['dev']['per_device_eval_batch_size'] = cargs['batch_size']
    else:
        del cargs['batch_size']

    args.update(cargs)
    if args['hint'] == 'closure':
        run_closure(args)
    elif args['hint'] == 'entity_types':
        run_etypes(args)
