import os
from random import seed

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

from models.model_exp2 import BertBEM 
from models.dataset_v2 import  TextDataset, MyCollator


from models.roberta_model import RobertaBEM
from models.trainer import Trainer
from models.optimizer import MyOptimizer

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

from common_utils import fix_seed 
from build_parser import default_parser
 
def get_data_reader(data_fp,  batch_size, args, max_len=512,  num_workers=4,    ):
    # fake max_len, which is set max_len>=256, when below code may meet error
    dataset = TextDataset( json_list_fp = data_fp, max_len=max_len )
    if args.usr_roberta == 0:
        collate = MyCollator( maxlen=max_len, PAD_ids=0 ,use_roberta=False )
    else:
        collate = MyCollator( maxlen=max_len, PAD_ids=1 , use_roberta=True ) 

    num_gpus = max(1, torch.cuda.device_count())
    batch_cnt = ( dataset.__len__() // num_gpus ) // batch_size

    if not  torch.cuda.is_available():
        dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, shuffle=True , collate_fn=collate, drop_last=True )
    else:
        datasampler = DistributedSampler(dataset)
        ## data sample cross different gpus 
        dataloader = DataLoader(dataset, batch_size=batch_size, sampler=datasampler, shuffle=False, collate_fn=collate, num_workers=num_workers, pin_memory=True, drop_last=True )
    return dataloader, batch_cnt 
 
def init_device():
    print(  torch.cuda.device_count() )
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            cuda_device = list(range(torch.cuda.device_count()))
            torch.cuda.set_device(args.local_rank)
            torch.distributed.init_process_group(backend="nccl", world_size=len(cuda_device) )
            print("use ddp in training", len(cuda_device))
        else:
            cuda_device = 0
        device = torch.device("cuda:{}".format(args.local_rank)) 
    else:
        cuda_device = -1
        device = torch.device("cpu")
    return cuda_device 

def get_model( args ):
    model = BertBEM(   
                            sent_encoder_folder = args.sent_encoder_folder ,
                            #'bert-base-uncased',
                            def_encoder_folder = args.def_encoder_folder , 
                            #'bert-base-uncased', 
                            gloss_weight = args.gloss_weight,
                            mlm_weight = args.mlm_weight, 
                            use_roberta= (args.usr_roberta>0.5),
                         )
    print("mlm loss weight: ", args.mlm_weight )
    print("gloss loss weight: ", args.gloss_weight )
    return model

## init the model params by BERT/RoBERTa
def load_model_params( args, model, fine_tune=False, model_path=None ):
    model_dict = model.state_dict()
    if fine_tune:
        pt_params = torch.load(model_path, map_location=torch.device('cpu') )
        for k,v in pt_params.items() :
            if k not in model_dict.keys():
                print(k)
                assert False
        model.load_state_dict(pt_params)
        print("use model %s to fine tune the BERT model"%(model_path))
    
    else:
        sent_bert_path = args.sent_encoder_folder + "/pytorch_model.bin"
        def_bert_path = args.def_encoder_folder + "/pytorch_model.bin"
        sent_bert_params = torch.load(sent_bert_path , map_location=torch.device('cpu') )
        def_bert_params =  torch.load(def_bert_path , map_location=torch.device('cpu') ) 
        state_dict = dict()

        if args.usr_roberta == 0:   
            for k, v in sent_bert_params.items():
                enc_key = k.replace( "bert.", "sent_encoder." )
                enc_key = enc_key.replace(".gamma", ".weight").replace(".beta", ".bias")
                if enc_key in model_dict: 
                    state_dict[enc_key] = v
                else:
                    print("sent param not in the model: ", k)
            
                enc_key = k.replace( "bert.", "mlm_encoder." ) 
                enc_key = enc_key.replace(".gamma", ".weight").replace(".beta", ".bias")
                if enc_key in model_dict: 
                    state_dict[enc_key] = v 
                else:
                    print("mlm param not in the model: ", enc_key) 

            # target form : def_encoder.encoder.layer.11.attention.self.query.weight
            ## ori form:    'bert.encoder.layer.11.attention.self.query.weight'
            for k, v in def_bert_params.items():
                enc_key = k.replace( "bert.", "def_encoder." )
                enc_key = enc_key.replace(".gamma", ".weight").replace(".beta", ".bias")
                if enc_key in model_dict: 
                    state_dict[enc_key] = v
                else:
                    print("def param not in the model: ", enc_key) 
        else:
            for k, v in sent_bert_params.items():
                enc_key = k.replace( "roberta.", "sent_encoder." ).replace('lm_head.','cls.')
                enc_key = enc_key.replace(".gamma", ".weight").replace(".beta", ".bias")
                if enc_key in model_dict: 
                    state_dict[enc_key] = v
                else:
                    print("sent param not in the model: ", k)
            for k, v in def_bert_params.items():
                #enc_key = "def_encoder." + k
                enc_key = k.replace( "roberta.", "def_encoder." ) 
                enc_key = enc_key.replace(".gamma", ".weight").replace(".beta", ".bias")
                if enc_key in model_dict: 
                    state_dict[enc_key] = v
                else:
                    print("def param not in the model: ", enc_key)   
        print("params not in pretrain: ")
        for k in model_dict :
            if not k in state_dict :
                print(k)
        model_dict.update(state_dict)
        model.load_state_dict(model_dict)
        print("load model from pretrain parameters")
    return model 
        
def main(args):
    #fix_seed() 
    if not os.path.exists(args.model_dir):  
        os.mkdir(args.model_dir) 
   
    ####============ step1: ============= 
    ### build model, load params and use ddp 
    model = get_model(args)
    model = load_model_params( args, model, 
                                    fine_tune=args.fine_tune , 
                                    model_path=args.pretrain_model_path
                             )

    cuda_device = init_device()
    device=torch.device("cuda:{}".format(args.local_rank)) if torch.cuda.is_available() else torch.device("cpu") 
    model = model.to(device)
    if torch.cuda.is_available():
        model = DistributedDataParallel(model, device_ids=[ args.local_rank ], output_device=[args.local_rank ] )
    print("Model is set ")
    
    ####============ step2: ============= 
    ###read datasets and return dataloader
    train_data, train_batch_cnt = get_data_reader( data_fp=args.train_set , batch_size=args.batch_size , args=args, max_len=args.max_len ) 
    dev_data, _ =get_data_reader(  data_fp=args.dev_set ,  args=args, batch_size=args.batch_size, max_len=args.max_len ) 
    print("finish data loader")

    ####============ step3: ============= 
    ### build optimizer
    optimizer = MyOptimizer(    model.parameters() ,   
                                train_batch_cnt=train_batch_cnt,  train_epoch_cnt=args.n_epoch,
                                cold_lr=args.cold_lr, learning_rate=args.lr 
                                )
    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    ####============ step4: ============= 
    ## clearify the trainer 
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      scheduler=None,
                      iterator=None,
                      train_dataset=train_data,
                      validation_dataset=dev_data,
                      serialization_dir=args.model_dir,
                      patience=args.patience,
                      num_epochs=args.n_epoch,
                      cuda_device=cuda_device,
                      shuffle=False,
                      cold_step_count=args.cold_steps_count,
                      cold_lr=args.cold_lr,
            
                      local_rank=args.local_rank,
                      grad_clipping=1.0
                      )
    print("Start training")
    trainer.train()

    print("Model is dumped")

if __name__ == '__main__':
    # read parameters
    parser = default_parser()
    args = parser.parse_args()
    main(args)

