import torch
from config import *
from augmentation import *
import numpy as np
from tqdm import trange
import os

def parse_data_fast(file_path, tokenizer, sequence_len, token_style, punctuation_dict):
    """
    :param file_path: text file path that contains tokens and punctuations separated by tab in lines
    :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models
    :param sequence_len: maximum length of each sequence
    :param token_style: For getting index of special tokens in config.TOKEN_IDX
    :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len
    punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation
    """
    cache_file = ""

    data_items = []
    with open(file_path, 'r', encoding='utf-8') as f:
        print(f"Loading '{file_path}'")
        lines = [line.replace("\n", "") for line in f.read().split('\n') if '\t' in line]

        idx = 0
        # loop until end of the entire text
        for _dummy_ in trange(len(lines) // (sequence_len//2)): # Very rough tqdm steps estimation assuming average word consisting of 2 WordPice/BPE tokens
            if not idx < len(lines):
                break

            x = np.ones(sequence_len, dtype=np.long) * TOKEN_IDX[token_style]['PAD']
            x[0] = TOKEN_IDX[token_style]['START_SEQ']

            attn_mask = np.zeros(sequence_len, dtype=np.long)
            attn_mask[0] = 1

            y = np.zeros(sequence_len, dtype=np.long)
            y_mask = np.zeros(sequence_len)
            y_mask[0] = 1

            index = 1

            #x = [TOKEN_IDX['PAD']] * sequence_len

            #x = [TOKEN_IDX[token_style]['START_SEQ']]
            #y = [0]
            #y_mask = [1]  # which positions we need to consider while evaluating i.e., ignore pad or sub tokens

            # loop until we have required sequence length
            # -1 because we will have a special end of sequence token at the end
            while index < sequence_len - 1 and idx < len(lines):
                parts = lines[idx].split('\t')

                word = parts[0]
                punc = "\t".join(parts[1:])
                # word, punc = lines[idx].split('\t')
                
                tokens = tokenizer.tokenize(word)
                # if taking these tokens exceeds sequence length we finish current sequence with padding
                # then start next sequence from this token
                if len(tokens) + index >= sequence_len:
                    break
                else:
                    for i in range(len(tokens) - 1):
                        x[index] = tokenizer.convert_tokens_to_ids(tokens[i])
                        attn_mask[index] = 1
                        #y.append(0)
                        #y_mask.append(0)
                        index += 1
                    if len(tokens) > 0:
                        x[index] = tokenizer.convert_tokens_to_ids(tokens[-1])
                    else:
                        x[index] = TOKEN_IDX[token_style]['UNK']

                    if punc in punctuation_dict.keys():
                        y[index] = punctuation_dict[punc]
                    else:
                        y[index] = punctuation_dict[' ']
                    
                    y_mask[index] = 1
                    attn_mask[index] = 1
                    index += 1
                    idx += 1
            x[index] = TOKEN_IDX[token_style]['END_SEQ']
            #y[index] = append(0)
            y_mask[index] = 1
            #if len(x) < sequence_len:
            #    x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
            #    y = y + [0 for _ in range(sequence_len - len(y))]
            #    y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
            #attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]

            data_items.append([x, y, attn_mask, y_mask])

        print(f"Processed total of {len(data_items)} examples")
    return data_items


def parse_data(file_path, tokenizer, sequence_len, token_style):
    """

    :param file_path: text file path that contains tokens and punctuations separated by tab in lines
    :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models
    :param sequence_len: maximum length of each sequence
    :param token_style: For getting index of special tokens in config.TOKEN_IDX
    :return: list of [tokens_index, punctuation_index, attention_masks, punctuation_mask], each having sequence_len
    punctuation_mask is used to ignore special indices like padding and intermediate sub-word token during evaluation
    """
    data_items = []
    with open(file_path, 'r', encoding='utf-8') as f:
        print(f"Loading '{file_path}'")
        lines = [line for line in f.read().split('\n') if line.strip()]
        idx = 0
        # loop until end of the entire text
        for _dummy_ in trange(len(lines)):
            if not idx < len(lines):
                break

            x = np.ones(sequence_len) * TOKEN_IDX[token_style]['PAD']
            x[0] = TOKEN_IDX[token_style]['START_SEQ']

            y = np.zeros(sequence_len)
            y_mask = np.zeros(sequence_len)
            y_mask[0] = 1

            index = 1

            #x = [TOKEN_IDX['PAD']] * sequence_len

            #x = [TOKEN_IDX[token_style]['START_SEQ']]
            #y = [0]
            #y_mask = [1]  # which positions we need to consider while evaluating i.e., ignore pad or sub tokens

            # loop until we have required sequence length
            # -1 because we will have a special end of sequence token at the end
            while len(x) < sequence_len - 1 and idx < len(lines):
                word, punc = lines[idx].split('\t')
                tokens = tokenizer.tokenize(word)
                # if taking these tokens exceeds sequence length we finish current sequence with padding
                # then start next sequence from this token
                if len(tokens) + len(x) >= sequence_len:
                    break
                else:
                    for i in range(len(tokens) - 1):
                        x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
                        y.append(0)
                        y_mask.append(0)
                    if len(tokens) > 0:
                        x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
                    else:
                        x.append(TOKEN_IDX[token_style]['UNK'])
                    y.append(punctuation_dict[punc])
                    y_mask.append(1)
                    idx += 1
            x.append(TOKEN_IDX[token_style]['END_SEQ'])
            y.append(0)
            y_mask.append(1)
            if len(x) < sequence_len:
                x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
                y = y + [0 for _ in range(sequence_len - len(y))]
                y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
            attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]

            data_items.append([x, y, attn_mask, y_mask])
    return data_items


