import collections
import os
import pickle
import sys
import logging

import torch
import torch.multiprocessing as mp
import torchvision.transforms as transforms
import torch.nn as nn
import torch.distributed as dist
import tqdm
from transformers import BertTokenizer

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from vlpretrain.data import CombinedDataset, CombinedTorchDataset
from vlpretrain.loss import paired_hinge_rank_loss2, binary_classification_loss
from vlpretrain.metric import batchwise_accuracy2, batchwise_recall
from vlpretrain.model import LangModel, VisnModel, JointModel, LANG_MODELS
from vlpretrain.param import parse_args

import random
import numpy as np

logger = logging.getLogger(__name__)


def is_port_in_use(port):
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('localhost', port)) == 0


def main():
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    port = 9590
    while is_port_in_use(port):
        port += 1
    print("Use port", port)
    os.environ['MASTER_PORT'] = str(port)

    # Using all available gpus for multi-processing distributed
    args = parse_args()
    args.gpus = torch.cuda.device_count()
    print("Use gpus ", list(range(args.gpus)))
    args.world_size = args.gpus * args.nodes
    # mp.spawn(setup, nprocs=args.gpus, args=(args,))
    # args.world_size = args.gpus * args.nodes
    mp.spawn(train, nprocs=args.gpus, args=(args,))


