
import datetime
import sys

sys.path.append('bart_predict_topic') #
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    Trainer,
    Seq2SeqTrainer,
    BertTokenizerFast,
    AutoTokenizer,
    T5ForConditionalGeneration,
)
import re
import torch
from torch.cuda.amp import autocast, GradScaler

from time import sleep
from transformers import AutoTokenizer, BartForConditionalGeneration
import torch
import argparse
from torch import distributed as dist
from tqdm import tqdm
import numpy as np
import random

from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from loguru import logger
from predict_target_cross_attention_data_prob.data_utils_bart_topic2_predict_target_cross_attention_prob import ABSADataset

import argparse
import os
import json
from transformers import AdamW, TrainingArguments , AutoModelForCausalLM
from os.path import join
import torch
# from component.datacollator import CaptionCollator
from torch.utils.data import DataLoader
from tqdm import tqdm
from bart_classify_model.classify_model2_predict_target_cross_attention_data_prob import classify_model #_avg_repeat
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
# import myTrain from 




def get_dataset(tokenizer, type_path, args, model = None, embedding_processor=None ):
    return ABSADataset(tokenizer, "data/dataset/jsons_en", data_type=type_path,  max_len=args.max_seq_length)



def evaluate(data_loader, model, tokenizer, local_rank = 0):
    """
    Compute scores given the predictions and gold labels
    """
    device = torch.device(f'cuda:{local_rank}')
    model.to(device)
    model.eval()

    
    outputs, targets = [], []
    inputs = []

    predict_topic_all_num = 0
    predict_topic_right_num = 0

    predict_topic_list_all_num = 0
    predict_topic_list_right_num = 0    

    print("data_loader",len(data_loader))
    for batch in tqdm(data_loader):
        if batch is None:
            continue
        input_ids = batch["input_ids"].cuda(non_blocking=True)
        attention_mask = batch["attention_mask"].cuda(non_blocking=True)
        # target_mask = batch["target_mask"].cuda(non_blocking=True)

        # path_attention_mask = batch["path_attention_mask"].cuda(non_blocking=True)
        sentence_mask = batch["sentence_mask"].cuda(non_blocking=True)
        target_ids = batch["target_ids"] #.cuda(non_blocking=True)
        reply_as_topic = batch["reply_as_topic"]
        tagging_label = batch["tagging_label"].cuda(non_blocking=True)
        new_topic_struct = batch["new_topic_struct"]

        max_sentence_num = batch["max_sentence_num"]
        ancient_list = None #batch["ancient_list"]
        # reply_attention_mask = batch["reply_attention_mask"].cuda(non_blocking=True)
        # speaker_attention_mask = batch["speaker_attention_mask"].cuda(non_blocking=True)
        # print("input_ids.tolist()",input_ids.tolist())
        # path = batch["path"]
        #：topic，topic， attention mask
        #，encoder，
        all_topic_attention_mask , predict_topic_right_num1, topic_list1 = model.module.predict_topic(  input_ids = input_ids,
                                                       tokenizer=tokenizer,
                            attention_mask= attention_mask,
                            sentence_mask=sentence_mask,
                            reply_as_topic = reply_as_topic,
                            tagging_label = tagging_label,
                            new_topic_struct= new_topic_struct,
                            max_sentence_num = max_sentence_num,
                            ancient_list = ancient_list
                            )
        
        print("predict_topic_right_num1",predict_topic_right_num1)
        predict_topic_right_num += predict_topic_right_num1
        predict_topic_all_num += (batch["tagging_label"]  >= 0).int().sum()

        # #topic list
        predict_topic_list_all_num += len(new_topic_struct)
        for each_topic in topic_list1:
            # if each_topic in new_topic_struct:
            #
            for right_topic in new_topic_struct:
                if set(right_topic) . issubset( set( each_topic) ):
                    predict_topic_list_right_num  += 1
                    break

        #all_topic_attention_mask ?? . ,.motivation
        # all_topic_attention_mask = torch.cat (  ( attention_mask , all_topic_attention_mask), dim=-2  )

        all_topic_attention_mask = all_topic_attention_mask.to(input_ids.device)
        
        seq_len = input_ids.shape[1]




        # print("all_topic_attention_mask",all_topic_attention_mask)
        path_num, _= all_topic_attention_mask.shape
        # print("input_ids.shape",input_ids.shape)
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(dim = 0)
            attention_mask = attention_mask.unsqueeze(dim=0)
        input_ids = input_ids.repeat(path_num,1).reshape(path_num,seq_len)
        attention_mask = attention_mask.repeat(path_num,1).reshape(path_num,seq_len)

        # print("input_ids",input_ids.shape)
        # print("path_attention_mask",path_attention_mask.shape)
        # print("path",path)
        outs = model.module.generate(  input_ids=input_ids,
                              attention_mask = attention_mask,
                                
                                # path_attention_mask = path_attention_mask,
                                max_length=300,
                                no_repeat_ngram_size = 20,
                                early_stopping=True,
                                # reply_attention_mask=reply_attention_mask,
                                # speaker_attention_mask= speaker_attention_mask,
                                num_beams=2,
                                path_attention_mask=all_topic_attention_mask,

                                )  # num_beams=8, early_stopping=True)


        # print("outs",outs.shape)

        dec = [tokenizer.decode(ids, skip_special_tokens=True)
               for ids in outs]
        # dec = [tokenizer.decode(ids, skip_special_tokens=True)
        #        for ids in fixed_num_outs]
        print("",dec)
        # 
        target = [tokenizer.decode(ids, skip_special_tokens=True)
                  for ids in  target_ids ]
        # print(":")
        # [print(x) for x in target]
        

        # input1 = [tokenizer.decode(ids, skip_special_tokens=True)
                #   for ids in batch['input_ids']]  # 
        # print(":",input1)

        #
        for result_index, each_result in enumerate( dec):
            # print("",)
            # print("each_result",each_result)
            qual = [x.strip() for x  in each_result.split(".")]
            qual = list(set(qual))
            dec[result_index] = ".".join(qual)
            # print("dec[result_index]",dec[result_index])


        # #
        # print("path",path)

        # for result_index, each_path_result in enumerate( dec ):#，
        #     result_path = path[result_index ]
        #     print("",result_path)
        #     for cut in range(1 , len (result_path)):

        #         parent_path_index = path.index( result_path[:cut] ) #+1. :len - 1。 （）
        #         del_qual = []
        #         print("",path[parent_path_index])
        #         print("",dec[parent_path_index].strip().split("."))
        #         print("",each_path_result.strip().split("."))
        #         for each_result in dec[parent_path_index].strip().split("."): 
        #             if each_result.strip() not in each_path_result:#，，
        #                 del_qual.append(each_result)
        #         print("",del_qual)
        #         for each_del in del_qual:
        #             print("",each_del)
        #             dec[parent_path_index] = dec[parent_path_index].replace(each_del,"")
                
        #         print("",dec[parent_path_index])


                # for path_idx , each_path in enumerate( path ): # 
                #     if path[path_idx] == result_path[:cut]: # . +1，+1
                #         #
                #         del_qual = []
                #         print("",path[path_idx])
                #         print("",dec[path_idx].strip().split("."))
                #         print("",each_path_result.strip().split("."))
                #         for each_result in dec[path_idx].strip().split("."): 
                #             if each_result.strip() not in each_path_result:#，，
                #                 del_qual.append(each_result)
                #         print("",del_qual)
                #         for each_del in del_qual:
                #             print("",each_del)
                #             dec[path_idx] = dec[path_idx].replace(each_del,"")
                        
                #         print("",dec[path_idx])
            

        #batchtarget 
        pattern = r'The (.+) is (.+), because (.+) is (.+)'        
        pred_qual_list = [x.strip().split(".") for x in dec]
        final_result = []
        for each_qual_list  in pred_qual_list:
            for each_qual in each_qual_list:
                # re.match()
                if each_qual.strip() not in final_result and re.match(pattern, each_qual) :
                # if each_qual.strip() not in final_result:
                    final_result.append( each_qual.strip() )
                    
        # print("final_result",final_result)
        # print("target",target)
        target_qual_list = [x.strip().split(".") for x in target]
        final_target = []
        for each_qual_list  in target_qual_list:
            for each_qual in each_qual_list:
                if each_qual.strip() not in final_target:
                    # print("each_qual.strip()",each_qual.strip())
                    # print("final_target",final_target)
                    final_target.append( each_qual.strip() )
        # print("final_target",final_target)
        outputs.append( ".".join(final_result))
        targets.append( ".".join(final_target) )
        # inputs.extend(input1)



        input_ids = None
        attention_mask = None
        # target_mask = batch["target_mask"].cuda(non_blocking=True)

        # path_attention_mask = batch["path_attention_mask"].cuda(non_blocking=True)
        sentence_mask = None

        tagging_label = None
        device = torch.device(f"cuda:{local_rank}")  # GPU
        torch.cuda.empty_cache()


    print("\nPrint some results to check the sanity of generation method:", '\n', '-'*30)
    # print the error part