class Dataset(torch.utils.data.Dataset):
    def __init__(self, files, tokenizer, sequence_len, token_style, punctuation_dict, is_train=False, augment_rate=0.1,
                 augment_type='substitute'):
        """

        :param files: single file or list of text files containing tokens and punctuations separated by tab in lines
        :param tokenizer: tokenizer that will be used to further tokenize word for BERT like models
        :param sequence_len: length of each sequence
        :param token_style: For getting index of special tokens in config.TOKEN_IDX
        :param augment_rate: token augmentation rate when preparing data
        :param is_train: if false do not apply augmentation
        """
        if isinstance(files, list):
            self.data = []
            for file in files:
                self.data += parse_data_fast(file, tokenizer, sequence_len, token_style, punctuation_dict)
        else:
            self.data = parse_data_fast(files, tokenizer, sequence_len, token_style, punctuation_dict)
        self.sequence_len = sequence_len
        self.augment_rate = augment_rate
        self.token_style = token_style
        self.is_train = is_train
        self.augment_type = augment_type

    def __len__(self):
        return len(self.data)

    def _augment(self, x, y, y_mask):
        x_aug = []
        y_aug = []
        y_mask_aug = []
        for i in range(len(x)):
            r = np.random.rand()
            if r < self.augment_rate:
                AUGMENTATIONS[self.augment_type](x, y, y_mask, x_aug, y_aug, y_mask_aug, i, self.token_style)
            else:
                x_aug.append(x[i])
                y_aug.append(y[i])
                y_mask_aug.append(y_mask[i])

        if len(x_aug) > self.sequence_len:
            # len increased due to insert
            x_aug = x_aug[0:self.sequence_len]
            y_aug = y_aug[0:self.sequence_len]
            y_mask_aug = y_mask_aug[0:self.sequence_len]
        elif len(x_aug) < self.sequence_len:
            # len decreased due to delete
            x_aug = x_aug + [TOKEN_IDX[self.token_style]['PAD'] for _ in range(self.sequence_len - len(x_aug))]
            y_aug = y_aug + [0 for _ in range(self.sequence_len - len(y_aug))]
            y_mask_aug = y_mask_aug + [0 for _ in range(self.sequence_len - len(y_mask_aug))]

        attn_mask = [1 if token != TOKEN_IDX[self.token_style]['PAD'] else 0 for token in x]
        return x_aug, y_aug, attn_mask, y_mask_aug

    def __getitem__(self, index):
        x = self.data[index][0]
        y = self.data[index][1]
        attn_mask = self.data[index][2]
        y_mask = self.data[index][3]

        if self.is_train and self.augment_rate > 0:
            x, y, attn_mask, y_mask = self._augment(x, y, y_mask)

        x = torch.tensor(x)
        y = torch.tensor(y)
        attn_mask = torch.tensor(attn_mask)
        y_mask = torch.tensor(y_mask)

        return x, y, attn_mask, y_mask
