import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import torch, librosa, os
from collections import Counter
from .data_processing import *

bc_category2id = {
    "NoBC": 0,
    "continuer": 1,
    "understanding": 2,
    "negative surprise": 3,
    "positive surprise": 3,
    "request confirmation": 3,
    "affirmative": 3
}

class BackChannelDataset(Dataset):
    def __init__(self, text_df, tokenizer, hparams, audio_path):
        
        self.hparams = hparams
        self.audio_path = audio_path
        self.with_audio = hparams.accustic_feature != ""
        self.with_history = hparams.memory_size != 0
        self.file_index = text_df.file_index
        self.file = text_df.file

        data_processor = DataProcessor(tokenizer, hparams)
        
        self.current_input = text_df.apply(
            lambda x:data_processor.get_nerual_input(
                x.current_text, 
                x.current_talker_type
            ),
            axis=1
        )

        if(hparams.memory_size != 0):
            self.history_input = text_df.apply(
                lambda x:data_processor.get_nerual_input(
                    x.history,
                    x.history_talker_type,
                    is_history=True
                ),
                axis=1
            )

        if(self.with_audio):
            self.audio_signal = self.current_input.apply(lambda x:None)

        self.target = text_df["bc_catetory"].apply(lambda x:bc_category2id[x])
        self.bc_d_nums = Counter(self.target)
    
    def __len__(self):
        return len(self.file_index)
    
    def __getitem__(self, idx):
        results = {
            "cur_input_ids":torch.tensor(self.current_input[idx]["input_ids"]),
            "cur_input_mask":torch.tensor(self.current_input[idx]["input_mask"]),
            "cur_token_len":torch.tensor(self.current_input[idx]["token_len"]),
            "history_size":torch.tensor(0),
            "target":torch.tensor(self.target[idx])
        }   

        if(self.with_audio):
            results["accustic_feature"] = self.load_audio(idx)
        
        if(self.with_history):
            results["his_input_ids"] = torch.tensor(self.history_input[idx]["input_ids"])
            results["his_input_mask"] = torch.tensor(self.history_input[idx]["input_mask"])
            results["his_token_len"] = torch.tensor(self.history_input[idx]["token_len"])
            results["history_size"] = torch.tensor(self.history_input[idx]["history_size"])

        return results
    
    def load_audio(self, idx):
        accustic_feature = self.audio_signal[idx]
        if(accustic_feature is None):
            accustic_feature, _ = librosa.load(
                os.path.join(self.audio_path, self.file[idx], f"{self.file_index[idx]}.wav"), 
                sr=None
            )
            self.audio_signal[idx] = accustic_feature
        return torch.tensor(accustic_feature)
