import os
import glob
import numpy as np
import pickle
import pandas as pd

questions = ['fttrump1', 'ftobama1', 'ftbiden1', 'ftwarren1', 'ftsanders1', 'ftbuttigieg1', 'ftharris1', 'ftklobuchar1',
             'ftpence1', 'ftyang1', 'ftpelosi1', 'ftrubio1', 'ftocasioc1', 'fthaley1', 'ftthomas1', 'ftfauci1',
             'ftblack', 'ftwhite', 'fthisp', 'ftasian', 'ftillegal', 'ftfeminists', 'ftmetoo', 'fttransppl',
             'ftsocialists', 'ftcapitalists', 'ftbigbusiness', 'ftlaborunions', 'ftrepublicanparty', 'ftdemocraticparty'
             ]


def main(args):
    if args.mp == 0:
        print(f'Evaluating on community {args.community} without message passing.\n')
        base_path = f'finetuned_{args.model_name}_community_{args.community}_all'
    else:
        print(f'Evaluating on community {args.community} with message passing.\n')
        base_path = f'finetuned_{args.model_name}_community_mp_{args.community}_all'
    model_path = f'../train_clm/models/{base_path}'
    output_path = f'output/{base_path}'
    ckp_paths = [each for each in glob.glob(f'{model_path}/*') if 'checkpoint' in each]
    max_ckp_number_idx = np.argmax([int(os.path.basename(each).split('-')[-1])
                                    for each in ckp_paths if 'checkpoint' in os.path.basename(each)])
    ckp_path = ckp_paths[max_ckp_number_idx]

    # generate responses
    for run in [1, 2, 3, 4, 5]:
        for prompt in ['Prompt2', 'Prompt3', 'Prompt4', 'Prompt1']:
            print(f'Generating run {run} {prompt}')
            os.system(f'export CUDA_VISIBLE_DEVICES={args.gpu}')
            command = f'python generate_community_opinion.py ' \
                      f'--model_path={ckp_path} ' \
                      f'--prompt_data_path=anes2020_pilot_prompt_probing.csv ' \
                      f'--prompt_option={prompt} ' \
                      f'--output_path={output_path}/run_{run} ' \
                      f'--seed={run} '
            print(command)
            os.system(command)

    # compute sentiment scores
    command = f'python compute_group_stance.py ' \
              f'--data_folder={output_path} ' \
              f'--anes_csv_file=anes2020_pilot_prompt_probing.csv ' \
              f'--output_filename={output_path}/{base_path}_predictions.csv ' \
              f'--n_workers={args.n_workers}'
    print(command)
    os.system(command)

    # compute weighted average of survey results, serving as ground truths
    comm2political = pickle.load(open('comm2political_all.pkl', 'rb'))
    dem_percent, rep_percent = comm2political[args.community][0], comm2political[args.community][1]
    df_survery_results = pd.read_csv('anes2020_pilot_prompt_probing.csv')
    df_weighted_score = dem_percent * df_survery_results['Democrat'] + rep_percent * df_survery_results['Republican']
    weighted_scores = df_weighted_score.values

    # run the evaluation for each run and prompt
    df_output = pd.read_csv(f'{output_path}/{base_path}_predictions.csv')

    rows = []
    # for run in [1, 2, 3, 4, 5]:
    #     run = f"run_{run}"
    #     for prompt_format in [1, 2, 3, 4]:
    #         prompt_format = "Prompt{}".format(prompt_format)
    #         df_ = df_output[(df_output['run'] == run) & (df_output['prompt_format'] == prompt_format)]
    #         pred_scores = df_['group_sentiment'].values
    #         corr, p_val = stats.pearsonr(weighted_scores, pred_scores)
    #         rows.append([args.community, run, prompt_format, corr, p_val])

    # for prompt_format in [1, 2, 3, 4]:
    #     prompt_format = "Prompt{}".format(prompt_format)
    #     for i, question in enumerate(questions):
    #         df_ = df_output[(df_output['prompt_format'] == prompt_format) & (df_output['question'] == question)]
    #         pred_score = df_['group_sentiment'].values.mean()
    #         gt_score = weighted_scores[i]
    #         pred_score = round(pred_score, 2)
    #         gt_score = round(gt_score, 2)
    #         rows.append([args.mp, args.community, prompt_format, question, pred_score, gt_score])
    #
    # df_scores = pd.DataFrame(rows, columns=['mp', 'community', 'prompt_format', 'question', 'pred_score', 'gt_score'])
    #
    # if os.path.exists(f'df_scores_{args.model_name}_all.csv'):
    #     df_scores_ = pd.read_csv(f'df_scores_{args.model_name}_all.csv')
    #     df_scores = pd.concat([df_scores_, df_scores])
    #
    # df_scores.to_csv(f'df_scores_{args.model_name}_all.csv', index=False)


    # # os.makedirs('results', exist_ok=True)
    # if args.mp == 0:
    #     df_scores.to_csv(f'{output_path}/df_results_{args.community}.csv', index=False)
    # else:
    #     df_scores.to_csv(f'{output_path}/df_results_mp_{args.community}.csv', index=False)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='gpt2')
    parser.add_argument('--community', type=int, default=0)
    parser.add_argument('--mp', type=int, default=0)
    parser.add_argument('--n_workers', type=int, default=8)
    parser.add_argument('--gpu', type=str, default='')
    args = parser.parse_args()

    main(args)
