# Erik McGuire, 2021

import pandas as pd
import numpy as np
import sys
import re

from zuco_params import *

def debug(x):
    """Clumsy fn to step through script."""
    if type(x) == tuple:
        for i in x:
            print(f'\n{i}\n')
    else:
        print(f'\n{x}\n')
    sys.exit()

def shuffle_scores(df, t:str = 'agg'):
    """Randomly swap scores for pieces in samples."""
    random.seed(args.seed)
    if t == 'avg':
        random_scores = df.piece_score.values.tolist()
        random.shuffle(random_scores)
        df.piece_score = random_scores
    else:
        random_scores = df[df.score != 0].score.values.tolist()
        random.shuffle(random_scores)
        for idx, ix in enumerate(df[df.score != 0].index):
            df.loc[ix, 'score'] = random_scores[idx]
    return df

def shuffle_scores_b(df, t:str = 'agg'):
    """Randomly sample scores for pieces in samples."""
    random.seed(args.seed)
    if t == 'avg':
        ps_list = df.piece_score.values.tolist()
        mn = min(ps_list)
        mx = max(ps_list)
        random_scores = [random.uniform(mn, mx) for _ in ps_list]
        df.piece_score = random_scores
    else:
        ps_list = df[df.score != 0].score.values.tolist()
        mn = min(ps_list)
        mx = max(ps_list)
        random_scores = [random.uniform(mn, mx) for _ in ps_list]
        for idx, ix in enumerate(df[df.score != 0].index):
            df.loc[ix, 'score'] = random_scores[idx]
    return df

def get_new_vals(df, mod, t):
    """Convert string to float."""
    dt = np.float64
    if t not in ['sum', 'avg']:
        vals = 'eegvals' if mod == 'eeg' else 'etvals'
    else:
        vals = 'avgvals' if t == 'avg' else 'summedvals'
    new_vals = np.zeros_like(df[vals].values)
    for ix, val in enumerate(df[vals].values):
        if "nan" in val:
            if mod == "gaze":
                shape = (1, 4)
            else:
                shape = (1, 104)
            v = np.zeros(shape=shape)
        else:
            val = val.replace("\n", " ")
            valr = re.sub(r" +", r", ", val)
            v = eval(valr.replace("[,", "["))
        new_vals[ix] = np.array(v, dtype=dt)
    return new_vals

def get_new_vals_hot(df, mod, t):
    """Convert string to float."""
    dt = np.int64
    vals = 'label_one_hot'
    new_vals = np.zeros_like(df[vals].values)
    for ix, val in enumerate(df[vals].values):
        if type(val) == float or "nan" in val:
            shape = (1, 11)
            v = np.zeros(shape=shape)
        else:
            val = val.replace("\n", " ")
            valr = re.sub(r" +", r", ", val)
            v = eval(valr.replace("[,", "["))
        new_vals[ix] = np.array(v, dtype=dt)
    return new_vals

def load_df(pth, mod, t: str = ""):
    df = pd.read_csv(pth, sep="\t")
    try:
        df = df.drop(columns='Unnamed: 0')
    except KeyError:
        pass
    if not "piece" in pth:
        df.iloc[:, 3] = get_new_vals(df, mod, t)
    else:
        df.iloc[:, 6] = get_new_vals(df, mod, t)
        df.iloc[:, 9] = get_new_vals_hot(df, mod, "pieced")
    return df
