import argparse
import glob
import os
import numpy as np
import pandas as pd
import torch
import tqdm
from sklearn.linear_model import LassoCV
from statsmodels.stats.proportion import proportion_confint
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
import sys

import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter("ignore", ConvergenceWarning)
warnings.filterwarnings("ignore")

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from prinfer import DataBalancer, DataProcessor, Estimator
from helper import *
from concurrent.futures import ProcessPoolExecutor, as_completed


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--metric', type=str, default='accuracy', 
                        choices=['accuracy', 'crossentropy', 'brierscore', 'other'], 
                        help='Metric to use for evaluation')
    parser.add_argument('--datatype', type=str, default='language', 
                        choices=['language', 'vision'], 
                        help='Type of data to process')
    parser.add_argument('--mingroupsize', type=int, default=20, 
                        choices=[5, 10, 15, 20, 25, 30, 50, 100], 
                        help='Minimum group size')
    parser.add_argument('--sampling', type=str, default='equal', 
                        choices=['equal', 'prop'], 
                        help='Sampling method')
    parser.add_argument('--loco', type=str, default='all-in', 
                        choices=['all-in', 'exclude-embeddings', 'exclude-modeldummy', 'exclude-taskdummy'], 
                        help='Covariates to be excluded')
    parser.add_argument('--compute_cis', action='store_true', 
                        help='Compute confidence intervals')
    parser.add_argument('--n_iterations', type=int, default=100, 
                        help='Number of iterations')

    return parser.parse_args()


def get_files(datatype):
    if datatype == "vision" or datatype == "language":
        all_files = glob.glob("data/predictions/**/*.pt")
        if datatype == "vision":
            all_files = [x for x in all_files if 'allmodels' in x]
        elif datatype == "language":
            all_files = [x for x in all_files if 'Phi-3' in x] # Phi-3 contains all other LLMs as well 
    else:
        raise ValueError("Data type not defined!")
    return all_files


def process_iteration(iter, dr, mingroupsize, sampling, unique_groups, compute_cis, loco):

    if sampling == 'prop':
        dr.subsample_proportional(min_group_size=mingroupsize, seed = iter)
    else:
        dr.subsample_allequal(min_group_size=mingroupsize, seed = iter)

    (
        X_train,
        groups_train,
        z_train,
        pz_train,
    ) = (
        dr.X_train,
        dr.groups_train,
        dr.z_train,
        dr.pz_train,
    )
    
    X_train, z_train, groups_train, pz_train = expand_data(X_train, z_train, groups_train, pz_train, loco)

    
    estimates = {}
    cis = {}
    
    n_groups = len(np.unique(groups_train))
    regmodel = LassoCV(alphas = np.logspace(-20, 20, 20),
                    cv=min(10, n_groups), 
                    precompute=True, fit_intercept = True)
    
    est = Estimator(
        X=X_train,
        z=z_train,
        groups=groups_train,
        reg=regmodel)
    # estimators
    estimates["DT"] = est.z
    
    if compute_cis:
        from statsmodels.stats.proportion import proportion_confint
        cis["DT"] = {g: proportion_confint(est.n_groups[g] * est.z[g], est.n_groups[g], alpha=0.05, method='wilson') for g in est.z}
        
    estimates["GM"] = est.estimate_gm()
    estimates["JS"] = est.estimate_js()
    estimates["SR"] = est.preds
    # estimates["STR"] = est.estimate_structured_regression()
    estimates['Baseline-pz'] = {g: pz_train[groups_train == g].mean() for g in unique_groups}
    estimates['EB'] = est.estimate_eb(reg = regmodel,
                                        compute_cis = compute_cis
                                        )
    estimates['EBreg'] = est.preds_ebreg

    return {
        "estimates": estimates,
        "cis": cis,
        "train_counts": est.n_groups,
        "counts": n_groups
    }
    

