import torch.nn as nn
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, RobertaForMaskedLM
from transformers.optimization import get_linear_schedule_with_warmup, AdamW
import pytorch_lightning as pl
import numpy as np

from util import f1_score
from functools import partial
from typing import List, TypedDict

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"


class ChatFormat:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def encode_header(self, message) -> List[int]:
        tokens = []
        tokens.extend(self.tokenizer.encode("<|start_header_id|>"))
        tokens.extend(self.tokenizer.encode(message["role"]))
        tokens.extend(self.tokenizer.encode("<|end_header_id|>"))
        tokens.extend(self.tokenizer.encode("\n\n"))
        return tokens

    def encode_message(self, message) -> List[int]:
        tokens = self.encode_header(message)
        tokens.extend(
            self.tokenizer.encode(message["content"].strip())
        )
        tokens.extend(self.tokenizer.encode("<|eot_id|>"))
        return tokens

    def encode_dialog_prompt(self, dialog) -> List[int]:
        tokens = []
        tokens.extend(self.tokenizer.encode("<|begin_of_text|>"))
        for message in dialog:
            tokens.extend(self.encode_message(message))
        # Add the start of an assistant message for the model to complete.
        tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
        return tokens


class KnowEnhancer(object):

    def __init__(self, model_name_or_path, dtype, device):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.device = device
        self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=dtype,
                                                          use_auth_token=True, device_map=None)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model.to(device)

        if 'llama2' in self.model_name_or_path:
            self.temperature = 0.6
            self.top_p = 0.9
        else:
            self.temperature = 0.9
            self.top_p = 1.0

    @torch.inference_mode()
    def llm_generate(self, dialog: List[TypedDict], max_new_tokens=128) -> str:

        # add "<<SYS>>\n{system_prompt}\n<</SYS>>\n\n" in first dialog
        if 'llama2' in self.model_name_or_path:
            # add "<<SYS>>\n{system_prompt}\n<</SYS>>\n\n" in first dialog
            if dialog[0]["role"] == "system":
                dialog = [
                             {
                                 "role": dialog[1]["role"],
                                 "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
                             }
                         ] + dialog[2:]
            # check roles
            assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
                [msg["role"] == "assistant" for msg in dialog[1::2]]
            ), (
                "model only supports 'system', 'user' and 'assistant' roles, "
                "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
            )
            # add chat history
            texts = []
            for prompt, answer in zip(
                    dialog[::2],
                    dialog[1::2],
            ):
                texts.append(
                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} "
                )
            # check last message if role is user, then add it to prompt text
            assert (
                    dialog[-1]["role"] == "user"
            ), f"Last message must be from user, got {dialog[-1]['role']}"
            texts.append(f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}")

            text_tokens = self.tokenizer(["".join(texts)], return_tensors="pt").input_ids

        else:
            chatFormat = ChatFormat(self.tokenizer)
            text_tokens = torch.tensor([chatFormat.encode_dialog_prompt(dialog)])
            # t1 = self.tokenizer.decode(text_tokens)

        inputs = text_tokens.to(self.device)
        prompt_tokens_len = len(inputs[0])
        generate_kwargs = dict(
            inputs=inputs,
            max_new_tokens=max_new_tokens,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=40,
            repetition_penalty=1.0,
            do_sample=True,
            # num_beams=1,
        )
        output_ids = self.model.generate(**generate_kwargs)
        total_tokens_len = len(output_ids[0])
        # output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        output = self.tokenizer.decode(output_ids[0][prompt_tokens_len:], skip_special_tokens=True)
        if "llama2" in self.model_name_or_path:
            response = output.strip().replace('</s>', "").strip()  # 解析 chat 模版
        else:
            response = output.strip().replace('<|eot_id|>', "").replace(
                '<|start_header_id|>assistant<|end_header_id|>\n\n', '').strip()  # 解析 chat 模版

        return response


