from torch.utils.data import Dataset
import os
import torch
from transformers import BertTokenizer


class FairClassificationDataset(Dataset):

    def __init__(self, args, split):
        """
        Args:
            args.dataset : which dataset to load
        """
        assert split in ['train', 'test', 'val']
        self.args = args
        self.reader = reader[args.dataset]
        self.split = self.reader.get_split(split)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.args.sos_token = self.tokenizer.sep_token_id
        self.args.number_of_tokens = self.tokenizer.vocab_size
        self.args.tokenizer = self.tokenizer
        self.args.padding_idx = self.tokenizer.pad_token_id
        self.tokenized_text = self.tokenize_split()
        self.length = len(self.tokenized_text['input_ids'])
        self.args.number_of_public_labels = self.reader.number_public_label
        self.args.number_sensitive_label = self.reader.number_sensitive_label

    def tokenize_split(self):
        total_length = len(self.split['texts'])
        if self.args.reduce_training_size and self.split == 'train':
            return self.tokenizer(self.split['texts'][:int(total_length * self.args.filter)], return_tensors="pt",
                                  padding=True,
                                  truncation=True, max_length=self.reader.max_length)
        elif self.args.reduce_training_size and self.split != 'train':
            return self.tokenizer(self.split['texts'][:-1], return_tensors="pt", padding=True,
                                  truncation=True, max_length=self.reader.max_length)
        else:
            return self.tokenizer(self.split['texts'][:int(total_length * self.args.filter)], return_tensors="pt",
                                  padding=True,
                                  truncation=True, max_length=self.reader.max_length)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {"text": {
            'input_ids': self.tokenized_text['input_ids'][idx].to(self.args.device),
            'token_type_ids': self.tokenized_text['token_type_ids'][idx].to(self.args.device),
            'attention_mask': self.tokenized_text['attention_mask'][idx].to(self.args.device)
        },
            "sensitive_label": torch.tensor(self.split['sensitive_label'][idx]).to(self.args.device),
            "public_label": torch.tensor(self.split['public_label'][idx]).to(self.args.device)}


class PanReader:
    ''''
    Main : Mention
    Sensistive : Age gender
    '''

    def __init__(self, private_label):
        assert private_label in ['age', 'gender']
        self.data_path = 'data/pan/tweet_{}_processed'.format(private_label)
        self.number_sensitive_label = 2
        self.number_public_label = 2
        self.max_length = 80

    def get_split(self, split):
        with open(os.path.join(self.data_path, 'x_{}'.format(split)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split('\t') for line in lines][1:]
        lines = [line for line in lines if len(line) == 3]
        texts = [line[0] for line in lines]
        sensitive_label = [int(line[-1]) for line in lines]
        public_label = [int(line[1]) for line in lines]
        return {'texts': texts, 'sensitive_label': sensitive_label, 'public_label': public_label}


class BlogReader:
    """
    Main : Topic
    Sensitive : Age, gender
    """

    def __init__(self, private_label):
        assert private_label in ['age', 'gender']
        self.data_path = 'data/blogs/{}_topic'.format(private_label)
        self.number_sensitive_label = 2
        self.number_public_label = 10
        self.max_length = 400

    def get_split(self, split):
        with open(os.path.join(self.data_path, '{}.txt'.format(split)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split('\t') for line in lines][1:]
        texts = [line[0] for line in lines]
        sensitive_label = [int(line[-1]) for line in lines]
        public_label = [int(line[1]) for line in lines]
        return {'texts': texts, 'sensitive_label': sensitive_label, 'public_label': public_label}


class BioReader:
    """
    Main : professor
    Sensitive : gender
    """

    def __init__(self):
        self.data_path = 'data/biais_bios/bio_processed/'
        self.number_sensitive_label = 2
        self.number_public_label = 28
        self.max_length = 140  # 5% truncate

    def get_split(self, split):
        with open(os.path.join(self.data_path, '{}.txt'.format(split)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split('\t') for line in lines][1:]
        lines = [line for line in lines if len(line) == 3]
        texts = [line[0] for line in lines]
        sensitive_label = [int(line[-1]) for line in lines]
        public_label = [int(line[1]) for line in lines]
        return {'texts': texts, 'sensitive_label': sensitive_label, 'public_label': public_label}


class TrustPilotReader:
    """
    Main : sentiment
    Sensitive : age, gender
    """

    def __init__(self, private_label):
        assert private_label in ['age', 'gender']
        self.data_path = 'data/trust_pilot/blog_{}_processed'.format(private_label)
        self.number_sensitive_label = 2
        self.number_public_label = 5
        self.max_length = 140  # 5% truncate

    def get_split(self, split):
        if split == 'val':
            split = 'valid'
        with open(os.path.join(self.data_path, '{}.txt'.format(split)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split('\t') for line in lines][1:]  # small bug
        lines = [line for line in lines if len(line) == 3]
        texts = [line[0] for line in lines]
        sensitive_label = [int(line[-1]) for line in lines]
        public_label = [int(line[1]) - 1 for line in lines]
        return {'texts': texts, 'sensitive_label': sensitive_label, 'public_label': public_label}


class DialReader:
    """"
    Main : mention, sentiment
    Sensitive : race
    """

    def __init__(self, main_label):
        assert main_label in ['mention', 'sentiment']
        self.data_path = 'data/dial/dial_{}_race_processed'.format(main_label)
        self.number_sensitive_label = 2
        self.number_public_label = 2
        self.max_length = 65

    def get_split(self, split):
        with open(os.path.join(self.data_path, 'x_{}'.format(split)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split('\t') for line in lines][1:]
        lines = [line for line in lines if len(line) == 3]
        texts = [line[0] for line in lines]
        sensitive_label = [int(line[-1]) for line in lines]
        public_label = [int(line[1]) for line in lines]
        return {'texts': texts, 'sensitive_label': sensitive_label, 'public_label': public_label}


reader = {
    'blog_age': BlogReader('age'),
    'blog_gender': BlogReader('gender'),
    'pan_age': PanReader('age'),
    'pan_gender': PanReader('gender'),
    'dial_mention': DialReader('mention'),
    'dial_sentiment': DialReader('sentiment'),
    'trust_age': TrustPilotReader('age'),
    'trust_gender': TrustPilotReader('gender'),
    'bio': BioReader()
}
if __name__ == '__main__':
    reader = DialReader('sentiment')
    reader.get_split('train')
