from abc import ABC, abstractmethod
import torch
import os
from transformers import GPT2LMHeadModel, AdamW, get_scheduler
from tqdm import tqdm

from model_utils import FedSPModel




class InitTrainer():#不需要label，纯kdloss，不需要prefix
    def __init__(self, save_dir, dataloader, model_name, 
            weight_decay, 
            gradient_accumulation_steps, local_epochs, lr, max_train_step, aux_layer_num,
            logger):
        self._save_dir = save_dir
        self._dataloader =dataloader
        self._model_name = model_name
        self._teacher_model = FedSPModel(model_name, "global", 0)
        self._student_model = FedSPModel(model_name, "aux", 0, aux_layer_num)
        self._logger = logger



        self._weight_decay = weight_decay

        self._gradient_accumulation_steps = gradient_accumulation_steps
        self._local_epochs = local_epochs
        self._lr = lr
        self._max_train_step = max_train_step
        self._mse_loss_func = torch.nn.MSELoss()

        
    def _train(self):
        if self._max_train_step == -1:
            total_train_step = self._local_epochs * int(len(self._dataloader) / self._gradient_accumulation_steps)
        else:
            total_train_step = self._max_train_step
        warmup_step = int(total_train_step * 0.06)+1
        optimizer = AdamW(params = self._student_model.parameters(), lr = self._lr, weight_decay = self._weight_decay)
        lr_scheduler = get_scheduler("linear", optimizer = optimizer, num_warmup_steps = warmup_step, num_training_steps = total_train_step)
        

        self._student_model.train()
        train_step = 0
        for epoch in range(self._local_epochs):
            if(train_step == self._max_train_step):
                break
            for step, batch in enumerate(tqdm(self._dataloader)):
                data = {k:v.to(torch.device("cuda")) for k,v in batch.items()}

                _input = data['input']
                output = self._student_model(_input)
                with torch.no_grad():
                    teacher_output = self._teacher_model(_input)
                loss = self._mse_loss_func(output.hidden_states[-1],
                    teacher_output.hidden_states[-1])

                loss = loss / self._gradient_accumulation_steps
                loss.backward()

                if ((step + 1) % self._gradient_accumulation_steps) == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    train_step += 1
                    lr_scheduler.step()
                if(train_step == self._max_train_step):
                    break
        os.makedirs(os.path.join(self._save_dir, str(0)), exist_ok=True)
        torch.save(self._teacher_model.state_dict(), os.path.join(self._save_dir, str(0), "global.pth"))
        torch.save(self._student_model.state_dict(), os.path.join(self._save_dir, str(0), "aux.pth"))
        del self._teacher_model
        del self._student_model
        torch.cuda.empty_cache()

class ClientTrainer():
    def __init__(self, save_dir, dataloader, model_name,
            prefix_len, aux_layer_num,
            r, last_r, client_id,
            weight_decay, label_smoothing, gradient_accumulation_steps,
            model_local_epochs, model_lr, model_max_train_step,
            prefix_local_epochs, prefix_lr, prefix_max_train_step,
            logger
        ):
        self._logger = logger
        self._save_dir = save_dir
        self._dataloader =dataloader
        self._model_name = model_name
        self._prefix_len = prefix_len

        self._r = r
        self._last_r = last_r
        self._client_id = client_id

        
        self._weight_decay = weight_decay
        self._label_smoothing = label_smoothing
        self._gradient_accumulation_steps = gradient_accumulation_steps

        self._model_local_epochs = model_local_epochs
        self._model_lr = model_lr
        self._model_max_train_step = model_max_train_step

        self._prefix_local_epochs = prefix_local_epochs
        self._prefix_lr = prefix_lr
        self._prefix_max_train_step = prefix_max_train_step

        self._model = FedSPModel(model_name, "aux", self._prefix_len, aux_layer_num)
        raw_state_dict = self._model.state_dict().copy()




        if self._last_r == 0:
            last_r_aux = torch.load(os.path.join(self._save_dir, str(last_r), "aux.pth"))
        else:
            last_r_aux = torch.load(os.path.join(self._save_dir, str(last_r), f"client_{self._client_id}_all.pth"))
        for key in raw_state_dict:
            if "prompt_encoder" not in key:
                raw_state_dict[key] = last_r_aux[key]
        if self._r > 1:
            prompt_encoder = torch.load(os.path.join(self._save_dir, str(self._r-1), "agg_prompt_encoder.pth"))
            for key in raw_state_dict:
                if "prompt_encoder" in key:
                    raw_state_dict[key] = prompt_encoder[key]
        self._model.load_state_dict(raw_state_dict, strict=True)

        self._loss_fct = torch.nn.CrossEntropyLoss(ignore_index = -1, reduce = False, label_smoothing = self._label_smoothing)

    def _switch_mode(self, mode):
        if mode == "model":
            self._model._freeze_param(["prompt_encoder"])
            self._model._unfreeze_param(["model"])
        else :
            self._model._freeze_param(["model"])
            self._model._unfreeze_param(["prompt_encoder"])


    def _train(self, mode):
        if mode == "model":
            local_epochs = self._model_local_epochs
            lr = self._model_lr
            max_train_step = self._model_max_train_step
        else:
            local_epochs = self._prefix_local_epochs
            lr = self._prefix_lr
            max_train_step = self._prefix_max_train_step
        if max_train_step == -1:
            total_train_step = local_epochs * int(len(self._dataloader) / self._gradient_accumulation_steps)
        else:
            total_train_step = max_train_step
        warmup_step = int(total_train_step * 0.06)+1
        optimizer = AdamW(params = self._model.parameters(), lr = lr, weight_decay = self._weight_decay)
        lr_scheduler = get_scheduler("linear", optimizer = optimizer, num_warmup_steps = warmup_step, num_training_steps = total_train_step)
        

        self._model.train()
        train_step = 0
        for epoch in range(local_epochs):
            if(train_step == max_train_step):
                break
            for step, batch in enumerate(tqdm(self._dataloader)):
                data = {k:v.to(torch.device("cuda")) for k,v in batch.items()}

                _input = data['input']
                _batch, _len = _input.shape
                _target = data['target']
                _msk = data['mask']
                output = self._model(_input)
                lm_logits = output.logits
                loss = self._loss_fct(lm_logits.view(-1, lm_logits.size(-1)), _target.view(-1)).view(_batch, _len)

                loss = loss * _msk 
                loss = loss.sum() / (_msk.sum() + 0.0001)
                loss = loss.mean() 

                loss = loss / self._gradient_accumulation_steps
                loss.backward()

                if ((step + 1) % self._gradient_accumulation_steps) == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    train_step += 1
                    lr_scheduler.step()
                if(train_step == max_train_step):
                    break
        if mode == "prompt_encoder":
            torch.save(self._model.state_dict(), os.path.join(self._save_dir, f"{self._r}", f"client_{self._client_id}_all.pth"))
            prompt_encoder = {}
            for key in self._model.state_dict():
                if "prompt_encoder" in key:
                    prompt_encoder[key] = self._model.state_dict()[key]
            torch.save(prompt_encoder, os.path.join(self._save_dir, f"{self._r}", f"client_{self._client_id}_prompt_encoder.pth"))
            del self._model
            torch.cuda.empty_cache()

