# Erik McGuire, 2021
import pandas as pd
import json
import os

from zuco_params import args
from zuco_utils import debug
from zuco_dataset import ZuCoDataset

from transformers import set_seed

if args.random_scores:
    print('Using randomized ZuCo supervision.')

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def main(f_ix:int = 0):

    if args.att_only:
        print("Updating only attention weights.")

    training_args = ZucoTrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        weight_decay=0.01,
        evaluation_strategy="no" if args.task == "rel" and args.cv == 0 else args.eval_strat,
        save_total_limit=args.save_total_limit,
        eval_steps=args.eval_steps,
        do_predict=args.do_predict,
        do_train=args.do_train,
        do_eval=False if args.task == "rel" and args.cv == 0 else args.do_eval,
        et_lmbda=args.et_lmbda,
        eeg_lmbda=args.eeg_lmbda,
        run_name=args.run_name if args.cv == 0 else f'{args.run_name}_fold_{f_ix}',
        pred_lmbda=args.pred_lmbda,
        seed=args.seed,
        report_to = args.report_to,
        save_steps=args.save_steps if args.cv == 0 else 5000,
        overwrite_output_dir=args.overwrite_output_dir,
    )
    set_seed(training_args.seed)
    trainer = ZuCoTrainer(
        model=get_model("model"),
        args=training_args,
        train_dataset=gist.ztrainset if args.zuco_only and not args.filtered else gist.train_dataset,
        eval_dataset=eval_set if not args.task == "rel" else gist.zdevset if args.cv > 0 else None,
        compute_metrics=compute_metrics,
        callbacks=[LossCallback],
        model_init=get_model("model_init")
    )

    if training_args.do_train:
        if args.train_type == "grid":
            best = trainer.hyperparameter_search(hp_space=hp_space,
                                                 n_trials=args.n_trials,
                                                 study_name=args.study_name,
                                                 direction='maximize',
                                                 compute_objective=zuco_objective)
            try:
                with open(f'data/best_run.json', 'w') as br:
                    json.dump(best, br)
            except:
                pass
            print(f"Best run: #{best.run_id}\nObjective: {best.objective}\nHyperparameters: {best.hyperparameters}")
        else:
            if args.load_best:
                try:
                    with open(f"data/best_run.json") as br:
                        best_run = json.load(br)
                        for n, v in best_run['hyperparameters'].items():
                            setattr(trainer.args, n, v)
                except:
                    print("\nNo best run found.\n")
            trainer.train()

    if training_args.do_eval:
        eval_ = trainer.evaluate()
        print(f"\nEvaluate:\n{eval_.items()}")

    if training_args.do_predict:
        if not training_args.do_train:
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"),
                config=gist.config,
                name=f'{args.run_name}_{args.chkpt}',
                reinit=True,
            )
        metrics = trainer.predict(gist.test_dataset).metrics
        print(f"\nPredict:\n{metrics.items()}")
        metrics["eval_checkpoint"] = int(args.chkpt)
        wandb.log(metrics)

    if not args.cv > 0 and args.save:
        if trainer.is_world_process_zero():
            tokenizer.save_pretrained(training_args.output_dir)

    if args.save_att:
        print("Collecting attentions for eval, test sets.")
        evaluate(trainer, training_args)

if __name__ == "__main__":
    from zuco_train_utils import *
    if args.cv > 0:
        for ix in range(args.cv):
            print(f"Fold {ix+1}")
            gist.current_fold = ix
            gist.ztrainset = gist.ztrainsets[ix]
            gist.zdevset = gist.zdevsets[ix]
            main(ix + 1)
            wandb.finish()
    else:
        main()

    if args.cv > 0: # For each step, save each metric averaged over folds
        avg_cv_scores = {step: {met: np.mean([gist.cv_scores[f][step][met]
                                              for f in range(args.cv)])
                                for met in gist.cv_scores[0][step].keys()}
                         for step in range(len(gist.cv_scores[0]))}
        if args.eval_strat == "steps":
            iteration = f"{args.eval_steps} steps"
        elif args.eval_strat == "epoch":
            iteration = f"{args.epochs} epochs"
        avg_cv_scores_df = pd.DataFrame.from_dict(avg_cv_scores, orient='index')
        avg_cv_scores_df.to_csv(f'{args.output_dir}avg_cv_scores_{args.run_name}.tsv', sep="\t")
        print(f"Dev scores per {iteration}, averaged over {args.cv} folds:")
        for i, metric_scores in avg_cv_scores.items():
            if not args.eval_strat == "epoch":
                print(f'Iteration {i+1}:')
            for metric, score in metric_scores.items():
                print(f'{metric}: {score:.3f}')
            print("\n")
        args.cv = 0
        #if args.cv == 0 and args.use_wandb:
            #try:
                #del os.environ["WANDB_DISABLED"]
            #except KeyError:
                #pass
        gist.ztrainset = gist.ztrainmain
        gist.zdevset = None
        main() # Build final model on full training data
