import glob
import torch
import numpy as np
import pandas as pd
import argparse
import os
from pathlib import Path
import sys
import tqdm

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--metric', type=str, default='', choices=['accuracy', 'crossentropy', 'brierscore', 'other'], help='Metric to use for evaluation')
    parser.add_argument('--datatype', type=str, default='', choices=['language', 'vision'], help='Type of data to process')
    parser.add_argument('--setting', type=str, default='', choices=['all', 'supervised', 'unsupervised'], help='Predefined groups or not')
    parser.add_argument('--base_dir', type=str, default='', help='Base directory for results')
    return parser.parse_args()

def get_files(base_dir, datatype, metric, setting):
    if setting != "all":
        return glob.glob(f"{base_dir}/{datatype}/{metric}/{setting}/*.pt")
    else:
        return glob.glob(f"{base_dir}/{datatype}/{metric}/**/*.pt")

def collect_estimates(df):
    store_est = []
    for iter in df:
        for method, estimates in df[iter]['estimates'].items():
            for g in estimates:
                store_est.append({'method': method, 'group': g, 'estimate': estimates[g], 'iter': iter})
    return store_est

def collect_ground_truth(df):
    store_gt = []
    for iter in df:
        for g, gt in df[iter]['gt'].items():
            store_gt.append({'group': g, 'gt': gt, 'iter': iter})
    return store_gt

def collect_counts(df, count_type):
    store_count = []
    for g, count in df[0][count_type].items():
        store_count.append({'group': g, f'{count_type}_count': count})
    return store_count

def process_file(filepath):
    df = torch.load(filepath)
    
    store_est = pd.DataFrame(collect_estimates(df))
    store_gt = pd.DataFrame(collect_ground_truth(df))
    store_train_count = pd.DataFrame(collect_counts(df, 'train_counts'))
    # store_count = pd.DataFrame(collect_counts(df, 'counts'))
    
    store_filepath = store_est.merge(store_gt, how='inner')
    # store_filepath = store_filepath.merge(store_count, how='inner')
    store_filepath = store_filepath.merge(store_train_count, how='inner')
    
    dataset, model = filepath.stem.split('-', 1)
    setting = filepath.parts[-2]
    
    store_filepath['dataset'] = dataset
    store_filepath['model'] = model
    store_filepath['setting'] = setting
    
    return store_filepath

def main():
    args = parse_arguments()

    save_dir = Path("results/consolidated")
    save_dir.mkdir(parents=True, exist_ok=True)

    base_dirs = [args.base_dir] if args.base_dir else glob.glob('results/estimates/*')
    datatypes = [args.datatype] if args.datatype else ['language', 'vision']
    metrics = [args.metric] if args.metric else ['accuracy', 'crossentropy', 'brierscore', 'other']
    settings = [args.setting] if args.setting else ['all']

    for base_dir in base_dirs:
        for datatype in datatypes:
            for metric in metrics:
                for setting in settings:
                    print(f"Processing: Metric={metric}, Datatype={datatype}, Setting={setting}, Base directory={base_dir}")
                    files = get_files(base_dir, datatype, metric, setting)
                    
                    if not files:
                        print(f"No files found for: Metric={metric}, Datatype={datatype}, Setting={setting}, Base directory={base_dir}")
                        continue

                    store = []
                    for filepath in tqdm.tqdm(files):
                        store.append(process_file(Path(filepath)))
                    
                    if store:
                        store_df = pd.concat(store)
                        config_name = f"{Path(base_dir).name}_{datatype}_{metric}"
                        output_file = save_dir / f"{config_name}_detailed_results.csv"
                        store_df.to_csv(output_file, index=False)

if __name__ == "__main__":
    main()

