import logging
import re
import json
import os.path as osp
from collections import defaultdict

import nltk
from nltk.stem.porter import *
import numpy as np
import random

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.cluster import KMeans
from datasets.cluster import kmeans_clustering, dbscan_clustering
from sentence_transformers import SentenceTransformer


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def load_data(args, config, tokenizer, split="train"):
    dataset = KeyphraseDataset(args, config, tokenizer, split)
    train_sampler = None

    if split == "train":
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)

        dataloader = DataLoader(dataset,
                                batch_size=args.train_batch_size,
                                collate_fn=dataset.llama_collate_fn,
                                worker_init_fn=seed_worker,
                                num_workers=args.num_workers,
                                sampler=train_sampler,
                                shuffle=True if train_sampler is None else False,
                                drop_last=True,
                                pin_memory=True)
    elif split == "valid":
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    
        dataloader = DataLoader(dataset,
                                batch_size=args.val_batch_size,
                                collate_fn=dataset.llama_collate_fn,
                                worker_init_fn=seed_worker,
                                num_workers=args.num_workers,
                                sampler=train_sampler,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True)
    elif split =="test":
        dataloader = DataLoader(dataset,
                                batch_size=args.test_batch_size,
                                collate_fn=dataset.llama_collate_fn,
                                shuffle=False,
                                drop_last=False)
    else:
        raise ValueError("Data split must be either train/valid/test.")
    
    return dataloader, train_sampler


