import copy
import ipdb
import json
import tqdm
import random
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

def make_data_split_3way(args):

    all_data = PreDataset(args.all_vectors, args.all_labels)
    n_train = len(all_data)

    X, y, texts = [], [], []

    for ii in tqdm.tqdm(range(n_train), desc='Making data splits'):
        vector, label, text = all_data[ii]
        X.append(vector)
        y.append(label)
        texts.append(text)
        #ipdb.set_trace()

    val_ratio = args.val_ratio * 100 / ((1 - args.test_ratio) * 100)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_ratio, random_state=args.seed)
    check_X_train, check_X_test, text_train, text_test = train_test_split(X, texts, test_size=args.test_ratio, random_state=args.seed)
    assert check_X_train == X_train
    assert check_X_test == X_test
    assert len(y_train) == len(text_train)
    assert len(y_test) == len(text_test)

    step1_X_train = copy.deepcopy(X_train)

    X_train, X_val, y_train, y_val = train_test_split(step1_X_train, y_train, test_size=val_ratio, random_state=args.seed)
    check_X_train, check_X_val, text_train, text_val = train_test_split(step1_X_train, text_train, test_size=val_ratio, random_state=args.seed)
    assert check_X_train == X_train
    assert check_X_val == X_val
    assert len(y_train) == len(text_train)
    assert len(y_val) == len(text_val)

    #ipdb.set_trace()

    train = MainDataset(X_train, y_train, text=text_train)
    val = MainDataset(X_val, y_val, text=text_val)
    test = MainDataset(X_test, y_test, text=text_test)

    return train, val, test

def _make_data_split_3way(args):
    all_data = PreDataset(args.all_vectors, args.all_labels)
    filenames = list(range(len(all_data)))
    filenames.sort()  # make sure that the filenames have a fixed order before shuffling
    random.seed(args.seed)
    random.shuffle(filenames) # shuffles the ordering of filenames (deterministic given the chosen seed)

    split_1 = int(0.8 * len(filenames))
    split_2 = int(0.9 * len(filenames))
    train_filenames = filenames[:split_1]
    dev_filenames = filenames[split_1:split_2]
    test_filenames = filenames[split_2:]
    ipdb.set_trace()

class PreDataset(object):
    def __init__(self, vectors_path, labels_path):
        self.data = open(vectors_path).readlines()
        self.labels = open(labels_path).readlines()
        self.text = [item.split('\t')[1] for item in self.labels]
        self.labels = [int(item.split('\t')[0]) for item in self.labels]
        #ipdb.set_trace()

    def __getitem__(self, idx):
        arr = convert_str_to_array(self.data[idx])
        label = self.labels[idx]
        text = self.text[idx]
        #ipdb.set_trace()
        return arr, label, text

    def __len__(self):
        return len(self.data)

class MainDataset(Dataset):
    def __init__(self, features, labels, text=None):
        self.data = features
        self.labels = labels
        if text is None:
            self.text = None
        else:
            self.text = text

    def __getitem__(self, idx):
        arr = self.data[idx]
        label = self.labels[idx]
        if self.text is None:
            return arr, label
        else:
            return arr, label, self.text[idx]

    def __len__(self):
        return len(self.data)

    def label_counts(self):
        n_pos = 0
        n_neg = 0
        for item in self.labels:
            if item == 0:
                n_neg += 1
            elif item == 1:
                n_pos += 1
        return n_pos, n_neg


def make_data_split(args):
    all_data = PreDataset(args.all_vectors, args.all_labels)
    n_train = len(all_data)

    X, y = [], []

    for ii in tqdm.tqdm(range(n_train), desc='Making data splits'):
        vector, label = all_data[ii]
        X.append(vector)
        y.append(label)
        #ipdb.set_trace()

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.06136, random_state=args.seed)

    train = MainDataset(X_train, y_train)
    test = MainDataset(X_test, y_test)

    #ipdb.set_trace()
    return train, test



def convert_str_to_array(text):
    arr = text.rstrip('\n').split(',')
    assert len(arr) == 768
    arr = np.array(arr, dtype=np.float32)
    #ipdb.set_trace()
    return arr

