import argparse
# hyperparameter
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type = str)
parser.add_argument("--model_name", type = str)
parser.add_argument("--dataset_name", type = str)
parser.add_argument("--N", type = int)
parser.add_argument("--C", type = int)
parser.add_argument("--update_method", type = int)
parser.add_argument("--adapter_lr", type = float)
parser.add_argument("--R", type = int)
parser.add_argument("--local_epochs", type = int)
parser.add_argument("--lr", type = float)
parser.add_argument("--alg", type = str)
parser.add_argument("--sigma", type = float)
parser.add_argument("--sigma_method", type = int)
parser.add_argument("--sigma_theta", type = float)
parser.add_argument("--seed", type = int)
parser.add_argument("--device", type = int)
args = parser.parse_args()

run_name = str(args.run_name)
model_name = str(args.model_name)
dataset_name = str(args.dataset_name)
N = int(args.N)
C = int(args.C)
update_method = int(args.update_method)
adapter_lr = float(args.adapter_lr)
R = int(args.R)
local_epochs = int(args.local_epochs)
lr = float(args.lr)
alg = str(args.alg)
sigma = float(args.sigma)
sigma_method = int(args.sigma_method)
sigma_theta = float(args.sigma_theta)
seed = int(args.seed)
device = int(args.device)


import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)

import torch
torch.manual_seed(seed)

import copy
from datetime import datetime
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GenerationConfig
from tqdm import tqdm
import torch.nn as nn
from peft import (
    LoraConfig,
    get_peft_model,
)

import utils
from data_utils import NLG_Dataset_Manager
from adapter_utils import AdapterManager
from trainer_utils import NLG_Trainer
from noise_utils import add_gauss_noise_to_model





torch.manual_seed(seed)

if model_name == "gpt2_medium":
    model_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_medium"
    tokenizer_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_medium/tokenizer"
    batch_size = 10
    gradient_accumulation_steps = 1
if model_name == "gpt2_xl":
    model_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_xl"
    tokenizer_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm_xl/tokenizer"
    batch_size = 4
    gradient_accumulation_steps = 2
elif model_name == "gpt2":
    model_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm"
    tokenizer_path = "/home/zhujh/20240417_FedQLoRA/model/gpt2_lm/tokenizer"
    batch_size = 10
    gradient_accumulation_steps = 1


if dataset_name == "e2e":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/e2e_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    max_train_step = -1
    max_eval_step = -1
    max_test_step = -1
elif dataset_name == "dart":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/dart_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    max_train_step = -1
    max_eval_step = -1
    max_test_step = -1


elif dataset_name == "dialogsum":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/dialogsum_NLG"
    max_length = 400
    weight_decay = 0.01
    label_smoothing = 0.1

    max_train_step = -1
    max_eval_step = -1
    max_test_step = -1
    batch_size = 5
    gradient_accumulation_steps = 2


elif dataset_name == "viggo":
    dataset_path = "/home/zhujh/20240417_FedQLoRA/data/viggo_NLG"
    max_length = 128
    weight_decay = 0.01
    label_smoothing = 0.1

    max_train_step = -1
    max_eval_step = -1
    max_test_step = -1


lora_rank = 1
lora_alpha = 32
target_modules = ["c_attn","c_proj","c_fc"]
lora_dropout = 0.05
lora_bias = "none"

num_beams = 10
do_sample = False
no_repeat_ngram_size = 4
length_penalty = 0.9
generation_length = 64



time = datetime.now().strftime("%Y-%m-%d %H:%M:%S").replace(":", "").replace(" ", "").replace("-", "")
meta_dir = f"/home/zhujh/20240417_FedQLoRA/[meta]/{run_name}/{alg}/seed_{str(seed)}/time_{time}"
save_dir = meta_dir.replace("[meta]", "save_models")
generation_dir = meta_dir.replace("[meta]", "generation")
log_dir = meta_dir.replace("[meta]", "logs")
logger = utils.set_log(log_dir)
logger.info(f"Args:\n \
              \tmeta:\n \
              \t\trun_name: {run_name}\n\
              \t\ttime: {time}\n\
              \t\tmodel_name: {model_name}\n\
              \t\tdataset_name: {dataset_name}\n\
              \t\talg: {str(alg)}\n\
              \t\tseed: {str(seed)}\n\
              \tdataset:\n\
              \t\tN: {str(N)}\n\
              \t\tC: {str(C)}\n\
              \t\tmax_length: {str(max_length)}\n\
              \tserver:\n\
              \t\tupdata_method: {str(update_method)}\n\
              \t\tadapter_lr: {str(adapter_lr)}\n\
              \t\tsigma: {str(sigma)}\n\
              \t\tsigma_method: {str(sigma_method)}\n\
              \t\tsigma_theta: {str(sigma_theta)}\n\
              \ttrain:\n\
              \t\tR: {str(R)}\n\
              \t\tlocal_epochs: {str(local_epochs)}\n\
              \t\tlr: {str(lr)}\n\
              \t\tbatch_size: {str(batch_size)}\n\
              \t\tgradient_accumulation_steps: {str(gradient_accumulation_steps)}\n\
              \t\tlabel_smoothing: {str(label_smoothing)}\n\
              \t\tweight_decay: {str(weight_decay)}\n\
              \tearly_stop:\n\
              \t\tmax_train_step: {str(max_train_step)}\n\
              \t\tmax_eval_step: {str(max_eval_step)}\n\
              \t\tmax_test_step: {str(max_test_step)}\n\
              \tlora:\n\
              \t\tlora_rank: {str(lora_rank)}\n\
              \t\tlora_alpha: {str(lora_alpha)}\n\
              \t\ttarget_modules: {str(target_modules)}\n\
              \t\tlora_dropout: {str(lora_dropout)}\n\
              \t\tlora_bias: {str(lora_bias)}\n\
              \tgeneration:\n\
              \t\tnum_beams: {str(num_beams)}\n\
              \t\tdo_sample: {str(do_sample)}\n\
              \t\tno_repeat_ngram_size: {str(no_repeat_ngram_size)}\n\
              \t\tlength_penalty: {str(length_penalty)}"
            )


