import argparse
# hyperparameter
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type = str)
parser.add_argument("--model_name", type = str)
parser.add_argument("--dataset_name", type = str)
parser.add_argument("--N", type = int)
parser.add_argument("--C", type = int)
parser.add_argument("--R", type = int)
parser.add_argument("--prefix_len", type = int)
parser.add_argument("--aux_layer_num", type = int)
parser.add_argument("--server_lr", type = float)
parser.add_argument("--client_lr", type = float)
parser.add_argument("--init_lr", type = float)
parser.add_argument("--use_init_train", type = int)
parser.add_argument("--use_client_train", type = int)
parser.add_argument("--use_server_train", type = int)



parser.add_argument("--seed", type = int)
parser.add_argument("--device", type = int)
args = parser.parse_args()

run_name = str(args.run_name)
model_name = str(args.model_name)
dataset_name = str(args.dataset_name)
N = int(args.N)
C = int(args.C)
R = int(args.R)
prefix_len = int(args.prefix_len)
aux_layer_num = int(args.aux_layer_num)

init_lr = float(args.init_lr)
server_lr = float(args.server_lr)
client_lr = float(args.client_lr)
use_init_train = int(args.use_init_train)
use_server_train = int(args.use_server_train)
use_client_train = int(args.use_client_train)


seed = int(args.seed)
device = int(args.device)


import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)

import torch
torch.manual_seed(seed)

import copy
from datetime import datetime
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GenerationConfig
import torch.nn as nn
import shutil


from data_utils import NLG_Dataset_Manager
from other_utils import set_log, client_choose, generate_prepare
from FedSP_trainer_utils import InitTrainer, ClientTrainer, ServerTrainer
from model_utils import FedSPModel



torch.manual_seed(seed)
if model_name == "gpt2_medium":
    tokenizer_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_medium/tokenizer"


batch_size = 4
init_gradient_accumulation_steps = 2
client_gradient_accumulation_steps = 4
server_gradient_accumulation_steps = 4


if dataset_name == "e2e":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/e2e_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    server_max_train_step = -1
    server_max_eval_step = -1
    server_max_test_step = -1
elif dataset_name == "dart":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/dart_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    server_max_train_step = -1
    server_max_eval_step = -1
    server_max_test_step = -1


elif dataset_name == "dialogsum":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/dialogsum_NLG"
    max_length = 400
    weight_decay = 0.01
    label_smoothing = 0.1

    server_max_train_step = -1
    server_max_eval_step = -1
    server_max_test_step = -1

    batch_size = 2
    init_gradient_accumulation_steps = 4
    client_gradient_accumulation_steps = 8
    server_gradient_accumulation_steps = 8


elif dataset_name == "viggo":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/viggo_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    server_max_train_step = -1
    server_max_eval_step = -1
    server_max_test_step = -1

if use_server_train == 1:
    init_epochs = 1
else:
    init_epochs = 0


init_max_train_step = -1


if use_client_train == 1:
    client_model_epochs = 1
    client_prefix_epochs = 1
else:
    client_model_epochs = 0
    client_prefix_epochs = 0  
client_model_lr = client_lr
client_model_max_train_step = -1
client_prefix_lr = client_lr
client_prefix_max_train_step = -1


if use_server_train == 1:
    server_epochs = 1
else:
    server_epochs = 0







time = datetime.now().strftime("%Y-%m-%d %H:%M:%S").replace(":", "").replace(" ", "").replace("-", "")
parent_dir = os.path.join("/home", "zhujh", "20240516_FedSP", "[meta]")
meta_dir =  os.path.join(parent_dir, run_name, f"seed_{str(seed)}", f"time_{time}")
save_dir =  meta_dir.replace("[meta]", "save_models")
generation_dir = meta_dir.replace("[meta]", "generation")
log_dir = meta_dir.replace("[meta]", "logs")
logger = set_log(log_dir)
logger.info(f"Args:\n \
              \tmeta:\n \
              \t\trun_name: {run_name}\n\
              \t\ttime: {time}\n\
              \t\tmodel_name: {model_name}\n\
              \t\tdataset_name: {dataset_name}\n\
              \t\tseed: {str(seed)}\n\
              \tdataset:\n\
              \t\tN: {str(N)}\n\
              \t\tC: {str(C)}\n\
              \t\tmax_length: {str(max_length)}\n\
              \tserver:\n\
              \t\tprefix_len: {str(prefix_len)}\n\
              \t\taux_layer_num: {str(aux_layer_num)}\n\
              \t\tbatch_size: {str(batch_size)}\n\
              \t\tweight_decay: {str(weight_decay)}\n\
              \t\tlabel_smoothing: {str(label_smoothing)}\n\
              \t\tR: {str(R)}\n\
              \tinit_train:\n\
              \t\tinit_gradient_accumulation_steps: {str(init_gradient_accumulation_steps)}\n\
              \t\tinit_epochs: {str(init_epochs)}\n\
              \t\tinit_lr: {str(init_lr)}\n\
              \t\tinit_max_train_step:{str(init_max_train_step)}\n\
              \tclient_train:\n\
              \t\tclient_gradient_accumulation_steps:{str(client_gradient_accumulation_steps)}\n\
              \t\tclient_model_epochs: {str(client_model_epochs)}\n\
              \t\tclient_model_lr: {str(client_model_lr)}\n\
              \t\tclient_model_max_train_step:{str(client_model_max_train_step)}\n\
              \t\tclient_prefix_epochs: {str(client_prefix_epochs)}\n\
              \t\tclient_prefix_lr: {str(client_prefix_lr)}\n\
              \t\tclient_prefix_max_train_step:{str(client_prefix_max_train_step)}\n\
              \tserver_train:\n\
              \t\tuse_server_train:{use_server_train}\n\
              \t\tserver_gradient_accumulation_steps:{str(server_gradient_accumulation_steps)}\n\
              \t\tserver_epochs: {str(server_epochs)}\n\
              \t\tserver_lr: {str(server_lr)}\n\
              \t\tserver_max_train_step:{str(server_max_train_step)}\n\
              \t\tserver_max_eval_step:{str(server_max_eval_step)}\n\
              \t\tserver_max_test_step:{str(server_max_test_step)}"
            )

tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

dataset_manager = NLG_Dataset_Manager(dataset_path, max_length, batch_size, seed, N)
init_trainer = InitTrainer(save_dir, dataset_manager._get_train_loader(0), model_name, 
    weight_decay,
    init_gradient_accumulation_steps, init_epochs, init_lr, init_max_train_step, aux_layer_num,
    logger)
init_trainer._train()

choosen_client_ids, last_r_of_clients = client_choose(R, N, C)


server_loss = [1000]
server_loss_before = [1000]
for r in range(1, R+1):
    os.makedirs(os.path.join(save_dir, str(r)), exist_ok=True)
    for client_id in choosen_client_ids[r]:
        client_trainer = ClientTrainer(
            save_dir, dataset_manager._get_train_loader(client_id), model_name,
            prefix_len, aux_layer_num,
            r, last_r_of_clients[r][client_id], client_id,
            weight_decay, label_smoothing, client_gradient_accumulation_steps,
            client_model_epochs, client_model_lr, client_model_max_train_step,
            client_prefix_epochs, client_prefix_lr, client_prefix_max_train_step,
            logger
        )
        client_trainer._switch_mode("model")
        client_trainer._train("model")
        client_trainer._switch_mode("prompt_encoder")
        client_trainer._train("prompt_encoder")
    server_trainer = ServerTrainer(
        save_dir, dataset_manager._get_train_loader(0), model_name, 
        prefix_len, 
        r, choosen_client_ids[r], 
        weight_decay, label_smoothing, server_gradient_accumulation_steps, 
        server_epochs, server_lr, server_max_train_step, server_max_eval_step, server_max_test_step
    )
    server_loss_before.append(server_trainer._evaluate(dataset_manager._get_valid_loader()))
    logger.info(f"{run_name} server_loss_before round_{str(r)}: {str(server_loss_before)}")
    server_trainer._train(logger)
    server_loss.append(server_trainer._evaluate(dataset_manager._get_valid_loader()))
    logger.info(f"{run_name} server_loss round_{str(r)}: {str(server_loss)}")
    server_trainer._save()

test_data_list, context_pred_refs_dict, context_list = generate_prepare(dataset_manager._get_test_loader(), dataset_path, generation_dir)
best_server_round = server_loss.index(min(server_loss))
logger.info(f"{run_name} best server: {str(best_server_round)}")
server_trainer = ServerTrainer(
    save_dir, dataset_manager._get_train_loader(0), model_name, 
    prefix_len, 
    best_server_round, choosen_client_ids[best_server_round], 
    weight_decay, label_smoothing, server_gradient_accumulation_steps, 
    server_epochs, server_lr, server_max_train_step, server_max_eval_step, server_max_test_step
)
logger.info(f"{run_name} loss of best server: {str(server_trainer._evaluate(dataset_manager._get_valid_loader()))}")
server_trainer._generate(test_data_list, copy.deepcopy(context_pred_refs_dict), context_list, generation_dir, tokenizer)



    
# sudo rm -r /home/zhujh/20240507_FedSP/save_models/t0513_fedsp_test

"""
print(f"r_{r}, client_{client_id}| before model train| model:{client_trainer._model.model.transformer.h[0].attn.c_attn.weight.tolist()[0][0]}")
        print(f"r_{r}, client_{client_id}| before model train| prefix:{client_trainer._model.prefix_encoder.weight.tolist()[0][0]}")


"""