#     for i, output in enumerate(outputs):
#         if output != targets[i]:
#             try:
#                 print(f'"Number     :" {i}')
#                 print(f'>>Input     : {inputs[i]}')
#                 print(f'>>Target    : {targets[i]}')
#                 print(f'>>Generation: {outputs[i]}')
#             except UnicodeEncodeError:
#                 print('Unable to print due to the coding error')



    micro_right_num = 0
    total_pred_num = 0
    total_gt_num = 0

    identified_right_num = 0
    for index , (each_reuslt, each_label) in enumerate(zip(outputs, targets)):
        # print("",inputs[index])
        pred_list = [x.strip() for x in each_reuslt.split(".")]
        #
        pred_list = list(set(pred_list))

        target_list = [x.strip() for x in each_label.split(".")]
        #
        target_list = list(set(target_list))


        #
        while "" in pred_list:
            pred_list.remove("")

        while "" in target_list:
            target_list.remove("")

        # print("pred_list",pred_list)
        # print("target_list",target_list)
        total_pred_num += len( pred_list )
        total_gt_num += len( target_list) 

        #micro F1
        for pred in pred_list:
            # print("pred",pred)
            if pred in target_list and len(pred) > 0:
                micro_right_num += 1
#             print("1")

        #identified F1。 
            #target
        pattern = r'The (.+) is (.+), because (.+) is (.+)'  
        target_qual_list = []
        for each_targ in target_list:
            targ_group = re.match(pattern, each_targ)         
            if targ_group:
                target_qual_list . append([targ_group.group(1), targ_group.group(3),targ_group.group(4)] )               
            else:
                print("each_targ",each_targ)
                print("targ_group",targ_group)


        pred_qual_list = []
        for each_pred in pred_list:
            pred_group = re.match(pattern, each_pred) 
            pred_qual_list .append( [pred_group.group(1), pred_group.group(3),pred_group.group(4)] )
        
        for pred in pred_qual_list:
    
            if pred in target_qual_list and len(pred) > 0:
                identified_right_num += 1