base_model = GPT2LMHeadModel.from_pretrained(model_path, device_map = "auto")
base_model = add_gauss_noise_to_model(base_model, model_name, sigma, sigma_method, sigma_theta)
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token_id = 0
generation_config = GenerationConfig(
    num_beams = num_beams,
    do_sample = do_sample,
    no_repeat_ngram_size = no_repeat_ngram_size,
    length_penalty = length_penalty,
)
lora_config = LoraConfig(
    r = lora_rank,
    lora_alpha = lora_alpha,
    target_modules = target_modules,
    lora_dropout = lora_dropout,
    bias = lora_bias,
)
lora_model = get_peft_model(base_model, lora_config)
lora_model.print_trainable_parameters()

dataset_manager = NLG_Dataset_Manager(dataset_path, max_length, batch_size, seed, N)
adapter_manager = AdapterManager(alg, adapter_lr, update_method)

test_data_list, context_pred_refs_dict, context_list = utils.generate_prepare(dataset_manager._get_test_loader(), dataset_path, generation_dir)
choosen_client_ids = utils.client_choose(R, N, C)


loss_fct = nn.CrossEntropyLoss(ignore_index = -1, reduce = False, label_smoothing = label_smoothing)
trainer = NLG_Trainer(
    lora_model, tokenizer, adapter_manager, 
    lr, local_epochs, gradient_accumulation_steps, 
    weight_decay, loss_fct, 
    max_train_step, max_eval_step, max_test_step,
    generation_config, generation_length, generation_dir
)

server_loss = []
distributed_loss = []


server_adapter = adapter_manager._get_adapter_from_model(lora_model)
distributed_adapter = adapter_manager._get_distributed_adapter(server_adapter)
server_loss.append(trainer._evaluate(server_adapter, dataset_manager._get_valid_loader()))
if "FedQLoRA" in alg:
    distributed_loss.append(trainer._evaluate(distributed_adapter, dataset_manager._get_valid_loader()))


for round in range(R):
    client_adapters = []
    for client_id in choosen_client_ids[round]:
        client_adapters.append(trainer._train(distributed_adapter, dataset_manager._get_train_loader(client_id)))
    contributed_adapter = adapter_manager._get_contributed_adapter(client_adapters)
    server_adapter = adapter_manager._get_server_adapter(server_adapter, distributed_adapter, contributed_adapter, round, R)
    distributed_adapter = adapter_manager._get_distributed_adapter(server_adapter)

    os.makedirs(os.path.join(save_dir, str(round+1)), exist_ok=True)

    server_loss.append(trainer._evaluate(server_adapter, dataset_manager._get_valid_loader()))
    logger.info(f"{run_name} server_loss round_{str(round)}: {str(server_loss)}")
    torch.save(server_adapter, os.path.join(save_dir, str(round+1), "server_adapter.pth"))


    if "FedQLoRA" in alg:
        distributed_loss.append(trainer._evaluate(distributed_adapter, dataset_manager._get_valid_loader()))
        logger.info(f"{run_name} distributed_loss round_{str(round)}: {str(distributed_loss)}")
        torch.save(distributed_adapter, os.path.join(save_dir, str(round+1), "distributed_adapter.pth"))





# 选一个最好的 epoch 进行生成
best_server_round = server_loss.index(min(server_loss))
logger.info(f"{run_name} best server: {str(best_server_round)}")
best_server_adapter = torch.load(os.path.join(save_dir, str(best_server_round), "server_adapter.pth"))
logger.info(f"{run_name} loss of best server: {str(trainer._evaluate(best_server_adapter, dataset_manager._get_valid_loader()))}")
trainer._generate(best_server_adapter, test_data_list, copy.deepcopy(context_pred_refs_dict), context_list, "best_server")

if "FedQLoRA" in alg:
    best_distributed_round = distributed_loss.index(min(distributed_loss))
    logger.info(f"best distributed: {str(best_distributed_round)}")
    best_distributed_adapter = torch.load(os.path.join(save_dir, str(best_distributed_round), "distributed_adapter.pth"))
    logger.info(f"loss of best distributed: {str(trainer._evaluate(best_distributed_adapter, dataset_manager._get_valid_loader()))}")
    trainer._generate(best_distributed_adapter, test_data_list, copy.deepcopy(context_pred_refs_dict), context_list, "best_distributed")


