from http import client
import torch
import numpy as np
import random
from tqdm import tqdm
from flearn.utils.model_utils import train_simple, evaluate, average_weights, train, train_others
from flearn.utils.process_data import PromptDataset, partition, partition_for_score
from data.process import tasks_num_labels
from transformers import RobertaConfig, RobertaTokenizer, AutoConfig
from models.modeling_roberta_lora import RobertaForMaskedLM
from ..utils.fl_score_fuctions import *
import copy
import os
import time
import math
import pickle
from flearn.utils.model_utils import peace_func



evaluate_metric = {"rte":"acc",
                    "sst-2":"acc",
                    "cola":"mcc",
                    'mrpc':'acc_and_f1',
                    'mpqa':'acc',
                    'qnli':'acc',
                    "subj":"acc",
                    'trec':"acc",
                    "wnli":"acc",
                    "boolq":"acc",
                    "mr": "acc"}


class CentralTraining(object):
    """
    对于聚合后的模型，进行中心化的训练，share_percent 是共享数据集的大小
    """

    def __init__(self, args, share_percent=0, iid=True, unequal=False, result_dir="central"):

        # self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model_config = args[0]
        self.prompt_config = args[1]
        self.fl_config = args[2]

        self.logger = logging.getLogger(__name__)

        self.v = {}

        self.model_config.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_config.data_dir = self.model_config.data_dir + self.model_config.task_name + '/'
        self.model_config.output_dir = self.model_config.output_dir + 'FL/' + self.model_config.task_name + '/'

        # 设置随机种子
        self.reset_seed()



    def reset_seed(self):
        torch.manual_seed(self.model_config.seed)
        torch.cuda.manual_seed_all(self.model_config.seed)
        np.random.seed(self.model_config.seed)
        random.seed(self.model_config.seed)
        torch.backends.cudnn.deterministic = True

    def init_data(self):
        
        self.model_config.local_rank = self.fl_config.local_rank
        self.train_dataset = PromptDataset(self.model_config, self.model_config.task_name.lower(), self.tokenizer, data_type="train")
        self.fl_config.data_dir = self.model_config.data_dir
        self.train_datasets = partition_for_score(self.fl_config, self.train_dataset, logger=self.logger)


    def load_model(self):

        if "roberta" in self.model_config.model_name_or_path:
            config = RobertaConfig.from_pretrained(        
                        self.model_config.model_name_or_path,
                        num_labels=tasks_num_labels[self.model_config.task_name],
                        finetuning_task=self.model_config.task_name,
                        cache_dir=self.model_config.cache_dir if self.model_config.cache_dir else None,
                        output_hidden_states=True,
                        output_attentions=True
                    )
            
                
            self.tokenizer = RobertaTokenizer.from_pretrained(
                        self.model_config.model_name_or_path,
                        do_lower_case=self.model_config.do_lower_case,
                        cache_dir=self.model_config.cache_dir if self.model_config.cache_dir else None,        
                    )

                
            config.apply_lora=self.model_config.apply_lora
            config.lora_alpha=self.model_config.lora_alpha
            config.lora_r=self.model_config.lora_r
            config.apply_adapter = self.model_config.apply_adapter
            config.adapter_path = self.model_config.adapter_path
            config.adapter_type = self.model_config.adapter_type
            config.adapter_size = self.model_config.adapter_size
            config.apply_bitfit = self.model_config.apply_bitfit
            # config.prompt_layer_list = [9, 22, 18, 11, 16, 17, 21, 15, 10, 8, 12, 7]

            
            self.model = RobertaForMaskedLM.from_pretrained(
                        self.model_config.model_name_or_path, 
                        config=config, 
                    )
    
    def generate_prompt(self):
        for name, param in self.model.named_parameters():
            if 'lora' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        # print('total param is {}'.format(all_param))

        all_param = 0
        # train_params = []
        for name, param in self.model.named_parameters():
            if param.requires_grad == True:
                all_param += param.numel()
        print('total param is {}'.format(all_param))


    def data_evaluate_and_score_ours(self):
        # returned train_loaders are list: [score, batch]
        self.train_loaders = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}_fisher.pkl'.format(self.model_config.task_name, \
                        self.fl_config.sort_type, self.fl_config.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                self.train_loaders = pickle.load(file)
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                self.train_loaders.append(evaluate_and_sort_ours_fisher((self.model_config, self.fl_config), train_dataset, \
                    self.model, client_index, col_func=self.train_dataset.collate_fn))
            with open(file_name, 'wb') as file:
                pickle.dump(self.train_loaders, file)
        
        for i in range(len(self.train_loaders)):
            totoal_score = 0
            for batch in self.train_loaders[i]:
                totoal_score += batch[0]
            self.client_score_dict[i] = (totoal_score / len(self.train_loaders[i]))
        
        print("the length of the training loaders:", len(self.train_loaders))
    
    def data_evaluate_and_score_voc(self):
        # returned train_loaders are list: [score, batch]
        data = []
        self.sorted_train_dataset = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.model_config.task_name, \
                        self.fl_config.sort_type, self.fl_config.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                data = pickle.load(file)
                self.sorted_train_dataset = [ele[0] for ele in data]
                for index, ele in enumerate(data):
                    self.client_score_dict[index] = ele[1]
        else:
            # define the global freq
            self.global_dict = {}
            total = 0
            for client_index, train_dataset in enumerate(self.train_datasets):
                self.global_dict, total = evaluate_and_sort_voc_freq(self.global_dict, total, train_dataset)
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                sorted_dataset, client_score = evaluate_and_sort_voc(self.fl_config, self.model_config, train_dataset, \
                    self.model, self.tokenizer, self.global_dict, total)
                data.append((sorted_dataset, client_score))
                self.sorted_train_dataset.append(sorted_dataset)
                self.client_score_dict[client_index] = client_score
            with open(file_name, 'wb') as file:
                pickle.dump(data, file)
        
        print("the length of the training loaders:", len(self.sorted_train_dataset))
    
    def data_evaluate_and_score_seqreo(self):
        # returned train_loaders are list: [score, batch]
        data = []
        self.sorted_train_dataset = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.model_config.task_name, \
                        self.fl_config.sort_type, self.fl_config.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                data = pickle.load(file)
                self.sorted_train_dataset = [ele[0] for ele in data]
                for index, ele in enumerate(data):
                    self.client_score_dict[index] = ele[1]
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                sorted_dataset, client_score = evaluate_and_sort_seqreo(self.fl_config, self.model_config, train_dataset, \
                    self.model, self.tokenizer)
                data.append((sorted_dataset, client_score))
                # print(client_index)
                self.sorted_train_dataset.append(sorted_dataset)
                self.client_score_dict[client_index] = client_score
            with open(file_name, 'wb') as file:
                pickle.dump(data, file)
        
        print("the length of the training loaders:", len(self.sorted_train_dataset))

    
    def data_evaluate_and_score_oursvoc(self):
        # returned train_loaders are list: [score, batch]
        self.train_loaders = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.model_config.task_name, \
                        self.fl_config.sort_type, self.fl_config.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                self.train_loaders = pickle.load(file)
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                self.train_loaders.append(evaluate_and_sort_ours_voc((self.model_config, self.fl_config), train_dataset, \
                    self.model, self.tokenizer, col_func=self.train_dataset.collate_fn))
            with open(file_name, 'wb') as file:
                pickle.dump(self.train_loaders, file)

        for i in range(len(self.train_loaders)):
            totoal_score = 0
            for batch in self.train_loaders[i]:
                totoal_score += batch[0]
            self.client_score_dict[i] = (totoal_score / len(self.train_loaders[i]))
        
        print("the length of the training loaders:", len(self.train_loaders))
    

    def client_train(self, idxs_users, epoch, training_loss, train_dataloaders, local_weights, time_list, peac_func):
        """
        进行客户端训练
        :param local_v:
        :param local_P:
        :param idxs_users:
        :param global_model:
        :param user_groups:
        :param epoch:
        :param train_dataset:
        :param train_losses:
        :param local_weights:
        :param local_losses:
        :return:
        """
        num_current = 0
        print(len(train_dataloaders))
        for idx in idxs_users:

            num_current += len(train_dataloaders[idx])
        total_loss = 0
        
        ori_trainable_weights = self.model.get_copy_of_trainable_weights()
        for idx in idxs_users:
            start = time.time()
            if self.fl_config.sort_type in ["ours", "oursvoc"]:
                w, loss, _ = train_simple((self.model_config, self.fl_config), train_dataloaders[idx], self.model, epoch)
            elif self.fl_config.sort_type == "vanila":
                w, loss, _ = train((self.model_config, self.fl_config), train_dataloaders[idx], self.model, self.train_dataset.collate_fn)
            else:
                w, loss, _ = train_others((self.model_config, self.fl_config), train_dataloaders[idx], self.model, self.train_dataset.collate_fn, epoch)
            
            local_weights.append([len(train_dataloaders[idx]), copy.deepcopy(w)])
            delta_time = time.time() - start
            time_list.append(delta_time)
            total_loss += loss * len(train_dataloaders[idx])
            print("{}:{:.4f}".format(idx, loss), end=" ")
            self.model.update_trainable_weights_from_dict(ori_trainable_weights)
        return total_loss / num_current
    
   

    def train(self):
        # 记录日志和结果
        # log_path = os.path.join(self.output_dir, "iid" if self.iid else "noniid", "log", ".txt")
        # result_path = os.path.join(self.output_dir, "iid" if self.iid else "noniid")

        # 加载模型
        self.load_model()


        # load dataset and user groups
        self.init_data()


        # evaluate difficulty for each sample for each clients
        self.model = self.model.to(self.model_config.device)
        if self.fl_config.sort_type == "ours":
            self.data_evaluate_and_score_ours()
        elif self.fl_config.sort_type == "voc":
            self.data_evaluate_and_score_voc()
        elif self.fl_config.sort_type == "seqreo":
            self.data_evaluate_and_score_seqreo()
        elif self.fl_config.sort_type == "oursvoc":
            self.data_evaluate_and_score_oursvoc()

        # generate the prompt parameters
        num_of_trainable_params = self.generate_prompt()

        # if self.fl_config.server_cl:
        #     self.client_sorted_index = np.argsort(self.client_score_list)

        # Training
        train_losses = []
        test_accs = []
        max_times = []
        best_acc = 0
        training_losses = []
        params_list = []

        self.reset_seed()

        # 第一次评估
        test_loss, test_acc = evaluate((self.model_config, self.prompt_config, self.fl_config), self.model, self.tokenizer)
        evaluate_key = evaluate_metric[self.model_config.task_name]
        
        # print(test_acc.keys())
        print("-train loss:{:.4f} -test acc:{}".format(test_loss, test_acc[evaluate_key]))
        lr = self.fl_config.learning_rate
        for epoch in range(self.fl_config.rounds):
            start = time.time()
            local_weights, local_losses, local_v, local_P, time_list = [], [], [], [], []
            print(f'\n | Global Training Round : {epoch} |\n')

            if epoch < 15:
                self.fl_config.learning_rate = lr - (lr - 0.5*lr) / self.fl_config.rounds * epoch
            elif epoch == 15:
                self.fl_config.learning_rate = self.fl_config.learning_rate * 0.05

            # if epoch < 15:
            # if epoch == 7:
            #     self.fl_config.learning_rate = 4e-4
            
            # if epoch >= 7:
            #     lr = 4e-4
            # self.fl_config.learning_rate = lr - (lr - 0.5*lr) / self.fl_config.rounds * epoch
                

            # 选择设备，并进行训练
            self.model.train()
            idxs_users = np.random.choice(range(self.fl_config.num_clients), self.fl_config.m, replace=False)
            if self.fl_config.server_cl:
                tmp_score_list = []
                for idx in idxs_users:
                    tmp_score_list.append(self.client_score_dict[idx])
                print(tmp_score_list)
                idxs_users = np.argsort(tmp_score_list)
                epoch_thred = peace_func(self.fl_config, self.fl_config.alpha, self.fl_config.beta, idxs_users, epoch, self.fl_config.rounds)
                print("the peace func:{}, epoch_thred:{}".format(self.fl_config.server_peace_func, epoch_thred))
                idxs_users = idxs_users[:epoch_thred]
                
            if self.fl_config.sort_type in ["ours", "oursvoc"]:
                training_loss = self.client_train(idxs_users, epoch, train_losses, self.train_loaders, \
                local_weights, time_list, self.fl_config.data_peace_func)
            elif self.fl_config.sort_type == "vanila":
                training_loss = self.client_train(idxs_users, epoch, train_losses, self.train_datasets, \
                local_weights, time_list, self.fl_config.data_peace_func)
            else:
                training_loss = self.client_train(idxs_users, epoch, train_losses, self.sorted_train_dataset, \
                local_weights, time_list, self.fl_config.data_peace_func)


            global_weights = average_weights(local_weights)
            # print(global_weights)

            # update global weights
            print("use fedavg as aggregation method on server")
            global_weights = average_weights(local_weights)
            self.model.update_trainable_weights_from_dict(global_weights)
            


            test_loss, test_acc = evaluate((self.model_config, self.prompt_config, self.fl_config), self.model, self.tokenizer)
            test_accs.append(test_acc[evaluate_key])
            max_times.append(sum(time_list))
            train_losses.append(test_loss)
            training_losses.append(training_loss)
            params_list.append(num_of_trainable_params)
            if test_acc[evaluate_key] > best_acc:
                best_acc = test_acc[evaluate_key]
            
            


            print("epoch{:4d} - loss: {:.4f} - accuracy: {:.4f} - lr: {:.4f} - time: {:.2f}, {},{}".format(epoch, test_loss, test_acc[evaluate_key], \
                self.fl_config.learning_rate, time.time() - start, sum(time_list), training_loss))

        save_path = self.model_config.output_dir
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    
        res_dict = {"acc":test_accs, "eval_loss": train_losses, "best_acc": best_acc, "training_time":max_times, "training_loss":training_losses, "num_transfer_params":params_list}
        print(res_dict)
        # with open(save_path + '/metrics_{}'.format(self.prompt_config.prompt_type), 'wb') as f:
        #     pickle.dump(res_dict,f)


# if __name__ == "__main__":
#     t = CentralTraining(args, share_percent=10, iid=0, unequal=False, prune_interval=30, prune_rate=0.6, auto_rate=True, auto_mu=False, server_mu=0, client_mu=0)
#     t.train()