#             print("1")

        
    micro_p = micro_right_num/total_pred_num  if total_pred_num > 0 else 0
    micro_r = micro_right_num/total_gt_num 
    micro_f1 = 2* micro_p * micro_r /(micro_p + micro_r)  if micro_p > 0 and micro_r > 0  else 0

    identified_p = identified_right_num /total_pred_num  if total_pred_num > 0 else 0
    identified_r = identified_right_num/total_gt_num 
    identified_f1 = 2* identified_p * identified_r /(identified_p + identified_r)  if identified_p > 0 and identified_r > 0  else 0


    print("micro_right_num", micro_right_num )
    print("identified_right_num", identified_right_num )
    print("", total_pred_num )
    print("gt", total_gt_num )
    print("topic", predict_topic_right_num / predict_topic_all_num)
    print("topic list ", predict_topic_list_right_num / predict_topic_list_all_num)

    print("==================" )
    print("micro_p", micro_p )
    print("micro_r", micro_r )
    print("micro_f1", micro_f1 )

    print("=================" )

    print("identified_p", identified_p )
    print("identified_r", identified_r )
    print("identified_f1", identified_f1 )

    return identified_f1

def seed_everything(seed):
    if seed >= 10000:
        raise ValueError("seed number should be less than 10000")
    if torch.distributed.is_initialized():
        rank = torch.distributed.get_rank()
    else:
        rank = 0
    seed = (rank * 100000) + seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) # hash，
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def _init_fn(worker_id):
    print("worker_id",worker_id)
    np.random.seed(  worker_id)

