import os
import time
import torch
import argparse
from glob import glob
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import load_files, save_pickle, fix_seed, print_model
from model import BAN, HAN, GGNN, HypergraphTransformer, MemNet, HypergraphTransformer_wohe, HypergraphTransformer_qsetkhe, HypergraphTransformer_qhekset
from modules.logger import setup_logger, get_rank
from dataloader import KVQA, PQnPQL, load_PQnPQL_data, FVQA, load_FVQA_data

import math
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

def eval_epoch(model, loader, args):
    model.eval()
    total_right = 0
    total_right_aset = 0
    mshot_right = 0
    oshot_right = 0
    zshot_right = 0
    total_num = 0
    mshot_total_num = 0
    oshot_total_num = 0
    zshot_total_num = 0
    
    out_dict = {}
    att_list = []
    ques_list = []
    kg_list = []
    pred_ans_list = []
    true_ans_list = []

    if 'pq' in args.data_name:
        path_list = []
        aset_list = []
        
    for b_idx, batch in enumerate(tqdm(loader)):
        batch = [b.cuda() for b in batch]
        labels = batch[-1]
        
        pred = model(batch)
        pred_score, pred_ans = pred.max(1)

        ques = batch[0]
        ques_list.append(ques.detach().cpu().numpy())
        he_kg = batch[1]
        
        kg_list.append(he_kg.detach().cpu().numpy())
        pred_ans_list.append(pred_ans.detach().cpu().numpy())
        true_ans_list.append(labels.detach().cpu().numpy())
        nz_idxs = labels.nonzero()
        right = labels[nz_idxs] == pred_ans[nz_idxs]
        total_right += right.sum().item()
        total_num += len(labels)
        
        if 'fvqa' in args.data_name:
            _, top3_indices = torch.topk(pred, 3)
            for idx, indices in enumerate(top3_indices):
                if labels[idx] in indices:
                    total_right_aset += 1

        if 'pq' in args.data_name:
            path_list.append(batch[2].detach().cpu().numpy())            
            aset = batch[-2]
            aset_list.append(aset.detach().cpu().numpy())
            for idx, pred in enumerate(pred_ans):
                if pred in aset[idx]:
                    total_right_aset += 1
                    
        if args.data_name == 'kvqa' and args.per_cate == False:
            for idx, gt in enumerate(labels):
                if gt in loader.dataset.oshot_ans_idxs:
                    oshot_right += right[idx]
                    oshot_total_num += 1
                elif gt in loader.dataset.zshot_ans_idxs:
                    zshot_right += right[idx]
                    zshot_total_num += 1
                else:
                    mshot_right += right[idx]
                    mshot_total_num += 1
            
    if args.data_name == 'kvqa' and args.per_cate == False: 
        print ('## Analysis on few-shot and zero-shot answers')
        print ('For one-shot questions : ', oshot_right/oshot_total_num, oshot_right, oshot_total_num)
        print ('For zero-shot questions : ', zshot_right/zshot_total_num, zshot_right, zshot_total_num)
        print ('For multi-shot questions : ', mshot_right/mshot_total_num, mshot_right, mshot_total_num)
    
    out_dict['ques'] = ques_list
    out_dict['kg'] = kg_list
    out_dict['pred'] = pred_ans_list
    out_dict['tans'] = true_ans_list
    
    if 'pq' in args.data_name:
        out_dict['path'] = path_list
        out_dict['ansset'] = aset_list
        
    return total_right, total_right_aset, total_num, out_dict
        
def inference(model, test_loader, ckpt_path, args, task_idx=-1, res=False):   
    last_ckpt = os.path.join(ckpt_path, 'ckpt_best.pth.tar')
    checkpoint = torch.load(last_ckpt)

    if list(checkpoint['state_dict'].keys())[0].startswith('module.'):
        checkpoint['state_dict'] = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

    model.load_state_dict(checkpoint['state_dict'])
    print ("load: %s" % (last_ckpt))

    total_right, total_right_aset, total_num, out_dict = eval_epoch(model, test_loader, args)
    accuracy = (total_right/total_num)
    
    if 'pq' in args.data_name:
        accuracy = total_right_aset/total_num
                
    if res == True:
        if task_idx == -1:
            out_file = os.path.join(ckpt_path, 'ckpt_best_out_selected.pkl')
        else:
            out_file = os.path.join(ckpt_path, 'ckpt_best_out_task%s.pkl'%(task_idx))
        save_pickle(out_dict, out_file)
    
    return accuracy
            
