import argparse
import re
import os
import os.path as osp

import torch
from tqdm import tqdm

from datasets.kp20k import load_data
from config import *
from transformers import BitsAndBytesConfig, AutoConfig, AutoTokenizer
from logger import FileLogger
from model import LlamaForCausalLM
from peft import PeftModel
import torch.nn.functional as F


class Ranker:

    def __init__(self):
        ### Load config / tokenizer / model ###
        # self.config = load_config(args)
        # self.tokenizer = load_tokenizer(args)
        self.config = AutoConfig.from_pretrained(
            args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
        self.valid_label_ids = [self.tokenizer('T', add_special_tokens=False)['input_ids'][0], 
                                self.tokenizer('F', add_special_tokens=False)['input_ids'][0]] 
        # self.model = load_model(args, self.config, self.tokenizer)
        # self.config.semicolon_token_id = self.tokenizer.convert_tokens_to_ids(";")

        ### Load data ###
        self.test_loader, _ = load_data(args, self.config, self.tokenizer, split="test")

        ### Load trained parameter weights ###
        if osp.exists(args.ckpt_model_path):
            # load base model
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            self.model = LlamaForCausalLM.from_pretrained(
                args.llm_base_model,
                quantization_config=bnb_config,
                device_map='auto',
                cache_dir=args.llm_cache_dir,
            )
                # model.gradient_checkpointing_enable()
                # model = prepare_model_for_kbit_training(model)
                # config = LoraConfig(
                #     r=args.lora_r,
                #     lora_alpha=args.lora_alpha,
                #     target_modules=args.lora_target_modules,
                #     lora_dropout=args.lora_dropout,
                #     bias='none',
                #     task_type="CAUSAL_LM",
                # )
                # model = get_peft_model(model, config)
            # # 是否使用4bit量化进行推理
            # quantization_config = BitsAndBytesConfig(
            #     load_in_4bit=True,
            #     bnb_4bit_compute_dtype=torch.float16,
            #     bnb_4bit_use_double_quant=True,
            #     bnb_4bit_quant_type="nf4",
            #     llm_int8_threshold=6.0,
            #     llm_int8_has_fp16_weight=False,
            # )

            # # 加载base model
            # model = AutoModelForCausalLM.from_pretrained(
            #     model_name_or_path,
            #     load_in_4bit=load_in_4bit,
            #     trust_remote_code=True,
            #     low_cpu_mem_usage=True,
            #     torch_dtype=torch.float16,
            #     device_map='auto',
            #     quantization_config=quantization_config
            # )

            # 加载adapter
            # if adapter_name_or_path is not None:
            self.model = PeftModel.from_pretrained(self.model, args.ckpt_model_path)

            log.console(f"Loading model checkpoint from {args.ckpt_model_path}...")
        else:
            raise Exception("No checkpoint found!")


    @torch.no_grad()
    def rank(self):
        total = len(self.test_loader)
        f = open(out_filepath, "w")

        self.model.eval()

        vocab_size = self.config.vocab_size
        # self.valid_label_ids对应列为1，剩下置零
        vocab_mask = torch.zeros(vocab_size, dtype=torch.bool).to(args.device)
        vocab_mask[self.valid_label_ids] = 1
        all_preds = []
        all_labels = []
        with tqdm(desc="Ranking", total=total, ncols=100) as pbar:
            for step, inputs in enumerate(self.test_loader, 1):
                for k, v in inputs.items():
                    inputs[k] = v.to(args.device)

                candidates_number = inputs['candidates_number']
                max_length = candidates_number.max().item()
                # select KPs
                outputs = self.rank_step(inputs, vocab_mask, max_length=max_length)
                # outputs = outputs[:, -max_length:]
                assert outputs.shape[1] == max_length
                # 根据candidates_number转化对应的mask
                outputs_mask = torch.arange(max_length, device=outputs.device)[None, :] < candidates_number[:, None]
                outputs_mask = outputs_mask.squeeze(1)
                # mask掉无效的token
                outputs = outputs.masked_fill(~outputs_mask, -100)

                all_preds.append(outputs)
                all_labels.append(inputs['labels'])
                
                input_strs = self.tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)
                # 取出"candidates:"后"final set"前的部分
                def get_candidates_str(input_str):
                    candidates_str = input_str.split("Candidates:")[1].split("Final set")[0].strip()
                    candidates = candidates_str.split("\n")
                    # candidate形式为"[1] xxx"，取出xxx
                    if not candidates[0]:
                        return []
                    candidates = [c.split("]")[1].strip() for c in candidates]
                    return candidates
                candidates_list = list(map(get_candidates_str, input_strs))

                def select_c(candidates, mask):
                    return [candidates[i] for i in range(len(mask)) if mask[i]]
                outputs = outputs==self.valid_label_ids[0]

                pred_kp_list = list(map(select_c, candidates_list, outputs))
                self.write_pred_to_file(pred_kp_list, f)

                pbar.update(1)
                # debug
                # if step > 3:
                #     break
        
        f.close()
        all_preds = self.pad_fn(all_preds, padding=-100)
        all_labels = self.pad_fn(all_labels, padding=-100)
        scores = self.calculate_scores(all_preds, all_labels)
        log.console(scores)

    
    def rank_step(self, inputs, vocab_mask, max_length=None):
        """Generate sequences for each example in `input_ids`."""
        # set default to output past
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        max_length = max_length if max_length is not None else self.config.max_length
        max_length += input_ids.shape[1]

        # current position and vocab size
        cur_len = input_ids.shape[1]
        past_key_values = None
        pred_ids = None
        all_valid_logits = None
        while cur_len < max_length:
            # model_inputs = self.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
            # outputs = self(**model_inputs)
            # next_token_logits = outputs.logits[:, -1, :]
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, # position_ids
                                 past_key_values=past_key_values, use_cache=True,
                                 return_dict=True)
            next_logits = outputs['logits'][:, -1, :] # -1 is needed when first timestep inference
            valid_logits = next_logits.masked_fill(~vocab_mask, -float('inf'))
            next_token = torch.argmax(valid_logits, dim=-1).unsqueeze(-1)

            # use cache
            past_key_values = outputs.past_key_values
            input_ids = next_token
            if pred_ids is None:
                pred_ids = next_token
                all_valid_logits = valid_logits.unsqueeze(1)
            else:
                pred_ids = torch.cat([pred_ids, next_token], dim=-1)
                all_valid_logits = torch.cat([all_valid_logits, valid_logits.unsqueeze(1)], dim=1)

            # Greedy decoding
            # input_ids = torch.cat([input_ids, next_token], dim=-1)
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

            # update position
            cur_len = cur_len + 1

        # check if every sequence has at least 5 T
        # if not,  set the top 5 T probability predictions to T
        for i in range(all_valid_logits.shape[0]):
            if (pred_ids[i] == self.valid_label_ids[0]).sum() < 5:
                prob = torch.softmax(all_valid_logits[i], -1) 
                k = min(5, all_valid_logits.shape[1])
                topk = torch.topk(prob[:, self.valid_label_ids[0]], k).indices
                pred_ids[i, topk] = self.valid_label_ids[0]

        return pred_ids


    def write_pred_to_file(self, pred_kp_list, _f):
        for pred_kp_l in pred_kp_list:
            # pred_kp_l = re.sub('\s{2,}', ' ', pred_kp_l).strip()
            # pred_kp_l = pred_kp_l.replace(" - ", "-")
            pred_kp_l = ";".join(pred_kp_l)
            _f.write(f"{pred_kp_l.strip()}\n")
    

    def calculate_scores(self, preds, labels):
        if preds is None or labels is None:
            return None

        score_dict = {}

        labels = labels.masked_fill(labels == 2, -100)[:, :-1]
        label_mask = labels != -100

        acc = ((preds == labels) & label_mask).sum(-1) / ((labels != -100).sum(-1) + 1e-20)
        label_T_mask = (labels == self.valid_label_ids[0])
        label_F_mask = (labels == self.valid_label_ids[1])
        acc_T = ((preds == labels) & label_T_mask).sum(-1) / (label_T_mask.sum(-1) + 1e-20)
        acc_F = ((preds == labels) & label_F_mask).sum(-1) / (label_F_mask.sum(-1) + 1e-20)

        score_dict['Acc'] = acc.mean().item()
        score_dict['Acc@T'] = acc_T.mean().item()
        score_dict['Acc@F'] = acc_F.mean().item()

        return score_dict


    def pad_fn(self, lst, padding=0):
        if len(lst) == 0:
            return lst
        max_len = max([x.shape[-1] for x in lst])
        if len(lst[0].shape) == 1:
            for i, x in enumerate(lst):
                lst[i] = F.pad(x, (0, max_len-x.shape[-1]), "constant", padding)
            lst = torch.stack(lst, dim=0)
        elif len(lst[0].shape) == 2:
            for i, x in enumerate(lst):
                lst[i] = F.pad(x, (0, max_len-x.shape[-1]), "constant", padding)
            lst = torch.cat(lst, dim=0)  # (batch_size, max_len)
        return lst
    

if __name__ == "__main__":
    set_template(args)
    # other global variables
    if args.ckpt_model_path is None:
        export_root = EXPERIMENT_ROOT + '/' + args.llm_base_model.split('/')[-1] + '/' + "kp20k"
        args.ckpt_model_path = osp.join(export_root, "bset-f1")

    os.makedirs(args.test_output_dir, exist_ok=True)

    log = FileLogger(args.test_output_dir, is_master=True, is_rank0=True, log_to_file=True)
    log.console(args)

    ranker = Ranker()
    out_filepath = osp.join(args.test_output_dir, f"predictions.txt")
    log.console('Writing predictions to {}'.format(out_filepath))
    ranker.rank()