import os
import sys
import torch
import logging

sys.path.append('.')
sys.path.append('../')

logger = logging.getLogger()

from tools.utils import matchEmb, centeringFn
from tools.ioFn import readJsonl

def createData(data_infos: list):
    """
    convert data_infos to the list of X, Y
    """
    x_datas = [torch.tensor(info['embedding'], dtype=torch.float32) for info in data_infos]
    x_datas = torch.stack(x_datas, dim=0)
    y_datas = torch.tensor([info['is_oracle'] for info in data_infos], dtype=torch.float32)
    doc_ids = torch.tensor([int(info['id'].split('_')[0]) for info in data_infos], dtype=torch.long)
    return x_datas, y_datas, doc_ids


def load_data(valid_size, input_dir, embfile, datafile, centering=False):    
    data_infos = readJsonl(os.path.join(input_dir, datafile))
    train_set = data_infos[:-valid_size]
    dev_set = data_infos[-valid_size:]

    train_matches = matchEmb(
        data_infos=train_set,
        input_dir=input_dir,
        embkey=embfile
    )

    num_train = len(train_matches)
    dev_matches = matchEmb(
        data_infos=dev_set,
        input_dir=input_dir,
        embkey=embfile
    )

    if centering:
        logger.warning("Centering is applied ...")
        concat_matches = train_matches + dev_matches
        concat_matches = centeringFn(concat_matches)
        train_matches = concat_matches[:num_train]
        dev_matches = concat_matches[num_train:]

    num_dev = len(dev_matches)
    logger.warning("Size of the train set: {}".format(num_train))
    logger.warning("Size of the dev set: {}".format(num_dev))

    train_xs, train_ys, train_doc_ids = createData(train_matches)
    dev_xs, dev_ys, dev_doc_ids = createData(dev_matches)
    return train_xs, train_ys, train_doc_ids, dev_xs, dev_ys, dev_doc_ids
