import torch
from torch.utils.data import Dataset, DataLoader
from utils import * 

class S4DatasetBatch(Dataset):
    """
        build dataset for BERTSeg with batch
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # read words and frequencies
        self.generate_word_freq_table()
        self.word_freq_table = read_frequency_table(self.config['freq_table_file']) # word_freq_table[word] = freq

        if (self.config["use_frequency"]=="none"):
            self.words = [word for (word, freq) in self.word_freq_table.items() for i in range(freq)]
        if (self.config["use_frequency"]=="one"):
            self.words = [word for (word, freq) in self.word_freq_table.items() for i in range(1)]
            self.words = sorted(self.words, key=lambda x: self.word_freq_table[x], reverse=True)
        if (self.config["use_frequency"]=="threshold"):
            self.words = [word for (word, freq) in self.word_freq_table.items() for i in range(int(freq/10))]
        if (self.config['dataset_size']>0): # for test
            self.words = self.words[:self.config['dataset_size']] 

        self.embeddings = self.gene_embeddings() # dict[word] = embedding
        #self.words = self.filter_words(self.words) # filter words with nan embedding
        if (self.config['train']==True): # filter long words
            self.words = self.filter_train_words(self.words) 
        self.tot_frequency = sum(self.word_freq_table.values()) # int

        #sort_by_length = self.config['sort_by_length']
        #sort_by_alphabet = self.config['sort_by_alphabet']
        #if sort_by_length: # from short to long
        #    self.words = sorted(self.words, key=lambda x: len(x))
        #if self.config["sort_by_alphabet"]:
        #    self.words = sorted(self.words, key=lambda x: x)
        self.words_with_tokens = self.add_special_tokens(self.words) # dict[word] = [<s>, token1, ..., tokenn]

        # build the subword vocab
        #self.char_dict = self.build_char_dict() # dict[char] = 1
        self.vocab = self.build_vocab() # dict[word] = index, vocab['<unk>'] = 0
        self.vocab_size = len(self.vocab)
        self.id2word = {i:w for (w, i) in self.vocab.items()} # dict[index] = word
        #self.jump_table = self.build_jump_table() # jump_table[word][t] = list of s

        self.segs_pos = self.gene_segs_pos(self.words) # dict[word] = [seg1, ..., segn]
        self.words_with_tokens_idx = self.build_words_with_tokens_idx()

    def __len__(self):
        return len(self.words)

    def __getitem__(self, index):
        word = self.words[index]
        freq = self.word_freq_table[word]
        embedding = self.embeddings.get(word, self.default_embedding)
        word_with_tokens = self.words_with_tokens[word]
        word_with_tokens_idx = self.words_with_tokens_idx[word]
        segs_pos = self.segs_pos[word]
        return {'word': word, 'freq': freq, 'word_with_tokens': word_with_tokens, 'word_with_tokens_idx': word_with_tokens_idx, 'segs_pos': segs_pos, 'embedding': embedding}
    
    def generate_word_freq_table(self):
        if (os.path.exists(self.config['freq_table_file']) == False):
            print ("Creating frequency table")
            lines = read_lines(self.config['text_file'])
            lines = filter_lines(lines)
            frequency_table = extract_frequency_table_from_lines(lines)
            save_frequency_table(frequency_table, self.config['freq_table_file'])	

    def filter_words(self, words) -> list:
        filtered_words = []
        for word in words:
            embedding = self.embeddings.get(word, None)
            if (embedding is not None and np.isnan(embedding).any() == False):
                filtered_words.append(word)
        return filtered_words

    def filter_train_words(self, words) -> list:
        filtered_words = []
        for word in words:
            if (len(word)>30): continue
            # if non a-z in word continue
            #if (re.search(r'[^a-z]', word)): continue
            filtered_words.append(word)
        return filtered_words

    def add_special_tokens(self, words):
        words_with_tokens = {}
        for word in words:
            if (words_with_tokens.get(word, None) is None):
                #word_with_tokens = ['<s>'] + list(word) + ['</s>']
                word_with_tokens = ['<s>'] + list(word) #+ ['</s>']
                words_with_tokens[word] = word_with_tokens
        return words_with_tokens

    def build_words_with_tokens_idx(self):
        words_with_tokens_idx = {}
        for word in self.words:
            if (words_with_tokens_idx.get(word, None) is None):
                word_with_tokens = self.words_with_tokens[word]
                word_with_tokens_idx = [self.vocab.get(c, self.vocab['<unk>']) for c in word_with_tokens]
                words_with_tokens_idx[word] = word_with_tokens_idx
        return words_with_tokens_idx

    def gene_segs_pos(self, words) -> list:
        """
            Generate subword position masking for each word
        """
        segs_pos = {}
        for word in words:
            # bool type
            if (segs_pos.get(word, None) is None):
                word_with_tokens = self.words_with_tokens[word]
                T = len(word_with_tokens)
                segs_pos[word] = np.zeros((T, T), dtype=int)
                for end_pos in range(T):
                    for start_pos in range(end_pos+1):
                        seg_list = word_with_tokens[start_pos:end_pos+1]
                        seg = ''.join(seg_list)
                        idx = self.vocab.get(seg, 0)
                        segs_pos[word][end_pos][start_pos] = idx
        return segs_pos

    def gene_embeddings(self):
        if (os.path.exists(self.config['embedding_file']) == True):
            # read embedding from pickle file config['embedding_file']
            print ("Loading embedding from pickle file")
            print (self.config['embedding_file'])
            with open(self.config['embedding_file'], 'rb') as f:
                embeddings = pickle.load(f)
        else:
            print ("Generating embedding")
            gene_embed_py = "./characterBERT_embedding.py"
            command = f"python {gene_embed_py} --input_file {self.config['text_file']} --output_file {self.config['embedding_file']}"
            os.system(command)
        self.default_embedding = embeddings['a']
        return embeddings

    def build_char_dict(self):
        char_dict = {}
        for word in self.words:
            for c in word:
                char_dict[c] = char_dict.get(c, 0) + 1
        with open(self.config['char_dict_file'], 'w') as f:
            for (c, freq) in list(sorted(char_dict.items(), key=lambda x:x[1], reverse=True)):
                f.write(c + '\t' + str(freq) + '\n')
        return char_dict

    def save_vocab(self, vocab_path):
        # pickle save self.vocab
        with open(vocab_path, 'wb') as f:
            pickle.dump(self.vocab, f)

    def build_vocab(self):
        """
        Steps:
            1. load/build subword vocab
            2. add special tokens to vocab
            3. add subwords to vocab
            4. add characters to vocab
        """
        self.vocab_path = self.config["vocab_path"]
        vocab_path = self.config['vocab_path']

        if (os.path.exists(vocab_path) == True):
            print ("Loading vocab from pickle file")
            self.vocab = pickle.load(open(vocab_path, 'rb'))
            return self.vocab

        tot_subword = 0
        vocab = {}
        # special tokens
        vocab['<pad>'] = tot_subword
        tot_subword += 1
        vocab['<unk>'] = tot_subword
        tot_subword += 1
        vocab['<s>'] = tot_subword
        tot_subword += 1
        vocab['</s>'] = tot_subword
        tot_subword += 1

        # subwords
        gene_bpe_vocab = "./scripts/gene_bpe_vocab.sh"
        command = f"bash {gene_bpe_vocab} {self.config['text_file']} {self.config['vocab_size']} {self.config['volt_flag']}"
        print (f'Running the command: {command}')
        os.system(command)
        self.subword_vocab_path = f"{self.config['text_file']}.{self.config['vocab_size']}.subword.vocab"
        with open(self.subword_vocab_path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            word, freq = line.strip().split()
            word = word.strip('▁').strip()
            if (vocab.get(word, -1) == -1):
                vocab[word] = tot_subword
                tot_subword += 1

        # all all characters to the vocab
        char_dict = {}
        for word in self.words:
            for c in word:
                char_dict[c] = char_dict.get(c, 0) + 1
        with open(self.config['char_dict_file'], 'w') as f:
            for (c, freq) in list(sorted(char_dict.items(), key=lambda x:x[1], reverse=True)):
                f.write(c + '\t' + str(freq) + '\n')
        self.char_dict = char_dict

        for (c, freq) in list(sorted(self.char_dict.items(), key=lambda x:x[1], reverse=True)):
            if (vocab.get(c, -1) == -1):
                vocab[c] = tot_subword
                tot_subword += 1

        # write vocab to vocab_path
        self.vocab = vocab
        self.save_vocab(vocab_path)
        return vocab

    def word_in_vocab(self, word) -> bool:
        return (self.vocab.get(word, -1) != -1)

    def collate_fn(self, batch):
        batch_size = len(batch)
        max_len = max([len(x['word_with_tokens']) for x in batch])
        freqs = [x['freq'] for x in batch]
        words = [x['word'] for x in batch]
        # make embeddings into [batch_size, embedding_size]
        embeddings = []
        embeddings = np.array([x['embedding'] for x in batch])
        # make embeddings into tensor
        embeddings = torch.tensor(embeddings)
        # pad segs_pos to max_len
        segs_poss = []
        for x in batch:
            segs_pos = x['segs_pos']
            segs_pos = np.pad(segs_pos, ((0, max_len-len(segs_pos)), (0, max_len-len(segs_pos))), 'constant', constant_values=0)
            segs_poss.append(segs_pos)
        # pad words_with_tokens to max_len (not used because the str don't go to calculation)
        words_with_tokens = []
        for x in batch:
            word_with_tokens = x['word_with_tokens']
            #word_with_tokens = np.pad(word_with_tokens, ((0, max_len-len(word_with_tokens))), 'constant', constant_values="<pad>")
            words_with_tokens.append(word_with_tokens)

        # pad words_with_tokens_idx to max_len
        words_with_tokens_idx = []
        for x in batch:
            word_with_tokens_idx = x['word_with_tokens_idx']
            word_with_tokens_idx = np.pad(word_with_tokens_idx, ((0, max_len-len(word_with_tokens_idx))), 'constant', constant_values=self.vocab["<pad>"])
            words_with_tokens_idx.append(word_with_tokens_idx)
        # turn into np
        segs_poss = np.array(segs_poss)
        words_with_tokens_idx = np.array(words_with_tokens_idx)

        return {"max_len": max_len, "words": words, "words_with_tokens": words_with_tokens, "words_with_tokens_idx": words_with_tokens_idx, "segs_poss": segs_poss, "batch_size": batch_size, "freqs": freqs, "embeddings": embeddings}