def main():
#     dist.init_process_group(backend="nccl")  # ，'nccl'
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', default=-1, type=int,
                    help='node rank for distributed training')
    parser.add_argument('--seed', default=42, type=int,
                    help='node rank for distributed training')
    parser.add_argument('--number_gpu', default=3, type=int,
                    help='node rank for distributed training')
    
    parser.add_argument('--cross_attention_layer_num', default=3, type=int,
                    help='node rank for distributed training')
    
    parser.add_argument('--train_batch_size', default=4, type=int,
                    help='node rank for distributed training')
    parser.add_argument('--max_seq_length', default=512, type=int,
                    help='node rank for distributed training')
    parser.add_argument('--output_dir', default="./", type=str,
                    help='node rank for distributed training') 
    
    parser.add_argument('--train_dataset', default="./", type=str,
                    help='node rank for distributed training')    
    parser.add_argument('--model_save_path', default="./", type=str,
                    help='node rank for distributed training')
    parser.add_argument('--model_name_or_path', default="./", type=str,
                    help='node rank for distributed training')
    parser.add_argument("--weight_decay", default=0.0, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--warmup_steps_rate", default=0.0, type=float) 
    parser.add_argument("--dropout_rate", default=0.0, type=float)
    parser.add_argument("--learning_rate", default=3e-4, type=float)

    parser.add_argument("--tagging_learning_rate", default=3e-4, type=float)
    parser.add_argument("--topic_learning_rate", default=3e-4, type=float)
    parser.add_argument("--attention_learning_rate", default=3e-4, type=float)
    parser.add_argument("--alpha_learning_rate", default=3e-4, type=float)


    parser.add_argument("--tag_ratio", default=3e-4, type=float)
    parser.add_argument("--generation_ratio", default=3e-4, type=float)
    parser.add_argument("--topic_ratio", default=3e-4, type=float)

    parser.add_argument("--warmup_ratio", default=0.01, type=float)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--num_train_epochs", default=30, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--eval_batch_size", default=16, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev/test set.")
    
    parser.add_argument("--use_amp", action='store_true',
                        help="Whether to run eval on the dev/test set.") 
    args = parser.parse_args()
    
    # torch.distributed.init_process_group(backend="nccl")

    torch.distributed.init_process_group(backend='nccl', init_method='env://',world_size=args.number_gpu, rank=args.local_rank, timeout=datetime.timedelta(seconds=5400))
    # os.environ['MASTER_ADDR'] = '127.0.0.1'
    # keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.)
    args.is_master = args.local_rank == 0

    if args.is_master:
        print("",args)

        # 
        namespace_dict = vars(args)

        # 
        for key, value in namespace_dict.items():
            print(f"{key}: {value}")

    # set the device
    args.device = torch.cuda.device(args.local_rank)
    local_rank = args.local_rank
    # 



    
    #
    #
    
    
    seed_everything(args.seed)
    seed_torch(args.seed)
    
    # 
#     if not os.path.exists(training_args.output_dir):
#         os.makedirs(training_args.output_dir)
    os.makedirs(args.output_dir,exist_ok=True)
    # 

    # with open(join(args.output_dir, 'train_args.json'), 'w', encoding='utf8') as f:
    #     args_json = json.dumps(vars(args))
    #     json.dump(args_json, f, indent=2)
        
        
    # training process
    if True:
        print("\n****** Conduct Training ******")
        
        # 
        set_seed(args.seed)
    
        # 
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

        model = classify_model.from_pretrained(args.model_name_or_path, args=args)

        model_size = sum(p.numel() for p in model.parameters())
        print(f"Model size: {model_size} parameters")
        #attention，
        model_state_dict = model.state_dict()
        # print("model_state_dict",model_state_dict.keys())
        # for name in model_state_dict.keys():
        #     if ".0.gat" in name:
        #         print(name)
        # for i in range(model.config.encoder_layers):
        #     model_state_dict[f'model.encoder.layers.{i}.gat.k_proj.weight'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.k_proj.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.k_proj.bias'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.k_proj.bias']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.v_proj.weight'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.v_proj.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.v_proj.bias'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.v_proj.bias']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.q_proj.weight'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.q_proj.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.q_proj.bias'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.q_proj.bias']

        #     model_state_dict[f'model.encoder.layers.{i}.gat.out_proj.weight'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.out_proj.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat.out_proj.bias'] = model_state_dict[f'model.encoder.layers.{i}.self_attn.out_proj.bias']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_norm.weight'] = model_state_dict[f'model.encoder.layers.{i}.self_attn_layer_norm.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_norm.bias'] = model_state_dict[f'model.encoder.layers.{i}.self_attn_layer_norm.bias']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_fc1.weight'] = model_state_dict[f'model.encoder.layers.{i}.fc1.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_fc1.bias'] = model_state_dict[f'model.encoder.layers.{i}.fc1.bias']

        #     model_state_dict[f'model.encoder.layers.{i}.gat_fc2.weight'] = model_state_dict[f'model.encoder.layers.{i}.fc2.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_fc2.bias'] = model_state_dict[f'model.encoder.layers.{i}.fc2.bias']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_final_layer_norm.weight'] = model_state_dict[f'model.encoder.layers.{i}.final_layer_norm.weight']
        #     model_state_dict[f'model.encoder.layers.{i}.gat_final_layer_norm.bias'] = model_state_dict[f'model.encoder.layers.{i}.final_layer_norm.bias']

        # model.load_state_dict(model_state_dict)
        # encoder，decoderfinetune
#         if args.freeze_encoder:
#             for name, param in model.encoder.named_parameters():
#                 # encoderdecoder
#                 # ，
#                 if 'embed_tokens' in name and not args.freeze_word_embed:
#                     param.requires_grad = True
#                 # 
#                 else:
#                     param.requires_grad = False
        total = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logger.info("Total training params: %.2fM" % (total / 1e6))
    
    

        #
        
        # GPU ，：
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", local_rank)
        model.to(device)
        
        #
        tokenizer.add_special_tokens({'additional_special_tokens':["<d>"]})
        model.resize_token_embeddings(len(tokenizer))
        #augment_training_data augment_training_data
        train_dataset = get_dataset(tokenizer =tokenizer, type_path=args.train_dataset,  args=args, ) #augment_training_data  train augment_training_data_order_output
        
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, drop_last=True, sampler=train_sampler, num_workers=8,
                                                    worker_init_fn=_init_fn, pin_memory=True) #collate_fn = lambda x: train_dataset.train_collate_fn(x, tokenizer) ,
        print("len(train_loader)", len(train_loader))
        # val_dataset = get_dataset(tokenizer = tokenizer, type_path="valid", args=args, )
        # val_loader = DataLoader(val_dataset, batch_size=1, num_workers=4, worker_init_fn=_init_fn,collate_fn=lambda x: val_dataset.test_collate_fn(x, tokenizer), ) #lambda x: val_dataset.collate_fn(x, tokenizer,args.max_seq_length)
    


        test_dataset = get_dataset(tokenizer = tokenizer, type_path="test", args=args, )
        test_sampler=torch.utils.data.distributed.DistributedSampler(test_dataset)
        test_loader = DataLoader(test_dataset,  batch_size=1, num_workers=8, worker_init_fn=_init_fn, collate_fn=lambda x: test_dataset.test_collate_fn(x, tokenizer) ) #lambda x: val_dataset.collate_fn(x, tokenizer,args.max_seq_length)
        print("len(test_loader)", len(test_loader))
        #model

        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],  output_device=args.local_rank, find_unused_parameters=True, broadcast_buffers=True,)
        

        
        #
        no_decay = [ "bias" , "LayerNorm.weight"] # "bias", "LayerNorm.weight"
        para1 = []
        para2 = []
        para3 = []
        para4 = []
        para5 = []
        para6 = []       
        para7 = []
        para8 = []   
        para13 = []
        para14 = []  
        for n, p in model.module.named_parameters():
