import os
import sys
import time
import jsonlines
from argparse import Namespace

from pprint import pprint
from tqdm import tqdm

from collections import defaultdict
import copy
from copy import deepcopy
import random

from utils import *
from model import Model



def single_run(llm, stage, recording, config, round):
    # Initialization of LLM Wrapper
    llm.refresh_stage(cur_stage = stage, cur_round = round)
    
    # Current experiment name
    if stage in ['Contrast-Responses-Merge-Memory', 'Regeneration-w-Suggestion']:
        exp_name = f'{round}-{stage}'
    else:
        exp_name = stage
        
    for sample in tqdm(recording):
        if exp_name in sample.keys():
            print(f'{exp_name} already done for the {sample["id"]}-th sample')
            continue

        try:
            completion = llm.predict(sample)
            for k,v in completion.items():
                sample[k] = v
        except Exception as e:
            sample[exp_name] = str(e)
            print(f'Error at {sample["id"]}-th sample: {str(e)}', file=sys.stderr)

    # Save current recording-List
    recording_path = f'/ossfs/workspace/Faithful-COT-Logic/recording/{config.model_name}/{config.dataset_name}'
    if not os.path.exists(recording_path):
        os.makedirs(recording_path)
    with open(os.path.join(recording_path, f'{exp_name}-{config.start_index}-{config.end_index}.json'), 'w') as f:
        json.dump(recording, f, indent=4)

def complete_run(llm, recording, config, total_iter_rounds):
    try:
        single_run(llm=llm, stage='Initial-Regeneration', recording=recording, config=config, round=0)
        get_cur_major_vote(weight_method='average', recording=recording, config=config)
        for iter_step_id in range(1, total_iter_rounds+1):
            single_run(llm=llm, stage='Contrast-Responses-Merge-Memory', recording=recording, config=config, round=2*iter_step_id-1)
            single_run(llm=llm, stage='Regeneration-w-Suggestion', recording=recording, config=config, round=2*iter_step_id)
            get_cur_major_vote(weight_method='average', recording=recording, config=config)
    except Exception as e:
        print(f'Error: {str(e)}', file=sys.stderr)




if __name__ == "__main__":

    config = Namespace()
    config.dataset_name = 'GSM8K'
    config.split = 'test'
    config.model_name = 'LLAMA3-8B'
    config.start_index = 0
    config.end_index = 768

    # Dataset configuration
    dataset_frn = f"data/{config.dataset_name}/{config.split}.jsonl"
    dataset = load_data(dataset_frn)

    print(f'Dataset: {config.dataset_name}, Length: {len(dataset)}')

    # Path of the Initial Responses
    initial_pred_directory=f'Initial-Generation-List/{config.model_name}/{config.dataset_name}'
    initial_pred_path=os.path.join(initial_pred_directory, 'output.jsonl')

    # Read and Processing the Initial Responses
    initial_generation_list = read_jsonl_as_list(initial_pred_path)
    recording_list = [{'id': zip_data[0]['id'], 'question': zip_data[0]['question'], 'response':zip_data[1]['completion'], 'response-answer':zip_data[1]['answer']} for zip_data in zip(dataset, initial_generation_list)]

    # Use recording_list to keep track of all the intermediate results.
    # Now the keys in recording_list: 'id', 'question', 'response' (initial response)
    print(f'size of initial prediction: {len(recording_list)}')


    # Model Initialization
    llm = Model(config, cur_stage='Prepare-Model')

    model_path = '/root/Meta-Llama-3-8B-Instruct'

    if "llama" in model_path.lower():
        llm.prepare_model(model_path)
    

    # Running
    recording = deepcopy(recording_list)[config.start_index:config.end_index]
    print(f'size of current run: {len(recording)}')

    complete_run(llm, recording, config, total_iter_rounds=9)