def construct_save_dir(loco, sampling, mingroupsize, datatype, metric, setting):
    save_dir = f"results/estimates/"
    save_dir += f"{'loco-' + loco + '_' if loco != 'all-in' else ''}"
    save_dir += f"{'sampling-' + sampling + '_' if sampling != 'equal' else ''}"
    save_dir += f"{'mingroupsize-' + str(mingroupsize) + '/' if mingroupsize != 5 else ''}"
    save_dir += f"{datatype}/{metric}/{setting}"
    save_dir = os.path.join(parent_dir, save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    return save_dir
    

def preprocess_data(dp, metric, filepath):
    dp.format_data()

    metric_methods = {
        "accuracy": dp.estimate_z_binacc,
        "brierscore": dp.estimate_z_brierscore,
        "crossentropy": dp.estimate_z_crossentropy
    }

    # Estimate z based on the provided metric
    metric_method = metric_methods.get(metric, dp.estimate_general_metric)
    metric_method()

    # Handle group creation or concatenation
    if dp.groups is None:
        min_clusters = 10 if 'anli' in filepath else 30
        max_clusters = 10 if 'anli' in filepath else 100
        dp.create_groups(clustering_algorithm=create_groups_with_kmeans, min_clusters=min_clusters, max_clusters=max_clusters)
    else:
        dp.concatenate_group_dimensions()


def process_file(filepath, args):

    metric = args.metric
    datatype = args.datatype
    mingroupsize = args.mingroupsize
    sampling = args.sampling
    loco = args.loco
    compute_cis = args.compute_cis
    n_iterations = args.n_iterations
    
    print(filepath)

    df = torch.load(filepath)
    
    excluded_paths = [
                      "math",
                      "anli",]
    
    try:
        if metric in ['accuracy', 'brierscore', 'crossentropy']:
            if any(sub in filepath for sub in ['lambada', 'qasper', 'trivia', '-generation', 'squad']): # qasper and lambada seem to have problems
                return
            y, X, pp = torch.tensor(df["Target"]), torch.tensor(df["Features"]), torch.tensor(df['AuxiliaryPredictionProb'])
            z = None
        else:
            if not any(sub in filepath for sub in ['squad', 'bench-generation', 'trivia']):
                return
            y, X, pp, z = [torch.tensor(df[key]) for key in ["Metric", "Features", "Metric", "Metric"]]
    except:
        ValueError("Fix the error")
        return 

    if "Strata" in df.keys() and len(df['Strata']) > 0 and not any(path in filepath for path in excluded_paths):
        groups = np.array(df["Strata"])
        setting = "supervised"
        dp = DataProcessor(X=X, y=y, pp=pp, groups=groups, z = z)
    else:
        setting = "unsupervised"
        dp = DataProcessor(X=X, y=y, pp=pp, z = z)
        
    save_dir = construct_save_dir(loco, sampling, mingroupsize, datatype, metric, setting)
    preprocess_data(dp, args.metric, filepath)

    minsize = mingroupsize * 2 if sampling == 'prop' else mingroupsize * 4
    dp.filter_by_groupsize(min_group_size=minsize)

    y, X, pp, groups, z, pz = dp.y, dp.X, dp.pp, dp.groups, dp.z, dp.pz
    
    # if too few groups, skip
    if len(np.unique(groups)) < 5:
        return 
    
    dr = DataBalancer(X=X, groups=groups, z=z, pz=pz)
    
    Xall, zall, groupsall, pzall = expand_data(X, z, groups, pz, loco)
    
    # compute ground truth 
    gt = {g: np.mean(zall[groupsall == g]) for g in np.unique(groupsall)}
    unique_groups, counts = np.unique(groupsall, return_counts=True)
    n_groups = dict(zip(unique_groups, counts))
    
    store = {}
     
    for iter in range(n_iterations):
        store[iter] = process_iteration(iter, dr, mingroupsize, sampling, unique_groups, compute_cis, loco)
        store[iter]['gt'] = gt
    
    
    # Save results from one evaluation (points estimates + CIs)
    iter = 0
    if len(store[iter]['cis']) > 0:
        data = []
        for approach in ['DT', 'EB']:
            for g, values in store[iter]['cis'][approach].items():
                data.append({
                    'group': g,
                    'gt': gt[g],
                    'lb': values[0],
                    'ub': values[1],
                    'method': approach,
                    'estimate': store[iter]['estimates'][approach][g]
                })

        df_iter = pd.DataFrame(data)
        iter_save_dir = save_dir.replace('estimates/', 'iterdata/')
        os.makedirs(iter_save_dir, exist_ok=True)
        filename_iter = os.path.join(iter_save_dir, "-".join(filepath.split("/")[-2:])).replace('.pt', '.csv')
        df_iter.to_csv(filename_iter, index=False)


    torch.save(store, os.path.join(save_dir, "-".join(filepath.split("/")[-2:])))
        
    # Print MSEs
    mses = {method: [] for method in store[0]['estimates']}
    cis = store[0]['cis']
    coverage = {method: {g: [] for g in unique_groups} for method in cis}
    width = {method: {g: [] for g in unique_groups} for method in cis}

    for iter, values in store.items():
        gt = values['gt'] # these will be constant across iterations
        for method, innerval in values['estimates'].items():
            mses[method].append(np.mean([(innerval[g] - gt[g]) ** 2 for g in unique_groups]))
        if 'cis' in values:
            for method, innerval in values['cis'].items():
                for g in unique_groups:
                    lb, ub = innerval[g]
                    coverage[method][g].append(1 if lb <= gt[g] <= ub else 0)
                    width[method][g].append(ub - lb)
 
    for method, values in mses.items():
        print(f"Method: {method} w/ MSE: {np.mean(values):.2f}")
    for method, values in coverage.items():
        print(f"Method: {method} w/ coverage: {np.mean([np.mean(val) for key, val in values.items()]):.2f}")
    for method, values in width.items():
        print(f"Method: {method} w/ width: {np.mean([np.mean(val) for key, val in values.items()]):.2f}")
        
    # transform the dictionaries into dataframes
    width_df = pd.DataFrame.from_dict({method: {g: np.mean(val) for g, val in values.items()} for method, values in width.items()}, orient='index').T
    coverage_df = pd.DataFrame.from_dict({method: {g: np.mean(val) for g, val in values.items()} for method, values in coverage.items()}, orient='index').T
    
    # Ensure the directories exist
    coverage_save_dir = save_dir.replace('estimates/', 'cis/')
    os.makedirs(coverage_save_dir, exist_ok=True)

    # Create paths and save them
    coverage_filepath = os.path.join(coverage_save_dir, "-".join(filepath.split("/")[-2:])).replace('.pt', '_coverage_.csv')
    width_filepath = os.path.join(coverage_save_dir, "-".join(filepath.split("/")[-2:])).replace('.pt', '_width_.csv')

    if not coverage_df.empty:
        coverage_df.to_csv(coverage_filepath, index=False)

    if not width_df.empty:
        width_df.to_csv(width_filepath, index=False)
        
def main():
    args = parse_arguments()
    all_files = get_files(args.datatype)

    # with ProcessPoolExecutor(max_workers=2) as executor:
    #     futures = [executor.submit(process_file, filepath, args) for filepath in all_files]
        
    #     progress = tqdm.tqdm(as_completed(futures), total=len(futures), desc="Processing Files")
        
        # for future in progress:
        #     result = future.result()
        #     progress.set_postfix(file=result) 
    
    for filepath in tqdm.tqdm(all_files):
        process_file(filepath, args)

if __name__ == "__main__":
    main()