def train(gpu, args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)


    device = torch.device('cuda', gpu)
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=args.world_size,
        rank=rank
    )
    # print('-------')
    print(args)
    # Models
    lang_layers = list(map(lambda x: -int(x), args.lang_layers.split(',')))     # The layers concated as the output.
    lang_model = LangModel(args, args.dim, arch=args.lang, layers=lang_layers,
                           pretrained=args.lang_pretrained, finetuning=args.lang_finetune, bertonly=args.bertonly, normalize=args.normalize)
    visn_model = VisnModel(args, args.dim, arch=args.visn,
                           pretrained=args.visn_pretrained, finetuning=args.visn_finetune, bertonly=args.bertonly, normalize=args.normalize)
    # The use of joint model would help synchronization in distributed learning.
    model = JointModel(args, lang_model, visn_model)

    # Since we will disallow the broadcast of buffers in DDP
    # we want make sure that there are no buffers besides batch normalization and position id.
    for name, buffer in model.named_buffers():
        assert 'bn' in name or 'downsample' in name or "position_ids" in name

    if args.load is not None:
        state_dict = torch.load(args.load, map_location=device)
        new_state_dict = {}
        for key, value in state_dict.items():        # If the ddp state_dict is saved
            if 'num_batches_tracked' not in key:
                if key.startswith("module."):
                    new_state_dict[key[len("module."):]] = state_dict[key]
                else:
                    new_state_dict[key] = state_dict[key]
        model_keys = set(model.state_dict().keys())
        load_keys = set(new_state_dict.keys())
        print("Keys in model but not in load:")
        for key in sorted(model_keys - load_keys):
            print(key)
        print("Keys in load but not in model:")
        for key in sorted(load_keys - model_keys):
            print(key)
        model.load_state_dict(new_state_dict)

        

    # Pre-processing Hyper-Params
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])
    Model, Tokenizer, weight = LANG_MODELS[args.lang]
    tokenizer = Tokenizer.from_pretrained(weight)
    # save_model(args, 'best', model.lang_model.model, tokenizer)

    # print('done')
    # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    max_len = args.max_len

    # Dump the pre-processing objs for future feature extractions.
    if gpu == 0:
        pickle.dump(tokenizer, open(
            os.path.join(args.output, 'tokenizer.pkl'), 'wb'))
        pickle.dump(valid_transform, open(
            os.path.join(args.output, 'img_transform.pkl'), 'wb'))

    # if os.path.exists(os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'trainset.pkl')):
    #     train_set = pickle.load(open(os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'trainset.pkl'), 'rb'))

    # # Data Sets
    # else:
    #     train_set = ImgSentDataset(args.train_imgs, args.train_langs, tiny=args.tiny, fast=args.fast)
    #     pickle.dump(train_set, open(
    #         os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'trainset.pkl'), 'wb'))
    # ###
    # train_set = ImgSentDataset(args.train_imgs, args.train_langs, tiny=args.tiny, fast=args.fast)
    # ###
    # train_tset = ImgSentTorchDataset(
    #     train_set, train_transform, tokenizer, max_len, perturb=args.perturb
    # )
    # print("GPU %d: load %d data in training." % (gpu, len(train_set)))

    # if os.path.exists(os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'validset.pkl')):
    #     valid_set = pickle.load(open(os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'validset.pkl'), 'rb'))

    # # Data Sets
    # else:
    #     valid_set = ImgSentDataset(args.valid_imgs, args.valid_langs, tiny=args.tiny, fast=args.fast)
    #     pickle.dump(valid_set, open(
    #         os.path.join('/home/woojeong2/vokenization/snap/vlpretrain', 'validset.pkl'), 'wb'))
    # ###        
    # valid_set = ImgSentDataset(args.valid_imgs, args.valid_langs, tiny=args.tiny, fast=args.fast, val=True)
    # ###
    # # valid_set.shuffle()         # Valid set only gets shuffled once!!!
    # print("GPU %d: load %d data in validation." % (gpu, len(valid_set)))
    # valid_tset = ImgSentTorchDataset(
    #     valid_set, valid_transform, tokenizer, max_len, val=True, perturb=args.perturb
    # )
    # print()
    snap_dir = '/home/woojeong2/vok_pretraining/snap/vlpretrain'
    # train_path = 'trainset_wiki103_book.pkl'
    # valid_path = 'validset_wiki103_book.pkl'
    train_path = 'trainset_'+args.dataname+'.pkl'
    valid_path = 'validset_'+args.dataname+'.pkl'
    if os.path.exists(os.path.join(snap_dir, train_path)):
        train_set = pickle.load(open(os.path.join(snap_dir, train_path), 'rb'))

    # Data Sets
    else:
        train_set = CombinedDataset(args, args.train_imgs, args.train_langs, tiny=args.tiny, fast=args.fast)
        pickle.dump(train_set, open(os.path.join(snap_dir, train_path), 'wb'))

    # train_set = CombinedDataset(args.train_imgs, args.train_langs, tiny=args.tiny, fast=args.fast)
    ###
    train_tset = CombinedTorchDataset(
        args, train_set, train_transform, tokenizer, max_len
    )
    print("GPU %d: load %d data in training." % (gpu, len(train_set)))

    if os.path.exists(os.path.join(snap_dir, valid_path)):
        valid_set = pickle.load(open(os.path.join(snap_dir,valid_path), 'rb'))

    # Data Sets
    else:
        valid_set = CombinedDataset(args, args.valid_imgs, args.valid_langs, tiny=args.tiny, fast=args.fast, val=True)
        pickle.dump(valid_set, open(
            os.path.join(snap_dir, valid_path), 'wb'))

    ###        
    # valid_set = CombinedDataset(args.valid_imgs, args.valid_langs, tiny=args.tiny, fast=args.fast, val=True)
    ###
    # valid_set.shuffle()         # Valid set only gets shuffled once!!!
    print("GPU %d: load %d data in validation." % (gpu, len(valid_set)))
    valid_tset = CombinedTorchDataset(
        args, valid_set, valid_transform, tokenizer, max_len, val=True
    )
    print()

    # Data Loader
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_tset,
        num_replicas=args.world_size,
        rank=rank,
        shuffle=True,
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_tset,
        batch_size=(args.batch_size),
        # batch_size=(args.batch_size),
        shuffle=False,          # Will be shuffled in the sampler.
        num_workers=max(args.num_workers // args.world_size, 1),
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True
    )

    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_tset,
        batch_size=32,             # Fix batch_size to have stable batchwise evaluations.
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )


    if args.optim == 'bert':
        from transformers import AdamW, get_linear_schedule_with_warmup
        no_decay = ["bias", "LayerNorm.weight"]
        
        params = list(filter(lambda p: p[1].requires_grad, model.named_parameters()))
        param_1 = [p for n, p in params if not any(nd in n for nd in no_decay)]
        param_2 = [p for n, p in params if any(nd in n for nd in no_decay)]

        optimizer_grouped_parameters = [
            {
                "params": param_1,
                "weight_decay": 0.01,
            },
            {
                "params": param_2,
                "weight_decay": 0.0,
            },
        ]

        # optimizer_grouped_parameters = [
        #     {
        #         "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        #         "weight_decay": 0.01,
        #     },
        #     {
        #         "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        #         "weight_decay": 0.0,
        #     },
        # ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8)
        t_total = len(train_loader) * args.epochs
        warmup_steps = int(t_total * args.warmup_ratio)
        print("Train for %d steps and warm up for %d steps" % (t_total, warmup_steps))
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
        )
    else:
        if args.optim == 'sgd':
            optimizer = args.optimizer(
                [param for param in model.parameters() if param.requires_grad],
                args.lr,
                momentum=0.9
            )
        else:
            optimizer = args.optimizer(
                [param for param in model.parameters() if param.requires_grad],
                args.lr,
                # momentum=0.9
            )

    # Loss and optimizer
    # criterion = paired_hinge_rank_loss2
    if args.loss == 'binary':
        criterion = binary_classification_loss
    elif args.loss == 'hinge':
        criterion = paired_hinge_rank_loss2
    torch.cuda.set_device(gpu)
    model.cuda(gpu)

    if args.fp16:
        try:
            from apex import amp
            from apex.parallel import DistributedDataParallel as DDP
            model, optimizer = amp.initialize(model, optimizer,
                                              opt_level='O2')
            # Defautly, current apex DDP would not broadcast the buffers.
            model = DDP(model)
        except Exception as e:
            print(e)
            print("Please install apex library")
            return
    else:
        # Note that we disallow broad cast buffers here to reduce communication cost.
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[gpu],
            find_unused_parameters=True,
            broadcast_buffers=False,
        )

    if args.test_only or args.load:     # Test the loading performance
        if gpu == 0:
            print("Test: GPU %d will test %d data in %d iterations." %
                  (gpu, len(valid_loader) * 256, len(valid_loader)))
            results = valid(args, model, criterion, valid_loader)
            print("Initial test results:")
            for key, value in results.items():
                print('\t%s: %0.4f' % (key, value))
        if args.test_only:
            exit()


    best_valid_loss = 9595.
    total_step = 0 
    total_loss_vl = 0
    total_loss_mlm = 0
    prev_loss = total_loss = 0.

    valid_step_size = 1000
    for epoch in range(args.epochs):
        if gpu == 0:
            print("Training of Epoch %d: GPU %d will process %d data in %d iterations." %
                  (epoch, gpu, len(train_loader) * args.batch_size , len(train_loader)))
        
        # Todo: data




        for i, (uid, lang_input,  visn_input, general_text_input) in enumerate(tqdm.tqdm(train_loader, disable=(gpu!=0))):
            # if i >= 1:
            #     break
            # return uid, (input_ids, attention_mask, ),  (tensor_img, ), (ids_capt, input_mask_capt, mlm_labels_capt), (ids_wiki, input_mask_wiki, mlm_labels_wiki)
            # Currently, lang_input is the (input_ids, attention_mask)
            # visn_input is (tensor_img)\

            lang_input = tuple(x.cuda(non_blocking=True) for x in lang_input)
            # neg_lang_input = tuple(x.cuda(non_blocking=True) for x in neg_lang_input)
            visn_input = tuple(x.cuda(non_blocking=True) for x in visn_input)
            # caption_input = tuple(x.cuda(non_blocking=True) for x in caption_input)
            general_text_input = tuple(x.cuda(non_blocking=True) for x in general_text_input)
            neg_lang_input = None
            if args.lmperturb or args.preprocessed_lmperturb:
                neg_lang_input = (lang_input[2], lang_input[3])
                lang_input = (lang_input[0], lang_input[1])
                
                

            # to(device)
            # Forward pass
            model.zero_grad()
            # lang_output, visn_output, caption_loss, text_loss = model(lang_input, visn_input, caption_input, general_text_input)
            # # neg_lang_output, neg_visn_output, _, _  = model(neg_lang_input, visn_input)
            # # loss = criterion(lang_output, visn_output, neg_lang_output, neg_visn_output, args.margin, args.bertonly)
            # loss = criterion(lang_output, visn_output,  args.margin, args.bertonly)
            # loss += args.losshp*2 *caption_loss + (1-args.losshp)*2* text_loss
            # total_loss += loss.item()

            loss_vl, caption_loss, text_loss, _, _ = model(lang_input, visn_input, neg_lang_input, None, general_text_input)
            # neg_lang_output, neg_visn_output, _, _  = model(neg_lang_input, visn_input)
            # loss = criterion(lang_output, visn_output, neg_lang_output, neg_visn_output, args.margin, args.bertonly)
            # loss = criterion(lang_output, visn_output,  args.margin, args.bertonly)
            # loss = args.losshp * 2 * (loss_vl + caption_loss) + (1 - args.losshp) * 2 * text_loss
            # loss = args.losshp * (loss_vl) + (1 - args.losshp) * text_loss
            loss =  (loss_vl) + args.losshp * text_loss
            total_loss += loss.item()
            total_loss_vl += loss_vl.item()
            total_loss_mlm += text_loss.item()
            

            # Backward
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Step
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            optimizer.step()
            if args.optim == 'bert':
                scheduler.step()
                
            

            # # Logging
            # interval = 100
            # if (i+1) % interval == 0:
            #     print("GPU %d Epoch %d Iter %d: Training Loss %0.4f" %
            #           (gpu, epoch, i+1, (total_loss - prev_loss) / interval))
            #     prev_loss = total_loss
            if gpu == 0 and total_step > 0 and total_step % valid_step_size == 0:
                snap_path = os.path.join(args.output, 'Step'+str(total_step)+'.pth')
                print("GPU 0: Save snapshot to ", snap_path)
                # torch.save(model.module.state_dict(), snap_path)
                # torch.save(model.module.lang_model.model.state_dict(), snap_path+'_lang')
                save_model(args, 'Step'+str(total_step), model.module.lang_model.model, tokenizer)
                print("GPU %d Step %d: Total Training Loss %0.4f" % (gpu, total_step, total_loss / valid_step_size))

                total_loss = 0
                # print("GPU %d Epoch %d: Total Training Loss %0.4f" % (gpu, epoch, total_loss / len(train_loader)))
                print("loss_vl", total_loss_vl/ valid_step_size, "text_loss", total_loss_mlm / valid_step_size)
                total_loss_vl = 0
                total_loss_mlm = 0

                print()
                print("Validation: GPU %d will process %d data in %d iterations." %
                    (gpu, len(valid_loader) * 256, len(valid_loader)))
                results = valid(args, model, criterion, valid_loader, use_tqdm=True)
                for key, value in results.items():
                    print('\t%s: %0.4f' % (key, value))
                if results['loss'] < best_valid_loss:
                    best_valid_loss = results['loss']
                    snap_path = os.path.join(args.output, 'BEST.pth')
                    print("GPU 0: Save snapshot to ", snap_path)
                    torch.save(model.module.state_dict(), snap_path)
                    torch.save(model.module, snap_path + '.model')
                    torch.save(model.module.lang_model.model.state_dict(), snap_path+'_lang')
                    torch.save(model.module.lang_model.model, snap_path + '.model_lang')
                    save_model(args, 'best', model.module.lang_model.model, tokenizer)
                print("BEST valid loss %0.4f" % best_valid_loss)
                print()
            total_step += 1


