import pathlib
from typing import Optional
from loguru import logger
import IPython
import pandas as pd
base = pathlib.Path('/nas/home/qasemi/Mowgli-CoreQuisite/outputs/KG')

assert base.is_dir()

df_data: Optional[pd.DataFrame] = None


for lm_fold in base.iterdir():
    if not lm_fold.is_dir() or '.ipynb_checkpoints' in lm_fold.name:
        continue
    lm_name = lm_fold.name.split('_')[-1]
    samples: pd.DataFrame = (
        pd.read_csv(
            lm_fold/'Prompts.csv', sep=',',
            usecols=['object', 'subject', '0', 'Unnamed: 7']
        )
        .rename({'0': lm_name, 'Unnamed: 7': 'type'}, axis=1)
        .set_index(keys=['subject', 'object', 'type'])
    )
    samples = samples[~ samples.index.duplicated()]
    # IPython.embed()
    # exit()
    logger.info(f'read {len(samples)} rows of data for {lm_name}')
    if df_data is None:
        df_data = samples
    else:
        df_data[lm_name] = samples


def process_row(row: pd.Series):
    lms = row.index
    # sents
    # row.values.tolist()
    # store the names (the keys of the new dict) as a set (keeps elements unique)
    names = set(row.values.tolist())

    # use a list comprehension, iterating through keys and checking the values match each n
    d = {}
    for n in names:
        d[n] = [k for k in lms if row[k] == n]
    logger.info(d)
    return pd.Series(d)

IPython.embed()