class ServerTrainer():
    def __init__(self, save_dir, dataloader, model_name, prefix_len,
            r, client_ids,
            weight_decay, label_smoothing, gradient_accumulation_steps,
            local_epochs, lr, max_train_step, max_eval_step, max_test_step
        ):
        self._save_dir = save_dir
        self._dataloader = dataloader
        self._model_name = model_name
        self._prefix_len = prefix_len

        self._r = r
        self._client_ids = client_ids
        
        self._weight_decay = weight_decay
        self._label_smoothing = label_smoothing
        self._gradient_accumulation_steps = gradient_accumulation_steps

        self._local_epochs = local_epochs
        self._lr = lr
        self._max_train_step = max_train_step
        self._max_eval_step = max_eval_step
        self._max_test_step = max_test_step

        self._model = FedSPModel(model_name, "global", self._prefix_len)
        raw_state_dict = self._model.state_dict().copy()

        clients_prompt_encoder = []
        for client_id in client_ids:
            prompt_encoder = torch.load(os.path.join(self._save_dir, f"{self._r}", f"client_{client_id}_prompt_encoder.pth"))
            clients_prompt_encoder.append(prompt_encoder)
        mean_prompt_encoder = {}
        for key in clients_prompt_encoder[0]:
            ps = [prompt_encoder[key] for prompt_encoder in clients_prompt_encoder]
            mean_prompt_encoder[key] = torch.mean(torch.stack(ps), dim=0)   
        init_global = torch.load(os.path.join(self._save_dir, str(0), "global.pth"))

        for key in raw_state_dict:
            if "prompt_encoder" not in key:
                raw_state_dict[key] = init_global[key]
            else:
                raw_state_dict[key] = mean_prompt_encoder[key]



        self._model.load_state_dict(raw_state_dict, strict=True)
        self._model._freeze_param(["model"])
        self._model._unfreeze_param(["prompt_encoder"])

        self._loss_fct = torch.nn.CrossEntropyLoss(ignore_index = -1, reduce = False, label_smoothing = self._label_smoothing)

    def _evaluate(self, dataloader):
        eval_loss = 0
        eval_step = 0
        self._model.eval()
        with torch.no_grad():
            for step, batch in enumerate(tqdm(dataloader)):
                data = {k:v.to(torch.device("cuda")) for k,v in batch.items()}

                _input = data['input']
                _batch, _len = _input.shape
                _target = data['target']
                _msk = data['mask']
                output = self._model(_input)
                lm_logits = output.logits
                loss = self._loss_fct(lm_logits.view(-1, lm_logits.size(-1)), _target.view(-1)).view(_batch, _len)

                loss = loss * _msk 
                loss = loss.sum() / (_msk.sum() + 0.0001)
                loss = loss.mean() 
                
                eval_loss+=loss
                eval_step+=1
                if(eval_step == self._max_eval_step):
                    break

        avg_eval_loss  = eval_loss / eval_step
        return avg_eval_loss.item()

    def _train(self, logger):
        self._model._freeze_param(["model"])
        self._model._unfreeze_param(["prompt_encoder"])
        if self._max_train_step == -1:
            total_train_step = self._local_epochs * int(len(self._dataloader) / self._gradient_accumulation_steps)
        else:
            total_train_step = self._max_train_step
        warmup_step = int(total_train_step * 0.06)+1
        optimizer = AdamW(params = self._model.parameters(), lr = self._lr, weight_decay = self._weight_decay)
        lr_scheduler = get_scheduler("linear", optimizer = optimizer, num_warmup_steps = warmup_step, num_training_steps = self._max_train_step)
        

        self._model.train()
        train_step = 0
        for epoch in range(self._local_epochs):
            if(train_step == self._max_train_step):
                break
            for step, batch in enumerate(tqdm(self._dataloader)):
                data = {k:v.to(torch.device("cuda")) for k,v in batch.items()}

                _input = data['input']
                _batch, _len = _input.shape
                _target = data['target']
                _msk = data['mask']
                output = self._model(_input)
                lm_logits = output.logits
                loss = self._loss_fct(lm_logits.view(-1, lm_logits.size(-1)), _target.view(-1)).view(_batch, _len)

                loss = loss * _msk 
                loss = loss.sum() / (_msk.sum() + 0.0001)
                loss = loss.mean() 

                loss = loss / self._gradient_accumulation_steps
                loss.backward()

                if ((step + 1) % self._gradient_accumulation_steps) == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    train_step += 1
                    lr_scheduler.step()
                if(train_step == self._max_train_step):
                    break
    def _save(self):
        prompt_encoder = {}
        for key in self._model.state_dict():
            if "prompt_encoder" in key:
                prompt_encoder[key] = self._model.state_dict()[key]
        torch.save(prompt_encoder, os.path.join(self._save_dir, f"{self._r}", "agg_prompt_encoder.pth"))
        del self._model
        torch.cuda.empty_cache()   

    def _generate(self, test_data_list, context_pred_refs_dict, context_list, generation_dir, tokenizer):
        self._model.eval()
        with torch.no_grad():
            for step , data in enumerate(tqdm(test_data_list)):
                output = self._model.generate(
                    input_ids=data["query"],
                    max_length=data["query_len"].item() + 64 + self._prefix_len,
                    #batch_size=1,
                    do_sample=False,
                    num_beams = 10,
                    no_repeat_ngram_size = 4,
                    length_penalty = 0.9
                )
                pred = tokenizer.decode(output.tolist()[0]).split('<|endoftext|>')[1].split('\n\n')[0].split('\n')[0].strip() 
                id = data["id"]
                context = context_list[id]
                assert context_pred_refs_dict[context]["pred"] == "[No answer]"
                context_pred_refs_dict[context]["pred"] = pred
                if step == self._max_test_step:
                    break
        refss = [context_pred_refs_dict[context]['refs'] for context in context_pred_refs_dict]
        preds = [context_pred_refs_dict[context]['pred'] for context in context_pred_refs_dict]
        with open(os.path.join(generation_dir,"refs.txt"), 'w', encoding='utf8') as refs_writer, \
            open(os.path.join(generation_dir,"pred.txt"), 'w', encoding='utf8') as pred_writer:
            for refs, pred in zip(refss, preds):
                for r in refs:
                    refs_writer.write(r + '\n')
                refs_writer.write('\n')
                pred_writer.write(pred + '\n')