class Extractor(pl.LightningModule):
    def __init__(self, model, tokenizer, class_tokens, label_word_idx, na_rel,
                 lr=5e-5, weight_decay=0.01, warmup_ratio=0.1, aux_ratio=0.1):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer

        num_labels = len(class_tokens)
        continuous_label_word = [a[0] for a in self.tokenizer(class_tokens, add_special_tokens=False)['input_ids']]
        self.model.resize_token_embeddings(len(self.tokenizer))
        with torch.no_grad():
            word_embeddings = self.model.get_input_embeddings()
            for i, idx in enumerate(label_word_idx):
                word_embeddings.weight[continuous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)

        assert torch.equal(self.model.get_input_embeddings().weight, word_embeddings.weight)
        assert torch.equal(self.model.get_input_embeddings().weight, self.model.get_output_embeddings().weight)
        self.word2label = continuous_label_word
        self.loss_fn = nn.CrossEntropyLoss()
        self.aux_fn = torch.nn.LogSigmoid()
        self.aux_ratio = aux_ratio
        self.eval_fn = partial(f1_score, rel_num=num_labels, na_num=na_rel)

        self.val_outputs = []
        self.train_outputs = []
        self.val_stats = []
        self.init_pred = None
        self.test_outputs = []
        self.best_f1 = 0

        self.lr = lr
        self.weight_decay = weight_decay
        self.warmup_ratio = warmup_ratio
        self.num_training_steps = None

    def forward(self, x):
        return self.model(x)

    def pvp(self, logits, input_ids):
        # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]
        # ! hard coded
        _, mask_idx = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
        bs = input_ids.shape[0]
        mask_output = logits[torch.arange(bs), mask_idx]
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
        final_output = mask_output[:, self.word2label]

        return final_output

    def get_loss(self, logits, input_ids, labels):
        final_output = self.pvp(logits, input_ids)
        loss = self.loss_fn(final_output, labels)
        return nn.functional.softmax(final_output, dim=1), loss


    def get_aux_loss(self, e1_input_ids, e1_attention_mask, e2_input_ids, e2_attention_mask, label_inds):
        e1_rep = torch.mean(
            self.model(e1_input_ids, e1_attention_mask, return_dict=True, output_hidden_states=True).hidden_states[-1],
            dim=1)
        e2_rep = torch.mean(
            self.model(e2_input_ids, e2_attention_mask, return_dict=True, output_hidden_states=True).hidden_states[-1],
            dim=1)

        real_relation_embedding = self.model.get_output_embeddings().weight[label_inds]
        aux_loss = -1.0* self.aux_fn(0.2-torch.norm(e1_rep + real_relation_embedding - e2_rep, p=2)/real_relation_embedding.shape[0])
        return aux_loss

    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, labels, e1_input_ids, e1_attention_mask, e2_input_ids, e2_attention_mask = batch
        result = self.model(input_ids, attention_mask, return_dict=True, output_hidden_states=True).logits
        final_output, loss = self.get_loss(result, input_ids, labels)
        aux_loss = self.get_aux_loss(e1_input_ids, e1_attention_mask, e2_input_ids, e2_attention_mask,
                                     labels+self.word2label[0])
        final_loss = loss + self.aux_ratio*aux_loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_aux_loss", aux_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.train_outputs.append({"train_logits": final_output.detach().cpu().numpy(),
                                   "train_labels": labels.detach().cpu().numpy()})
        return final_loss

    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, labels = batch
        logits = self.model(input_ids, attention_mask, return_dict=True).logits
        final_output, loss = self.get_loss(logits, input_ids, labels)
        self.log("Eval/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.val_outputs.append({"eval_logits": final_output.detach().cpu().numpy(),
                                 "eval_labels": labels.detach().cpu().numpy()})

    def on_train_epoch_end(self) -> None:
        logits = np.concatenate([o["train_logits"] for o in self.train_outputs])
        labels = np.concatenate([o["train_labels"] for o in self.train_outputs])

        val_stat = logits[range(len(labels)), labels].tolist()
        if (self.current_epoch + 1) % self.trainer.reload_dataloaders_every_n_epochs == 0:
            print(f'save val stat in epoch {self.current_epoch}.')
            self.val_stats.append(val_stat)

        _, init_pred = self.eval_fn(logits, labels)
        if self.init_pred is None and self.current_epoch + 1 == self.trainer.reload_dataloaders_every_n_epochs:
            self.init_pred = init_pred

        self.train_outputs.clear()

    def on_validation_epoch_end(self) -> None:
        logits = np.concatenate([o["eval_logits"] for o in self.val_outputs])
        labels = np.concatenate([o["eval_labels"] for o in self.val_outputs])

        # val_stat = logits[range(len(labels)), labels].tolist()
        # if (self.current_epoch + 1) % self.trainer.reload_dataloaders_every_n_epochs == 0:
        #     print(f'save val stat in epoch {self.current_epoch}.')
        #     self.val_stats.append(val_stat)

        f1d, init_pred = self.eval_fn(logits, labels)
        # if self.init_pred is None and self.current_epoch + 1 == self.trainer.reload_dataloaders_every_n_epochs:
        #     self.init_pred = init_pred
        f1 = f1d['f1']
        self.log("Eval/f1", f1)
        if f1 > self.best_f1:
            self.best_f1 = f1
        self.log("Eval/best_f1", self.best_f1, prog_bar=True, on_epoch=True)

        self.val_outputs.clear()

    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, labels = batch
        logits = self.model(input_ids, attention_mask, return_dict=True).logits
        logits = self.pvp(logits, input_ids)
        self.test_outputs.append({"test_logits": logits.detach().cpu().numpy(),
                                  "test_labels": labels.detach().cpu().numpy()})

    def on_test_epoch_end(self) -> None:
        logits = np.concatenate([o["test_logits"] for o in self.test_outputs])
        labels = np.concatenate([o["test_labels"] for o in self.test_outputs])

        f1d, _ = self.eval_fn(logits, labels)
        self.log("Test/f1", f1d['f1'])

    def configure_optimizers(self):
        no_decay_param = ["bias", "LayerNorm.weight"]
        parameters = self.model.named_parameters()

        # only optimize the embedding parameters
        optimizer_group_parameters = [
            {"params": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)],
             "weight_decay": self.weight_decay},
            {"params": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], "weight_decay": 0}
        ]

        optimizer = AdamW(optimizer_group_parameters, lr=self.lr, eps=1e-8)
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=self.num_training_steps * self.warmup_ratio,
                                                    num_training_steps=self.num_training_steps)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                'scheduler': scheduler,
                'interval': 'step',  # or 'epoch'
                'frequency': 1,
            }
        }