from functools import total_ordering
from logging import debug
from transformers import AutoTokenizer
from typing import Dict, List
import pickle

try:
    import marisa_trie
except ModuleNotFoundError:
    pass


class Trie(object):
    def __init__(self, sequences: List[List[int]] = []):
        self.trie_dict = {}
        self.len = 0
        if sequences:
            for sequence in sequences:
                Trie._add_to_trie(sequence, self.trie_dict)
                self.len += 1

        self.append_trie = None
        self.bos_token_id = None

    def append(self, trie, bos_token_id):
        self.append_trie = trie
        self.bos_token_id = bos_token_id

    def add(self, sequence: List[int]):
        Trie._add_to_trie(sequence, self.trie_dict)
        self.len += 1

    def get(self, prefix_sequence: List[int]):
        return Trie._get_from_trie(
            prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
        )

    @staticmethod
    def load_from_dict(trie_dict):
        trie = Trie()
        trie.trie_dict = trie_dict
        trie.len = sum(1 for _ in trie)
        return trie

    @staticmethod
    def _add_to_trie(sequence: List[int], trie_dict: Dict):
        if sequence:
            if sequence[0] not in trie_dict:
                trie_dict[sequence[0]] = {}
            Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])

    @staticmethod
    def _get_from_trie(
        prefix_sequence: List[int],
        trie_dict: Dict,
        append_trie=None,
        bos_token_id: int = None,
    ):
        if len(prefix_sequence) == 0:
            output = list(trie_dict.keys())
            if append_trie and bos_token_id in output:
                output.remove(bos_token_id)
                output += list(append_trie.trie_dict.keys())
            return output
        elif prefix_sequence[0] in trie_dict:
            return Trie._get_from_trie(
                prefix_sequence[1:],
                trie_dict[prefix_sequence[0]],
                append_trie,
                bos_token_id,
            )
        else:
            if append_trie:
                return append_trie.get(prefix_sequence)
            else:
                return []

    def __iter__(self):
        def _traverse(prefix_sequence, trie_dict):
            if trie_dict:
                for next_token in trie_dict:
                    yield from _traverse(
                        prefix_sequence + [next_token], trie_dict[next_token]
                    )
            else:
                yield prefix_sequence

        return _traverse([], self.trie_dict)

    def __len__(self):
        return self.len

    def __getitem__(self, value):
        return self.get(value)


class MarisaTrie(object):
    def __init__(
        self,
        sequences: List[List[int]] = [],
        cache_fist_branch=True,
        max_token_id=256001,
    ):

        self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + (
            [chr(i) for i in range(65000, max_token_id + 10000)]
            if max_token_id >= 55000
            else []
        )
        self.char2int = {self.int2char[i]: i for i in range(max_token_id)}

        self.cache_fist_branch = cache_fist_branch
        if self.cache_fist_branch:
            self.zero_iter = list({sequence[0] for sequence in sequences})
            assert len(self.zero_iter) == 1
            self.first_iter = list({sequence[1] for sequence in sequences})

        self.trie = marisa_trie.Trie(
            "".join([self.int2char[i] for i in sequence]) for sequence in sequences
        )

    def get(self, prefix_sequence: List[int]):
        if self.cache_fist_branch and len(prefix_sequence) == 0:
            return self.zero_iter
        elif (
            self.cache_fist_branch
            and len(prefix_sequence) == 1
            and self.zero_iter == prefix_sequence
        ):
            return self.first_iter
        else:
            key = "".join([self.int2char[i] for i in prefix_sequence])
            return list(
                {
                    self.char2int[e[len(key)]]
                    for e in self.trie.keys(key)
                    if len(e) > len(key)
                }
            )

    def __iter__(self):
        for sequence in self.trie.iterkeys():
            yield [self.char2int[e] for e in sequence]

    def __len__(self):
        return len(self.trie)

    def __getitem__(self, value):
        return self.get(value)


class DummyTrieMention(object):
    def __init__(self, return_values):
        self._return_values = return_values

    def get(self, indices=None):
        return self._return_values


class DummyTrieEntity(object):
    def __init__(self, return_values, codes):
        self._return_values = list(
            set(return_values).difference(
                set(
                    codes[e]
                    for e in (
                        "start_mention_token",
                        "end_mention_token",
                        "start_entity_token",
                    )
                )
            )
        )
        self._codes = codes

    def get(self, indices, depth=0):
        if len(indices) == 0 and depth == 0:
            return self._codes["end_mention_token"]
        elif len(indices) == 0 and depth == 1:
            return self._codes["start_entity_token"]
        elif len(indices) == 0:
            return self._return_values
        elif len(indices) == 1 and indices[0] == self._codes["end_entity_token"]:
            return self._codes["EOS"]
        else:
            return self.get(indices[1:], depth=depth + 1)


def get_trie(args, tokenizer):
    rel2id = {}
    cnt = 0
    with open(f"{args.data_dir}/temp.txt", "r") as file:
        idx = 0
        total_entity_ids = [] 
        for line in file.readlines():
            t = line.strip().split("\t")
            entity_name = " ".join([t[2],t[-1]] + t[3:-1]).strip()
            total_entity_ids.append(tokenizer(entity_name, add_special_tokens=True).input_ids)
            rel2id[entity_name] = cnt
            cnt += 1
            # show the entities
            if idx < 5:  print(entity_name, total_entity_ids[-1])
            idx += 1

        max_len = max(len(_) for _ in total_entity_ids)
        print("*"*10 + f"max output length : {max_len}" + "*"*10)
        
        # add </s> and <
        eos_id = tokenizer.eos_token_id
        if "bart" in args.model_name_or_path:
            trie = Trie([[eos_id] + _ for _ in total_entity_ids])
        else:
            trie = Trie([[eos_id] + _ for _ in total_entity_ids])
    #TODO hard coded
    return trie, rel2id

if __name__ == "__main__":
    total_entity_ids = []
    model_name = "facebook/bart-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    dataset = "FB15k-237"
    with open(f"../dataset/{dataset}/entity2text.txt", "r") as file:
        idx = 0
        for line in file.readlines():
            entity_name = line.split("\t")[-1].strip()
            total_entity_ids.append(tokenizer(entity_name, add_special_tokens=True).input_ids)
            # show the entities
            if idx < 5:  print(entity_name, total_entity_ids[-1])
            idx += 1
    
    # add </s> and <
    trie = Trie([[2] + _[1:] for _ in total_entity_ids])
    model_real_name = model_name.split("/")[-1]
    with open(f"{model_real_name}_{dataset}.pkl", "wb") as file:
        pickle.dump(trie, file)
    


