import os
import time
import datetime

import torch
import torch.nn as nn
import numpy as np

from tqdm import tqdm
# from termcolor import colored

from dataset.parallel_sampler import ParallelSampler
from train.utils import named_grad_param, grad_param, get_norm
from torch.cuda.amp import autocast as autocast

def train(train_data, val_data, model, args):
    '''
        Train the model
        Use val_data to do early stopping
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_acc = 0
    sub_cycle = 0
    best_path = None

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)#.to(device)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, 'max', patience=args.patience//2, factor=0.1, verbose=True)

    print("{}, Start training".format(
        datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True)

    train_gen = ParallelSampler(train_data, args, args.train_episodes)
    train_gen_val = ParallelSampler(train_data, args, args.val_episodes)
    val_gen = ParallelSampler(val_data, args, args.val_episodes)
 
    for ep in range(args.train_epochs):
        with autocast():
            sampled_tasks = train_gen.get_epoch()

            if not args.notqdm:
                sampled_tasks = tqdm(sampled_tasks, total=train_gen.num_episodes,
                        ncols=80, leave=False, desc='Training on train')

            for task in sampled_tasks:
                if task is None:
                    break
                train_one(task, model, opt, args)

            if ep % 10 == 0:
                acc, std = test(train_data, model, args, args.val_episodes, False,
                                train_gen_val.get_epoch())
                print("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f} ".format(
                    datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                    "ep", ep, "train","acc:", acc, std,
                    ), flush=True)

            # Evaluate validation accuracy
            cur_acc, cur_std = test(val_data, model, args, args.val_episodes, False,
                                    val_gen.get_epoch())
            print(("{}, {:s} {:2d}, {:s} {:s}{:>7.4f} ± {:>6.4f}, " ).format(
                   datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                   "ep", ep, "val  ", "acc:", cur_acc, cur_std), flush=True)

            # Update the current best model if val acc is better
            if cur_acc > best_acc:
                best_acc = cur_acc
                best_path = os.path.join(out_dir, str(ep))

                # save current model
                print("{}, Save cur best model to {}".format(
                    datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                    best_path))

                torch.save(model.state_dict(), best_path + '.bestmodel')

                sub_cycle = 0
            else:
                sub_cycle += 1

            # Break if the val acc hasn't improved in the past patience epochs
            if sub_cycle == args.patience:
                break

    print("{}, End of training. Restore the best weights".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')),
            flush=True)

    # restore the best saved model
    model.load_state_dict(torch.load(best_path + '.bestmodel'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(os.path.join(
                                      os.path.curdir,
                                      "saved-runs",
                                      str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, 'best')

        print("{}, Save best model to {}".format(
            datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
            best_path), flush=True)

        torch.save(model.state_dict(), best_path + '.bestmodel')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

    return


def train_one(task, model, opt, args):
    '''
        Train the model on one sampled task.
    '''
    model.train()
    opt.zero_grad()

    support, query = task
    # print("train_one len(support['text']), len(query['text'])::::::::: ", len(support['text']), len(query['text']))
    _, loss = model(support, query)

    loss.backward()
    opt.step()


def test(test_data, model, args, num_episodes, verbose=True, sampled_tasks=None):
    '''
        Evaluate the model on a bag of sampled tasks. Return the mean accuracy
        and its std.
    '''
    model.eval()

    if sampled_tasks is None:
        sampled_tasks = ParallelSampler(test_data, args,
                                        num_episodes).get_epoch()

    acc = []
    if not args.notqdm:
        sampled_tasks = tqdm(sampled_tasks, total=num_episodes, ncols=80,
                             leave=False,
                             desc= 'Testing on val' )

    for task in sampled_tasks:
        acc.append(test_one(task, model, args))

    acc = np.array(acc)

    if verbose:
        print("{}, {:s} {:>7.4f}, {:s} {:>7.4f}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'),
                 "acc mean" ,
                np.mean(acc),
                "std" ,
                np.std(acc),
                ), flush=True)

    return np.mean(acc), np.std(acc)


def test_one(task, model, args):
    '''
        Evaluate the model on one sampled task. Return the accuracy.
    '''
    support, query = task

    acc, loss = model(support, query)

    return acc