def valid(args, model, criterion, valid_loader, use_tqdm=True):
    model.eval()
    results = collections.defaultdict(lambda: 0)
    iterator = tqdm.tqdm(valid_loader) if use_tqdm else valid_loader
    total_size = 0
    
    for i, (uid, lang_input, visn_input, general_text_input) in enumerate(iterator):
        # Currently, lang_input is the (input_ids, attention_mask)
        # visn_input is (tensor_img)
        # tem = lang_input
        lang_input = tuple(x.cuda(non_blocking=True) for x in lang_input)
        # neg_lang_input = tuple(x.cuda(non_blocking=True) for x in neg_lang_input)
        visn_input = tuple(x.cuda(non_blocking=True) for x in visn_input)
        # caption_input = tuple(x.cuda(non_blocking=True) for x in caption_input)
        general_text_input = tuple(x.cuda(non_blocking=True) for x in general_text_input)
        neg_lang_input = None
        if args.lmperturb or args.preprocessed_lmperturb:
            neg_lang_input = (lang_input[2], lang_input[3])
            lang_input = (lang_input[0], lang_input[1])


        with torch.no_grad():
            # Forward pass
            # lang_output, visn_output = model(lang_input, visn_input)
            # neg_lang_output, neg_visn_output = model(neg_lang_input, visn_input)

            loss_vl, caption_loss, text_loss, lang_output, visn_output = model(lang_input, visn_input, neg_lang_input, None, general_text_input)
            # neg_lang_output, neg_visn_output, _, _  = model(neg_lang_input, visn_input)
            

            # Evaluation
            # results['loss'] += args.losshp*2 * (criterion(lang_output, visn_output, neg_lang_output, neg_visn_output, args.margin, args.bertonly).item() + caption_loss.item())+ (1-args.losshp)*2 *text_loss.item()
            # acc_result, size = batchwise_accuracy(lang_output, visn_output, neg_lang_output, neg_visn_output, hinge=False, bertonly=args.bertonly)

            # results['loss'] += args.losshp *( loss_vl.item() + caption_loss.item())+ (1-args.losshp) *text_loss.item()
            results['loss'] += args.losshp *  ( loss_vl.item() ) + (1 - args.losshp) * text_loss.item()
            results['loss_vl'] += loss_vl.item()
            results['loss_mlm'] += text_loss.item()
            # acc_result, size = batchwise_accuracy(lang_output, visn_output,  hinge=False, bertonly=args.bertonly)
            acc_result = batchwise_accuracy2(lang_output, visn_output )
            results['acc'] += acc_result.item()
            # total_size += size
                # results['R%d' % key] += value

    for key in results:
        if key == 'acc':
            results[key] = results[key] / len(valid_loader)
        else:
            results[key] = results[key] / len(valid_loader)
    model.train()

    return results

def save_model(args, name, model, tokenizer):
    # Save model checkpoint
    output_dir = os.path.join(args.output, name)
    os.makedirs(output_dir, exist_ok=True)
    model_to_save = (
        model.module if hasattr(model, "module") else model
    )  # Take care of distributed/parallel training
    model_to_save.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    torch.save(args, os.path.join(output_dir, "training_args.bin"))

    # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
    # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
    # logger.info("Saving optimizer and scheduler states to %s", output_dir)


if __name__ == "__main__":
    main()