"""
        self._logger.info(self._teacher_model)
        for n,p in self._teacher_model.named_parameters():
            self._logger.info(n+" "+str(p.requires_grad) + " " + str(p.shape))
        self._logger.info(self._student_model)
        for n,p in self._student_model.named_parameters():
            self._logger.info(n+" "+str(p.requires_grad) + " " + str(p.shape))

        self._teacher_model._freeze_param(["model", "prompt_encoder"])
        self._student_model._freeze_param(["prefix_encoder"])
        self._logger.info("\n\nfreeze_teacher\t\t")
        for n,p in self._teacher_model.named_parameters():
            self._logger.info(n+" "+str(p.requires_grad) + " " + str(p.shape))
        self._logger.info("\n\nfreeze_studuent\t\t")
        for n,p in self._student_model.named_parameters():
            self._logger.info(n+" "+str(p.requires_grad) + " " + str(p.shape))
"""

"""

        self._logger.info(str(r) +  str(client_id))
        for n in raw_state_dict:
            self._logger.info(n + " " +  str(raw_state_dict[n].shape) + " " + str(raw_state_dict[n].requires_grad))
        self._logger.info("named")
        for n,p in self._model.named_parameters():
            self._logger.info(n+" "+str(p.requires_grad) + " " + str(p.shape))
"""