import collections
from http import client
from typing import DefaultDict, OrderedDict
import torch
import numpy as np
import random
from tqdm import tqdm
from flearn.utils.model_utils import evaluate, average_weights, generate_mask, train_personalize_with_prune_growth
from flearn.utils.process_data import PromptDataset, partition
from data.process import tasks_num_labels
from transformers import RobertaConfig, RobertaTokenizer
from models.modeling_roberta_lora import RobertaForMaskedLM
from ..utils.fl_score_fuctions import *
import copy
import os
import time
import math
import pickle
from flearn.utils.model_utils import peace_func
import logging



evaluate_metric = {"rte":"acc",
                    "sst-2":"acc",
                    "cola":"mcc",
                    'mrpc':'acc_and_f1',
                    'mpqa':'acc',
                    'qnli':'acc',
                    "subj":"acc",
                    'trec':"acc",
                    "wnli":"acc",
                    "boolq":"acc",
                    "mr": "acc"}


class CentralTraining(object):
    """
    对于聚合后的模型，进行中心化的训练，share_percent 是共享数据集的大小
    """

    def __init__(self, args, share_percent=0, iid=True, unequal=False, result_dir="central"):

        # self.device = "cuda" if torch.cuda.is_available() else "cpu"
         
        self.args = args
        self.general_layer_num = self.args.select_layer_num

        self.args.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.args.data_dir = self.args.data_dir + self.args.task_name + '/'
        self.args.output_dir = self.args.output_dir + 'FL/' + self.args.task_name + '/'

        self.logger = logging.getLogger(__name__)
        if self.args.select_method == 'ours':
            # self.layer_index_list = [2,0,1,3,4,23,7,6,5,8,9,10,11,12,13,14,15,18,16,22,17,20,21,19][::-1]
            self.layer_index_list = [8,9,23,13,5,10,7,20,0,2,6,12,21,22,4,3,11,1,15,14,19,16,17,18][::-1]
            # self.layer_index_list = [0,5,12,10,22,21,23,6,8,2,7,3,9,4,15,13,11,20,17,1,14,19,16,18][::-1]
        else:
            self.layer_index_list = list(range(24))
        
        if self.args.select_method == 'random':
            random.shuffle(self.layer_index_list)
        print(self.layer_index_list)
        
        # 设置随机种子
        self.reset_seed()



    def reset_seed(self):
        torch.manual_seed(self.args.seed)
        torch.cuda.manual_seed_all(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)
        torch.backends.cudnn.deterministic = True

    def init_data(self):
        
        self.args.local_rank = self.args.local_rank
        self.train_dataset = PromptDataset(self.args, self.args.task_name.lower(), self.tokenizer, data_type="train")
        self.eval_dataset = PromptDataset(self.args, self.args.task_name.lower(), self.tokenizer, data_type='dev')
        self.train_loaders, self.test_loaders, self.n_sample_list = partition(self.args, self.train_dataset, \
            self.eval_dataset, self.logger)


    def load_model(self):

        if "roberta" in self.args.model_name_or_path:
            config = RobertaConfig.from_pretrained(
                        self.args.model_name_or_path,
                        num_labels=tasks_num_labels[self.args.task_name],
                        finetuning_task=self.args.task_name,
                        cache_dir=self.args.cache_dir if self.args.cache_dir else None,
                        output_hidden_states=True,
                        output_attentions=True
                    )
            
                
            self.tokenizer = RobertaTokenizer.from_pretrained(
                        self.args.model_name_or_path,
                        do_lower_case=self.args.do_lower_case,
                        cache_dir=self.args.cache_dir if self.args.cache_dir else None,        
                    )

                
            config.apply_lora=self.args.apply_lora
            config.lora_alpha=self.args.lora_alpha
            config.lora_r=self.args.lora_r
            config.apply_adapter = self.args.apply_adapter
            config.adapter_path = self.args.adapter_path
            config.adapter_type = self.args.adapter_type
            config.adapter_size = self.args.adapter_size
            config.apply_bitfit = self.args.apply_bitfit
            # config.prompt_layer_list = [9, 22, 18, 11, 16, 17, 21, 15, 10, 8, 12, 7]

            
            self.model = RobertaForMaskedLM.from_pretrained(
                        self.args.model_name_or_path, 
                        config=config, 
                    )
    
    def generate_prompt(self, transfer_layer_index_list):
        self.train_parameters_name = list()
        self.transfer_parameters_name = list()
        self.server_weights = {}
        for name, param in self.model.named_parameters():
            if 'lora' in name:
                layer_index = int(name.split('.')[3])
                if layer_index in transfer_layer_index_list:
                    self.transfer_parameters_name.append(name)
                self.train_parameters_name.append(name)
                param.requires_grad = True
                self.server_weights[name] = copy.deepcopy(param.data)
            else:
                param.requires_grad = False

        all_param = 0
        for name, param in self.model.named_parameters():
            if param.requires_grad == True:
                all_param += param.numel()
        print('total param is {}'.format(all_param))
    
    def generate_weights_for_clients(self):
        
        self.client_weights = []
        # state_dict = OrderedDict()
        # for name, p in self.model.named_parameters():
        #     if p.requires_grad == True:
        #         state_dict[name] = copy.deepcopy(p.data)

        for _ in range(self.args.num_clients):
            self.client_weights.append(copy.deepcopy(self.model.get_copy_of_trainable_weights()))



    def data_evaluate_and_score_ours(self):
        # returned train_loaders are list: [score, batch]
        self.train_loaders = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.args.task_name, \
                        self.args.sort_type, self.args.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                self.train_loaders = pickle.load(file)
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                self.train_loaders.append(evaluate_and_sort_ours((self.args, self.args), train_dataset, \
                    self.model, client_index, col_func=self.train_dataset.collate_fn))
            with open(file_name, 'wb') as file:
                pickle.dump(self.train_loaders, file)
        
        for i in range(len(self.train_loaders)):
            totoal_score = 0
            for batch in self.train_loaders[i]:
                totoal_score += batch[0]
            self.client_score_dict[i] = (totoal_score / len(self.train_loaders[i]))
        
        print("the length of the training loaders:", len(self.train_loaders))
    
    def data_evaluate_and_score_voc(self):
        # returned train_loaders are list: [score, batch]
        data = []
        self.sorted_train_dataset = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.args.task_name, \
                        self.args.sort_type, self.args.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                data = pickle.load(file)
                self.sorted_train_dataset = [ele[0] for ele in data]
                for index, ele in enumerate(data):
                    self.client_score_dict[index] = ele[1]
        else:
            # define the global freq
            self.global_dict = {}
            total = 0
            for client_index, train_dataset in enumerate(self.train_datasets):
                self.global_dict, total = evaluate_and_sort_voc_freq(self.global_dict, total, train_dataset)
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                sorted_dataset, client_score = evaluate_and_sort_voc(self.args, self.args, train_dataset, \
                    self.model, self.tokenizer, self.global_dict, total)
                data.append((sorted_dataset, client_score))
                self.sorted_train_dataset.append(sorted_dataset)
                self.client_score_dict[client_index] = client_score
            with open(file_name, 'wb') as file:
                pickle.dump(data, file)
        
        print("the length of the training loaders:", len(self.sorted_train_dataset))
    
    def data_evaluate_and_score_seqreo(self):
        # returned train_loaders are list: [score, batch]
        data = []
        self.sorted_train_dataset = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.args.task_name, \
                        self.args.sort_type, self.args.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                data = pickle.load(file)
                self.sorted_train_dataset = [ele[0] for ele in data]
                for index, ele in enumerate(data):
                    self.client_score_dict[index] = ele[1]
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                sorted_dataset, client_score = evaluate_and_sort_seqreo(self.args, self.args, train_dataset, \
                    self.model, self.tokenizer)
                data.append((sorted_dataset, client_score))
                # print(client_index)
                self.sorted_train_dataset.append(sorted_dataset)
                self.client_score_dict[client_index] = client_score
            with open(file_name, 'wb') as file:
                pickle.dump(data, file)
        
        print("the length of the training loaders:", len(self.sorted_train_dataset))

    
    def data_evaluate_and_score_oursvoc(self):
        # returned train_loaders are list: [score, batch]
        self.train_loaders = []
        self.client_score_dict = {}
        file_name = './sorted_dataset/fl/{}/{}_dataset_{}.pkl'.format(self.args.task_name, \
                        self.args.sort_type, self.args.num_clients)
        if os.path.exists(file_name):
            with open(file_name, 'rb') as file:
                self.train_loaders = pickle.load(file)
        else:
            for client_index, train_dataset in enumerate(self.train_datasets):
                logger.info("prepossing the dataset for client:{}".format(client_index))
                self.train_loaders.append(evaluate_and_sort_ours_voc((self.args, self.args), train_dataset, \
                    self.model, self.tokenizer, col_func=self.train_dataset.collate_fn))
            with open(file_name, 'wb') as file:
                pickle.dump(self.train_loaders, file)

        for i in range(len(self.train_loaders)):
            totoal_score = 0
            for batch in self.train_loaders[i]:
                totoal_score += batch[0]
            self.client_score_dict[i] = (totoal_score / len(self.train_loaders[i]))
        
        print("the length of the training loaders:", len(self.train_loaders))
    

    def client_train(self, idxs_users, train_dataloaders, local_weights, time_list, global_weights, epoch):
        """
        进行客户端训练
        :param local_v:
        :param local_P:
        :param idxs_users:
        :param global_model:
        :param user_groups:
        :param epoch:
        :param train_dataset:
        :param train_losses:
        :param local_weights:
        :param local_losses:
        :return:
        """
        total_loss = 0
        
        
        for idx in idxs_users:
            
            # self.model.train()
            # self.model.update_trainable_weights_from_dict(copy.deepcopy(self.client_weights[idx]))
            self.model.train()
            self.model.update_trainable_weights_from_dict(copy.deepcopy(global_weights))
            start = time.time()

            w, loss, mask = train_personalize_with_prune_growth((self.args, self.args), train_dataloaders[idx], self.model, self.mask[idx], epoch)
            self.mask[idx] = copy.deepcopy(mask)

            local_weights.append([len(train_dataloaders[idx]), copy.deepcopy(w)])
            delta_time = time.time() - start
            time_list.append(delta_time)
            total_loss += loss * len(train_dataloaders[idx])
            print("{}:{:.4f}".format(idx, loss), end=" ")
            print("update the local client:{}".format(idx))
            # self.client_weights[idx] = copy.deepcopy(self.model.get_copy_of_trainable_weights())
    
   

    def train(self):
        # 记录日志和结果

        # 加载模型
        self.load_model()

        # load dataset and user groups
        self.init_data()

        # generate the prompt parameters and generate the globaol weights, self.server_weights
        transfer_layer_index = self.layer_index_list[:self.general_layer_num]
        num_of_trainable_params = self.generate_prompt(transfer_layer_index)
        print("the transfer layer name:", self.transfer_parameters_name)


        # generate the client weights, self.client_weights
        self.generate_weights_for_clients()

        # move the model to gpu
        self.model = self.model.cuda()

        # Training
        train_losses = []
        test_accs = []
        max_times = []
        best_acc = 0
        training_losses = []
        training_parameters = []
        params_list = []

        # initial the  the mask for each clients
        global_weights = self.model.get_copy_of_trainable_weights()
        self.mask = []
        


        for i in range(self.args.num_clients):
            self.model.train()
            self.model.update_trainable_weights_from_dict(copy.deepcopy(global_weights))
            mask = generate_mask((self.args, self.args), self.train_loaders[i], self.model)
            self.mask.append(copy.deepcopy(mask))

        # exit()

        # 第一次评估
        self.model.train()
        self.model.update_trainable_weights_from_dict(copy.deepcopy(global_weights))
        

        test_loss, test_acc = evaluate((self.args, self.args, self.args), self.model, self.tokenizer)

        # evaluate_key = evaluate_metric[self.args.task_name]
        print("-train loss:{:.4f} -test acc:{}".format(test_loss, test_acc))
        lr = self.args.learning_rate
        for epoch in range(self.args.rounds):
            start = time.time()
            local_weights, local_losses, local_v, local_P, time_list = [], [], [], [], []
            print(f'\n | Global Training Round : {epoch} |\n')
            

            # 选择设备
            idxs_users = np.random.choice(range(self.args.num_clients), self.args.m, replace=False)

            training_loss = self.client_train(idxs_users, self.train_loaders, \
                local_weights, time_list, global_weights, epoch)


            # # jiaxiang 更新server全部的层
            global_weights = average_weights(local_weights)
            self.model.train()
            self.model.update_trainable_weights_from_dict(copy.deepcopy(global_weights))
            

            test_loss, test_acc = evaluate((self.args, self.args, self.args), self.model, self.tokenizer)

            
            test_accs.append(test_acc)
            max_times.append(max(time_list))
            train_losses.append(test_loss)
            training_losses.append(training_loss)
            training_parameters.append(self.general_layer_num)
            # params_list.append(num_of_trainable_params)
            if test_acc > best_acc:
                best_acc = test_acc
            
            


            print("epoch{:4d} - loss: {:.4f} - accuracy: {:.4f} - lr: {:.4f} - time: {:.2f}, {},{}".format(epoch, test_loss, test_acc, \
                self.args.learning_rate, time.time() - start, sum(time_list), training_loss))


        res_dict = {"acc":test_accs, "eval_loss": train_losses, "best_acc": best_acc, "training_time":max_times, "training_loss":training_losses, "num_transfer_params":training_parameters}

        print(res_dict)


# if __name__ == "__main__":
#     t = CentralTraining(args, share_percent=10, iid=0, unequal=False, prune_interval=30, prune_rate=0.6, auto_rate=True, auto_mu=False, server_mu=0, client_mu=0)
#     t.train()