class KeyphraseDataset(Dataset):

    def __init__(self, args, config, tokenizer, split="train"):
        self.args = args
        self.config = config
        self.tokenizer = tokenizer
        self.split = split
        self.max_pos_num = config.max_position_embeddings
        self.few_shot = args.few_shot

        if args.absent:
            absent_str = '_absent'
        else:
            absent_str = ''
    
        if self.args.random:
            self.save_file = osp.join(args.data_dir, f"{split}_random{absent_str}.json")
        elif self.few_shot == -1 or split != 'train':
            self.save_file = osp.join(args.data_dir, f"{split}{absent_str}.json")
            
        self.stemmer = PorterStemmer()

        if not osp.exists(self.save_file):
            self.__load_and_cache_examples()

        self.offset_dict = {}
        with open(self.save_file, "rb") as f:
            self.offset_dict[0] = 0
            for line, _ in enumerate(f, 1):
                offset = f.tell()
                self.offset_dict[line] = offset
        self.offset_dict.popitem()


    def __load_and_cache_examples(self):
        '''
        inputs:
        {
            "src"
            "candidates"
            "trg"
        }
        outputs:
        {
            "input_ids"
            "attention_mask"
            "labels"
        }
        '''
        # prompt_template = "Read following document and select keyphrases from candidates. \n Document: \n {} \n Candidate: \n {}Correct keyphrases: \n"
        prompt_template = "Select the final set of keyphrases for a given document from the candidate keyphrases marked by numbers. T signifies choosing the keyphrase, and F signifies discarding it. For example, T F F indicates choose candidate [1] while discard candidates [2] and [3]. The final set should be semantically non-redundant.\nDocument: {}\nCandidates:\n{}Final set:"
        # data/inspec/test_beam10 --> data/inspec
        if self.split == "train":
            root_data_dir = self.args.data_dir
            src_path = osp.join(root_data_dir, f"{self.split}_src.txt")
            trg_path = osp.join(root_data_dir, f"{self.split}_trg.txt")
        elif self.split == "valid":
            root_data_dir = self.args.data_dir
            src_path = osp.join(root_data_dir, f"{self.split}_src.txt")
            trg_path = osp.join(root_data_dir, f"{self.split}_trg.txt")
        else:
            root_data_dir = osp.join(self.args.data_dir, "..")
            src_path = osp.join(root_data_dir, f"{self.split}_src.txt")
            trg_path = osp.join(root_data_dir, f"{self.split}_trg.txt")
        if self.args.absent:
            candidate_path = osp.join(self.args.data_dir, f"{self.split}_predictions_absent.txt")
        else:
            candidate_path = osp.join(self.args.data_dir, f"{self.split}_predictions_present.txt")
        # candidate_path = osp.join("data/kp20k/test_beam10", f"{self.split}_predictions_present.txt")
        # candidate_path = osp.join(osp.join(self.args.data_dir, ".."), "test_trg.txt")

        process_data_path = osp.join(self.save_file)
        process_data_f = open(process_data_path, "w")

        model = SentenceTransformer("data/PLM/keyphrase-mpnet-v1")

        empty_trg_line = 0
        long_line = 0
        empty_candidate_line = 0
        for line_i, (src_line, trg_line, candidate_line) in enumerate(zip(open(src_path), open(trg_path), open(candidate_path))):
            if line_i % 10000 == 0:
                print(f"Processing {line_i}th line")

            src_line = src_line.strip()
            trg_line = trg_line.strip()
            candidate_line = candidate_line.strip()

            # Filter empty lines
            if src_line == "":
                continue
            trg_list = trg_line.split(";")
            trg_list = list(filter(None, trg_list))
            trg_list = [re.sub('\s{2,}', ' ', trg).strip() for trg in trg_list]
            candidates_list = candidate_line.split(";")
            candidates_list = list(filter(None, candidates_list))
            candidates_list = [re.sub('\s{2,}', ' ', trg).strip() for trg in candidates_list]
            
            if "<peos>" in trg_list:
                peos_idx = trg_list.index("<peos>")
                pre_trgs = trg_list[:peos_idx]
                abs_trgs = trg_list[peos_idx+1:]
                if self.args.absent:
                    trgs = abs_trgs
                else:
                    trgs = pre_trgs
            else:
                trgs = trg_list
            
            if len(trgs) == 0 and self.split == "train":
                empty_trg_line += 1
                continue
            trgs_stem = [" ".join([self.stemmer.stem(w.lower().strip()) for w in re.split("[ -]", trg.strip())]) for trg in trgs]
            
            # lowercase
            candidates_list = [c.lower() for c in candidates_list]
            # replace("'s", " 's").replace("s'", "s")
            candidates_list = [c.replace("'s", " 's").replace("s'", "s") for c in candidates_list]

            # # random
            if self.args.random:
                random.shuffle(candidates_list)
            elif self.args.cluster:
                # candidates_list = kmeans_clustering(model, candidates_list, 3)
                candidates_list = dbscan_clustering(model, candidates_list)

            candidates_list_stem = [" ".join([self.stemmer.stem(w) for w in re.split("[ -]", c.strip())]) for c in candidates_list]
            candidates_str = ""
            #A. candidate1\n
            #B. candidate2\n
            # count = 1
            for i, candidate in enumerate(candidates_list):
                c_str = f"[{i}] {candidate}"
                candidates_str += c_str + "\n"
                # if candidates_list_stem[i] in trgs_stem:
                #     continue
                # c_str = f"[{count}] {candidate}"
                # candidates_str += c_str + "\n"
                # count += 1
            
            answer_str = ''
            candidate_true_number = 0
            for c_index, candidate in enumerate(candidates_list_stem):
                if candidate in trgs_stem:
                    if len(answer_str) > 0:
                        answer_str += " "
                    answer_str += "T"
                    candidate_true_number += 1
                else:
                    if len(answer_str) > 0:
                        answer_str += " "
                    answer_str += "F"

            if (len(answer_str)==0 or candidate_true_number==0) and self.split == "train":
                empty_candidate_line += 1
                continue

            input_str = prompt_template.format(src_line, candidates_str)
            input_ids = [self.tokenizer.bos_token_id]
            input_ids += self.tokenizer(input_str, add_special_tokens=False).input_ids
            # input_ids += [self.tokenizer.eos_token_id, self.tokenizer.bos_token_id]
            target_mask = [0] * len(input_ids)
            
            if self.split == "train":
                input_ids += self.tokenizer(answer_str, add_special_tokens=False).input_ids
                input_ids += [self.tokenizer.eos_token_id]
                target_mask += [1] * (len(input_ids) - len(target_mask))
                labels = input_ids.copy()
            else:
                labels = self.tokenizer(answer_str, add_special_tokens=False).input_ids
                labels += [self.tokenizer.eos_token_id]

            if len(input_ids) > self.args.llm_max_text_len and self.split != "test":
                long_line += 1
                continue
                
            
            # input_ids = input_ids[:self.config.max_position_embeddings]
            # target_mask = target_mask[:self.config.max_position_embeddings]
            attention_mask = [1] * len(input_ids)
            assert len(input_ids) == len(target_mask) == len(attention_mask)
            result = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "target_mask": target_mask,
                "candidates_number": len(candidates_list),
                "targets_number": len(trgs_stem)
            }
            # result = self.tokenizer(input_str, 
            #                         return_tensors=None, 
            #                         max_length=self.args.llm_max_text_len, truncation=True, 
            #                         padding=False)
            # if (result["input_ids"][-1] != self.tokenizer.eos_token_id):
            #     result["input_ids"].append(self.tokenizer.eos_token_id)
            #     result["attention_mask"].append(1)
            # result["labels"] = result["input_ids"].copy()

            json.dump(result, process_data_f)
            process_data_f.write("\n")

        process_data_f.close() 
        print(f"Empty trg line: {empty_trg_line}")
        print(f"Empty candidate line: {empty_candidate_line}")
        print(f"Long line: {long_line}")    


    def llama_collate_fn(self, batch):
        all_input_ids = []
        all_attention_mask = []
        all_labels = []
        all_target_mask = []
        all_candidates_number = []
        all_targets_number = []
        example_max_length = max([len(batch[idx]['input_ids']) for idx in range(len(batch))])
        max_length = min(self.max_pos_num, example_max_length)
        max_label_length = max([len(batch[idx]['labels']) for idx in range(len(batch))])
        
        for i in range(len(batch)):
            input_ids = batch[i]['input_ids']
            attention_mask = batch[i]['attention_mask']
            labels = batch[i]['labels']
            target_mask = batch[i]['target_mask']
            candidates_number = batch[i]['candidates_number']
            targets_number = batch[i]['targets_number']
            if len(input_ids) > max_length:
                # input_ids = input_ids[-max_length:]
                # attention_mask = attention_mask[-max_length:]
                # # if not eval: 
                # labels = labels[-max_length:]
                # target_mask = target_mask[-max_length:]
                raise ValueError(f"Input length {len(input_ids)} is larger than max length {max_length}")
            else:
                padding_length = max_length - len(input_ids)
                input_ids = [0] * padding_length + input_ids
                attention_mask = [0] * padding_length + attention_mask
                if self.split == "train": 
                    labels = [-100] * padding_length + labels
                else:
                    labels = labels + [-100] * (max_label_length - len(labels))
                target_mask = [0] * padding_length + target_mask

            # if eval: 
            #     assert input_ids[-1] == 13
            # else:
            #     # 2是sep, 13是eos
            #     assert input_ids[-3] == 13 and input_ids[-1] == 2
            #     assert labels[-3] == -100 and labels[-2] != -100
            
            all_input_ids.append(torch.tensor(input_ids).long())
            all_attention_mask.append(torch.tensor(attention_mask).long())
            all_labels.append(torch.tensor(labels).long())
            all_target_mask.append(torch.tensor(target_mask).long())
            all_candidates_number.append(torch.tensor(candidates_number).long())
            all_targets_number.append(torch.tensor(targets_number).long())
        
        # print('all_input_ids', torch.vstack(all_input_ids).shape)

        return {
            'input_ids': torch.vstack(all_input_ids),
            'attention_mask': torch.vstack(all_attention_mask),
            'labels': torch.vstack(all_labels),
            'target_mask': torch.vstack(all_target_mask),
            'candidates_number': torch.vstack(all_candidates_number),
            'targets_number': torch.vstack(all_targets_number)
        }


    def __len__(self):
        return len(self.offset_dict)

    def __getitem__(self, line):
        offset = self.offset_dict[line]
        with open(self.save_file) as f:
            f.seek(offset)
            return json.loads(f.readline())