def main():
    ''' parse config file '''
    parser = argparse.ArgumentParser(description="experiments")
    parser.add_argument("--model_name", default="ht")
    parser.add_argument("--data_name", default="kvqa")
    parser.add_argument("--cfg", default="ht")
    parser.add_argument("--exp_name", default="dev")
    parser.add_argument("--inference", action='store_true')
    parser.add_argument("--per_cate", action='store_true')
    parser.add_argument("--debug", action='store_true')
    parser.add_argument("--schedule", action='store_true')
    parser.add_argument("--selected", action='store_true')
    parser.add_argument("--abl_only_ga", action='store_true')
    parser.add_argument("--abl_only_sa", action='store_true')
    parser.add_argument("--abl_ans_fc", action='store_true')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--split_seed", type=int, default=123)
    parser.add_argument("--wd", type=float, default=0.0)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--max_epoch", type=int, default=1000)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--lr", type=float, default=0.0)
    parser.add_argument("--q_opt", type=str, default='org')
    parser.add_argument("--n_hop", type=int, default=1)
    parser.add_argument("--loaded_epoch", type=int, default=0)
    parser.add_argument("--local_rank", default=0, type=int, help="this argument is not used and should be ignored")
    args = parser.parse_args()
        
    config_file = "configs/%s.yaml"%(args.cfg)
    print (config_file)
    model_cfg = load_files(config_file)

    fix_seed(model_cfg['MODEL']['SEED'])

    if args.debug == False:
        summary_path = model_cfg["RES"]["TB"] + args.exp_name
        summary = SummaryWriter(summary_path)

    log_path = model_cfg["RES"]["LOG"] + args.exp_name
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    ckpt_path = model_cfg["RES"]["CKPT"] + args.exp_name
    if not os.path.exists(ckpt_path):
        os.makedirs(ckpt_path)
    
    logger = setup_logger(args.exp_name, log_path, get_rank())
    logger.info(model_cfg['MODEL'])
    logger.info(args)
    
    # ------------ Construct Dataset Class ------------------------------------
    datasets = {}
    if args.data_name == 'kvqa':
        modes = ['train', 'val', 'test']
        n_node_lists = []
        for mode in modes:
            fname = ckpt_path+'/%s_cache.pkl'%(mode)
            if os.path.isfile(fname):
                datasets[mode] = load_files(fname)
            else:
                data = KVQA(model_cfg, args, mode)
                datasets[mode] = data
                save_pickle(data, fname)
            n_node_lists.append(max(datasets[mode].n_node))
        max_n_node = max(n_node_lists)
    
        for mode in modes:
            datasets[mode].max_n_node = max_n_node
    
    elif 'fvqa' in args.data_name:
        train, test = load_FVQA_data(model_cfg, args)
        datasets['train'] = FVQA(model_cfg, args, train)
        datasets['test'] = FVQA(model_cfg, args, test)

    elif 'pq' in args.data_name:
        train, val, test = load_PQnPQL_data(model_cfg, args)
        datasets['train'] = PQnPQL(model_cfg, args, train)
        datasets['val'] = PQnPQL(model_cfg, args, val)
        datasets['test'] = PQnPQL(model_cfg, args, test)    

    train_loader = DataLoader(datasets['train'], batch_size=model_cfg['MODEL']['BATCH_SIZE'], num_workers=args.num_workers, shuffle=True)
    if 'fvqa' in args.data_name:
        val_loader = DataLoader(datasets['test'], batch_size=model_cfg['MODEL']['BATCH_SIZE'], num_workers=args.num_workers, shuffle=True)
    else:
        val_loader = DataLoader(datasets['val'], batch_size=model_cfg['MODEL']['BATCH_SIZE'], num_workers=args.num_workers, shuffle=True)
    test_loader = DataLoader(datasets['test'], batch_size=model_cfg['MODEL']['BATCH_SIZE'], num_workers=args.num_workers, shuffle=True)

    # ------------ Model -----------------------
    if args.model_name == 'ht':
        model = HypergraphTransformer(model_cfg, args).cuda()
    elif args.model_name == 'ht_abl_wohe':
        model = HypergraphTransformer_wohe(model_cfg, args).cuda()
    elif args.model_name == 'ht_abl_qset_khe':
        model = HypergraphTransformer_qsetkhe(model_cfg, args).cuda()
    elif args.model_name == 'ht_abl_qhe_kset':
        model = HypergraphTransformer_qhekset(model_cfg, args).cuda()
    elif args.model_name == 'ggnn':
        model = GGNN(model_cfg, args, max_n_node).cuda()
    elif args.model_name == 'han':
        model = HAN(model_cfg, args).cuda()
    elif args.model_name == 'ban':
        model = BAN(model_cfg, args).cuda()
    elif args.model_name == 'memnet':
        model = MemNet(model_cfg, args).cuda()
    elif args.model_name == 'gcn':
        model = GCN(model_cfg, args).cuda()
        
    print_model(model, logger)
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=150, T_mult=1, eta_max=0.001,  T_up=10, gamma=0.5)
    
    if args.loaded_epoch != 0:
        load_model = os.path.join(ckpt_path, 'ckpt_best.pth.tar')
        checkpoint = torch.load(load_model)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch_idx']
        start_epoch = epoch+1
        logger.info(load_model, ' is loaded!')
    else:
        start_epoch = 0
        
    model.cuda()
    
    train_loss = []
    best_acc = 0.
    for e_idx in range(start_epoch, start_epoch+args.max_epoch):
        model.train()
        total_right = 0
        total_num = 0
        total_right_aset = 0
        for b_idx, batch in enumerate(tqdm(train_loader)):
            batch = [b.cuda() for b in batch]
            labels = batch[-1]
            pred = model(batch)
            pred_score, pred_ans = pred.max(1)
            loss = F.nll_loss(pred, labels)
            train_loss.append(loss.item())

            nz_idxs = labels.nonzero()
            right = labels[nz_idxs] == pred_ans[nz_idxs]
            total_right += right.sum().item()
            total_num += len(labels)
            
            if 'fvqa' in args.data_name:
                _, top3_indices = torch.topk(pred, 3)
                for idx, indices in enumerate(top3_indices):
                    if labels[idx] in indices:
                        if labels[idx] != 0:
                            total_right_aset += 1 # top-3 accuracy

            if 'pq' in args.data_name:
                aset = batch[-2]
                for idx, pred in enumerate(pred_ans):
                    if pred in aset[idx]:
                        total_right_aset += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if args.debug == False:
                summary.add_scalar('loss/train', loss.item(), e_idx*len(train_loader)+b_idx)
        
        if args.schedule:
            lr_scheduler.step()
        
        if args.debug == False:
            tr_accu = total_right/total_num
            tr_accu_aset = total_right_aset/total_num
            summary.add_scalar('accu/train', tr_accu, e_idx)
            summary.add_scalar('lr/train', optimizer.param_groups[0]['lr'], e_idx)
            
            if 'pq' in args.data_name: # for PQ and PQL
                summary.add_scalar('accu_aset/train', tr_accu_aset, e_idx)
                logger.info('epoch %i train accuracy : %f, %i/%i / %f, %i/%i '%(e_idx, tr_accu, total_right, total_num, tr_accu_aset, total_right_aset, total_num))
            else: # for kvqa
                logger.info('epoch %i train accuracy : %f, %i/%i'%(e_idx, tr_accu, total_right, total_num))
            
        with torch.no_grad():
            total_right_val, total_right_aset_val, total_num_val, _ = eval_epoch(model, val_loader, args)
            
        if args.debug == False:
            val_acc = total_right_val/total_num_val
            val_acc_aset = total_right_aset_val/total_num_val
            summary.add_scalar('accu/val', val_acc, e_idx)
            
            if 'pq' in args.data_name:
                summary.add_scalar('accu_aset/val', val_acc_aset, e_idx)
                logger.info('epoch %i val accuracy : %f, %i/%i / %f, %i/%i'%(e_idx, val_acc, total_right_val, total_num_val, val_acc_aset, total_right_aset_val, total_num_val))
                val_acc = val_acc_aset
            else:
                logger.info('epoch %i val accuracy : %f, %i/%i'%(e_idx, val_acc, total_right_val, total_num_val))
        
            if val_acc >= best_acc:
                best_acc = val_acc
                torch.save({'epoch_idx': e_idx,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                        os.path.join(ckpt_path, 'ckpt_best.pth.tar'))
            logger.info('## Current VAL Best : %f'%(best_acc))
      
    test_acc_final = inference(model, test_loader, ckpt_path, args)
    logger.info('test accuracy (final) : %f'%(test_acc_final))
    summary.add_scalar('accu_final/test', test_acc_final, 0)
    
if __name__ == "__main__":
    main()