#             print("name",n)
            if   "tag" in n:
                
                if any(nd in n for nd in no_decay): #decay
                    para1.append(p)
                    
                else:
                    para2.append(p)
            elif "topic_classify_layer" in n:
                if any(nd in n for nd in no_decay): #decay
                    para7.append(p)
                    
                else:
                    para8.append(p)
            elif  "topic_layer" in n : #attention 
                if any(nd in n for nd in no_decay): #decay
                    para5.append(p)
                    
                else:
                    para6.append(p)

            elif  "alpha" in n or "beta" in n : #attention 
                if any(nd in n for nd in no_decay): #decay
                    para13.append(p)
                    
                else:
                    para14.append(p)
            
            else:
                if any(nd in n for nd in no_decay):
                    para3.append(p)
                    
                else:
                    para4.append(p)
            
        if para1 == [] or para2 == [] or para3 == [] or para4 == []or para5 == []or para6 == []or para7 == []or para8 == []or para13 == []or para14 == []:
            print("")
            
        
        optimizer_grouped_parameters = [
            {
                "params": para4,
                "weight_decay": args.weight_decay,
                "lr":args.learning_rate,
            },

            {
                "params": para3,
                "weight_decay": 0.0,
                "lr":args.learning_rate,
            },

            {
                "params": para2,
                "weight_decay": args.weight_decay,
                "lr":args.tagging_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para1,
                "weight_decay": 0.0,
                "lr":args.tagging_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para6,
                "weight_decay": args.weight_decay,
                "lr":args.attention_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para5,
                "weight_decay": 0.0,
                "lr":args.attention_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para7,
                "weight_decay": args.weight_decay,
                "lr":args.topic_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para8,
                "weight_decay": 0.0,
                "lr":args.topic_learning_rate, #args.learning_rate, #1e-3,
            },         

            {
                "params": para13,
                "weight_decay": args.weight_decay,
                "lr":args.alpha_learning_rate, #args.learning_rate, #1e-3,
            },

            {
                "params": para14,
                "weight_decay": 0.0,
                "lr":args.alpha_learning_rate, #args.learning_rate, #1e-3,
            },           


        ]
        # print("optimizer_grouped_parameters",optimizer_grouped_parameters)
        optimizer = AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon,betas=(0.9, 0.98) ) # eps=args.adam_epsilon,betas=(0.9, 0.98)

        
        # optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate )

        t_total = (
            (len(train_loader.dataset) // (args.train_batch_size * args.number_gpu  )) #* max(1, eval(args.n_gpu))
            // args.gradient_accumulation_steps
            * float(args.num_train_epochs)
        )
        print("t_total",t_total)        
        lr_decrease = 1 / t_total  * args.learning_rate
        scheduler = get_linear_schedule_with_warmup(
            optimizer,  num_warmup_steps=args.warmup_ratio*t_total, num_training_steps=t_total
        )
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.num_train_epochs, gamma=0.7)  
        scaler = GradScaler()

        #
        all_step = 0
        max_f1_result = -100
        for epoch in range(args.num_train_epochs):
            model.train()

            #

            # if train_sampler is not None:
            #     train_sampler.set_epoch(epoch)
            print("fuck")
            # if True:
            #     for _, batch in enumerate(train_loader):
            with tqdm(train_loader, unit="batch") as tepoch: # 🌟 1. 
                # print("tepoch",len(tepoch) )
                for step, batch in enumerate( tepoch):              # 🌟 2. 
                    tepoch.set_description(f"Epoch {epoch}") # 🌟 3. 
                # for batch in tqdm(train_loader):
                    
                    # 
                    input_ids = None
                    # attention_mask = None

                    # target_mask = batch["target_mask"].cuda(non_blocking=True)
                    target_ids = None


                    # print("attention_mask",attention_mask[0])

                    path_attention_mask = None

                    sentence_mask = None
                    all_topic_matrix_label = None
                    tagging_label = None
                    # all_sentence_topic_mask = batch["all_sentence_topic_mask"].cuda(non_blocking=True)

                    all_sentence_cross_attention_mask = None
                    all_topic_cross_attention_mask = None
                    device = torch.device(f"cuda:{local_rank}")  # GPU
                    torch.cuda.empty_cache()
                    
                    if local_rank == 0:
                        all_step += 1
                    if batch is None:
                        continue


                    # print("input_ids",batch["input_ids"].shape)
                    # print("input_ids.shape[0]",batch["input_ids"].shape[0])
                    # max_num = batch["input_ids"].shape[0]
                    # max_seq = 8*420
                    # if batch["input_ids"].shape[0] * batch["input_ids"].shape[1] >= max_seq:
                    #     while max_num  * batch["input_ids"].shape[1] >=  max_seq:
                    #         max_num -= 1


                    input_ids = batch["input_ids"].cuda(non_blocking=True)
                    # print("input_ids",input_ids.shape)
                    attention_mask = batch["attention_mask"].cuda(non_blocking=True)

                    # target_mask = batch["target_mask"].cuda(non_blocking=True)
                    target_ids = batch["target_ids"].cuda(non_blocking=True)
                    target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100

                    # print("attention_mask",attention_mask[0])

                    path_attention_mask = batch["path_attention_mask"].cuda(non_blocking=True)

                    sentence_mask = batch["sentence_mask"].cuda(non_blocking=True)
                    all_topic_matrix_label = batch["all_topic_matrix_label"].cuda(non_blocking=True)
                    tagging_label = batch["tagging_label"].cuda(non_blocking=True)
                    # all_sentence_topic_mask = batch["all_sentence_topic_mask"].cuda(non_blocking=True)

                    all_sentence_cross_attention_mask = batch["all_sentence_cross_attention_mask"].cuda(non_blocking=True)
                    all_topic_cross_attention_mask = batch[ "all_topic_cross_attention_mask"].cuda(non_blocking=True)
                    triplets_num = batch["triplets_num"].cuda(non_blocking=True)
                    all_ancient_list = None #batch["ancient_list"]
                    # ce_weight = batch["ce_weight"].cuda(non_blocking=True)
                    # reply_attention_mask = batch["reply_attention_mask"].cuda(non_blocking=True)
                    # speaker_attention_mask = batch["speaker_attention_mask"].cuda(non_blocking=True)
    #                 return_loss = batch["return_loss"].cuda(non_blocking=True)
                    # print("attention_mask",attention_mask.shape)
                    # print("reply_attention_mask",reply_attention_mask.shape)


                    
                    # 
                    if args.use_amp:
                        print("amp")
                        optimizer.zero_grad()
                        with autocast():
                            loss = model(input_ids = input_ids, attention_mask = attention_mask, labels = target_ids, 
                                        path_attention_mask = path_attention_mask,
                                        tokenizer= tokenizer,
                                        sentence_mask=sentence_mask,
                                        all_topic_matrix_label = all_topic_matrix_label,
                                        tagging_label = tagging_label,
                                        all_ancient_list=  all_ancient_list,
                                        # reply_attention_mask=reply_attention_mask,
                                        # speaker_attention_mask=speaker_attention_mask, 
                                        ).loss
                        print("loss",loss.item())
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        # print("amp")



                        # max_num = 9
                        # print(f"{local_rank}")
                        # print(f"{local_rank}batch",batch)
                        loss = model(input_ids = input_ids, attention_mask = attention_mask, labels = target_ids,  #
                                        path_attention_mask = path_attention_mask, #
                                        tokenizer= tokenizer,
                                        sentence_mask=sentence_mask,
                                        all_topic_matrix_label = all_topic_matrix_label,
                                        tagging_label = tagging_label,
                                        all_ancient_list=  all_ancient_list,
                                        # all_sentence_topic_mask = all_sentence_topic_mask,
                                        all_sentence_cross_attention_mask = all_sentence_cross_attention_mask,
                                        all_topic_cross_attention_mask = all_topic_cross_attention_mask,
                                        tag_ratio = args.tag_ratio,
                                        generation_ratio = args.generation_ratio,
                                        topic_ratio = args.topic_ratio,
                                        is_train = True,
                                        triplets_num = triplets_num,
                                        # ce_weight = ce_weight,
                                        # reply_attention_mask=reply_attention_mask,
                                        # speaker_attention_mask=speaker_attention_mask, 
                                        ).loss
                        
                        # torch.distributed.barrier()
                        # torch.distributed.barrier()
                        # print("loss",loss.item())


                        # loss = loss / args.gradient_accumulation_steps  # 
                        # loss.backward()

                        # if (step + 1) % args.gradient_accumulation_steps == 0:
                        #     # ，
                        #     optimizer.step()
                        #     optimizer.zero_grad()
                        #     scheduler.step()

                        optimizer.zero_grad(set_to_none=True) #
                        loss.backward()
                        
        # #                 for name, param in model.named_parameters():
        # #                     if param.grad is None:
        # #                         print(name)

        #                 #
        #                 # optimizer.param_groups[0]['lr'] -= lr_decrease
        #                 # optimizer.param_groups[1]['lr'] -= lr_decrease
                        
                        optimizer.step()
                        # print("lr",optimizer.param_groups[0]['lr'])
                        scheduler.step()
                    tepoch.set_postfix(loss=loss.item())
                    # sleep(0.01)
                    # print(f"{local_rank}")

                    #
                    # input_ids = None
                    # # attention_mask = None

                    # # target_mask = batch["target_mask"].cuda(non_blocking=True)
                    # # target_ids = None


                    # # print("attention_mask",attention_mask[0])

                    # path_attention_mask = None

                    # sentence_mask = None
                    # all_topic_matrix_label = None
                    # tagging_label = None
                    # # all_sentence_topic_mask = batch["all_sentence_topic_mask"].cuda(non_blocking=True)

                    # all_sentence_cross_attention_mask = None
                    # all_topic_cross_attention_mask = None
                    # device = torch.device(f"cuda:{local_rank}")  # GPU
                    # torch.cuda.empty_cache()



            # torch.distributed.barrier()
            if local_rank == 0:

                # 
                input_ids = None
                # attention_mask = None

                # target_mask = batch["target_mask"].cuda(non_blocking=True)
                target_ids = None


                # print("attention_mask",attention_mask[0])

                path_attention_mask = None

                sentence_mask = None
                all_topic_matrix_label = None
                tagging_label = None
                # all_sentence_topic_mask = batch["all_sentence_topic_mask"].cuda(non_blocking=True)

                all_sentence_cross_attention_mask = None
                all_topic_cross_attention_mask = None
                device = torch.device(f"cuda:{local_rank}")  # GPU
                torch.cuda.empty_cache()
                model.eval()
                logger.info("*** start  ***")


                with torch.no_grad():
                    f1_result = evaluate(test_loader, model, tokenizer, local_rank)
                    # if f1_result > max_f1_result:
                    #     max_f1_result = f1_result
                    #     print(f"{epoch}epoch")
                    #     torch.save(model.module.state_dict(), args.model_save_path)

                    # print("test")
                    # f1_result = evaluate(test_loader, model, tokenizer)
            # torch.distributed.barrier()
                    
                for n, p in model.module.named_parameters():
        #             print("name",n)
                    if  "alpha" in n : #attention 
                        print("alpha",p)

                    if  "beta" in n  : #attention 
                        print("beta",p)

        print("all_step",all_step)

        print("Finish training and saving the model!")
        print(f",{local_rank}")


#     # 
#     if local_rank == 0:

#         logger.info("*** start test ***")
#         # test_dataset = get_dataset(tokenizer = tokenizer, type_path="test", args=args, model= model,)
#         # test_loader = DataLoader(test_dataset, batch_size=8, num_workers=4, worker_init_fn=_init_fn)
    
#         saved_model_path = args.model_save_path  # 
#         model.module.load_state_dict(torch.load(saved_model_path))
#         evaluate(test_loader, model, tokenizer, local_rank)
# #         metrics = trainer.evaluate(test_dataset)
# #         trainer.log_metrics("test", metrics)
# #         trainer.save_metrics("test", metrics)
#     device = torch.device(f"cuda:{local_rank}")  # GPU
#     torch.cuda.empty_cache()
        


if __name__ == '__main__':
    main()
    
