import librosa, os, torch, json
import pandas as pd

class DataProcessor():
    def __init__(self, tokenizer, hparams, max_token_len=512):
        self.max_token_len = max_token_len
        self.use_speaker_ids = hparams.use_speaker_ids
        self.memory_size = hparams.memory_size
        self.tokenizer = tokenizer
        self.ids_dic = {
            "[PAD]":tokenizer.vocab["[PAD]"],
            "상담자":len(tokenizer.vocab),
            "내담자":len(tokenizer.vocab)+1
        }
    
    def get_nerual_input(self, texts, speaker_ids, is_history=False):
        # Handle missing or empty text data.
        if(pd.isna(texts)):
            texts, speaker_ids = "", "상담자"

        # Split the input texts and speaker information into lines.
        texts = texts.split("\n")
        speaker_ids = speaker_ids.split("\n")

        # If considering history and memory size is limited, truncate the input.
        if(is_history and self.memory_size != 0):
            texts = texts[:self.memory_size]
            speaker_ids = speaker_ids[:self.memory_size]

        # Tokenize and add speaker IDs to the input texts.
        token_ids = self.tokenize(texts)
        token_ids = self.add_speaker_ids(token_ids, speaker_ids)

        results = {"input_ids":[], "input_mask":[], "token_len":0, "history_size":len(texts)}

        # Process each line's token IDs.
        for ids in token_ids:
            results["input_mask"].append(self.make_mask(ids))
            results["input_ids"].append(self.add_pad(ids))
            results["token_len"] = max(results["token_len"], len(ids))

        return results

    def tokenize(self, texts):
        token_ids = []
        for text in texts:
            _token_id = self.tokenizer.wordpiece_tokenizer.tokenize(text)
            _token_id = ["[CLS]"] + _token_id
            _token_id = self.tokenizer.convert_tokens_to_ids(_token_id)
            _token_id = _token_id[:self.max_token_len-self.use_speaker_ids]
            token_ids.append(_token_id)
        return token_ids
    
    def add_speaker_ids(self, token_ids, speaker_ids):
        if(not self.use_speaker_ids):
            return token_ids
        
        for token, speaker in zip(token_ids, speaker_ids):
            token.insert(1, self.ids_dic[speaker])
        return token_ids

    def make_mask(self, ids):
        T = len(ids)
        N = self.max_token_len - T
        mask = [1] * T + [0]*N
        return mask

    def add_pad(self, ids):
        T = len(ids)
        N = self.max_token_len - T
        return ids + [self.ids_dic["[PAD]"]]*N
