import os
import pandas as pd
import torch.utils.data as util_data
from torch.utils.data import Dataset

label2id = {"276": 0, "R276": 1, "31": 2, "R31": 2, '17': 3, 'R17': 3,
                   '47': 4, 'R47': 4, '161': 5, '36': 6, 'R36': 6, '57': 7, '40': 8, 'R40': 8, '463': 9}

class TextClustering(Dataset):
    def __init__(self, train_x, train_y):
        assert len(train_x) == len(train_y)
        self.train_x = train_x
        self.train_y = train_y

    def __len__(self):
        return len(self.train_x)

    def __getitem__(self, idx):
        return {'text': self.train_x[idx], 'label': self.train_y[idx]}

class AugmentPairSamples(Dataset):
    def __init__(self, train_x, train_x1, train_x2, train_x3,train_y):
        assert len(train_y) == len(train_x) == len(train_x1) == len(train_x2) ==len(train_x3)
        self.train_x = train_x
        self.train_x1 = train_x1
        self.train_x2 = train_x2
        self.train_x3=train_x3
        self.train_y = train_y
        
    def __len__(self):
        return len(self.train_y)

    def __getitem__(self, idx):
        return {'text': self.train_x[idx], 'text1': self.train_x1[idx], 'text2': self.train_x2[idx], 'text3':self.train_x3[idx],'label': self.train_y[idx]}


def augment_loader(args):
    train_data = pd.read_csv(args.data_path)
    train_text = train_data['tokens'].fillna('.').values
    train_text1 = train_data['bt_aug_tokens'].fillna('.').values
    train_text2 = train_data['t5_aug_tokens'].fillna('.').values
    train_text3 = train_data['swap_aug_tokens_v2'].fillna(".").values
    train_data['int_label']=train_data['label'].astype(str).map(label2id)
    train_label = train_data['int_label'].astype(int).values
    print("train rels distribution:")
    print(train_data['int_label'].value_counts())
    train_dataset = AugmentPairSamples(train_text, train_text1, train_text2,train_text3, train_label)
    train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    return train_loader

def train_unshuffle_loader(args):
    train_data = pd.read_csv(args.data_path)
    train_data['int_label']=train_data['label'].astype(str).map(label2id)
    sample_data=train_data.sample(frac=0.2,random_state=args.seed)
    print(len(sample_data))
    train_text = sample_data['tokens'].fillna('.').values
    train_label = sample_data['int_label'].astype(int).values
    print("valid rels distribution:")
    print(sample_data['int_label'].value_counts())
    train_dataset = TextClustering(train_text, train_label)
    train_loader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    return train_loader

