import argparse
import json
import os
import random
from tqdm import tqdm
from datetime import datetime
import time

import sys


from openai import OpenAI
client = OpenAI()

def generate_tasks(data, max_source_texts=None):
    tasks = []
    for source_author, source_texts in data['source_authors'].items():
        if max_source_texts is not None:
            source_texts = source_texts[:max_source_texts]
        for target_author, target_texts in data['target_authors'].items():
            for source_text in source_texts:
                sample = {
                    'source_author': source_author,
                    'source_text': source_text,
                    'source_author_texts': source_texts,
                    'target_author_texts': target_texts,
                    'target_author': target_author,
                }
                tasks.append(sample)

    return tasks
       
def load_raw_authorship_data(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data


def hit_openai(message, model_name):
    response = client.chat.completions.create(
        model=model_name,
        response_format={ "type": "json_object" },
        messages=[
            {"role": "system", "content": "You are a helpful assistant designed to output JSON."},
            {"role": "user", "content": message}
        ]
        )
    
    result = response.choices[0].message.content
    text = json.loads(result)

    return text

def do_prompted_transfer(*, original_text, examples, model_name):
    message = 'The following comments are written by a single author: \n'
    for i, text in enumerate(examples):
        message += json.dumps({'text':text})+'\n'
    message += "\n\nCan you rewrite the following comment to make it look like the above author's style:\n"
    message += json.dumps({'text':original_text})+'\n'

    return hit_openai(message, model_name=model_name)


# example usage:
# python hit_openai.py --out_dir /path/to/output --data /path/to/data.json --max_examples 16 --max_source_texts 10

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--out_dir', type=str)
    parser.add_argument('--data', type=str)
    parser.add_argument('--max_examples', type=int, default=16)
    parser.add_argument('--max_source_texts', type=int, default=None)
    parser.add_argument('--model_name', type=str, default="gpt-3.5-turbo-0125")
    #"gpt-4-turbo"

    # parser.add_argument('--approach', type=str)

    cmd_args = parser.parse_args()
    hparams = vars(cmd_args)
    out_dir = hparams['out_dir']
    # approach = hparams['approach']

    # assert approach in ['chatgpt']

    dtime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    task_folder = f"{out_dir}/{dtime}"
    os.makedirs(task_folder, exist_ok=False)

    with open(os.path.join(task_folder, "hparams.json"), 'w') as f:
        json.dump(hparams, f)

    data = load_raw_authorship_data(hparams['data'])
    tasks = generate_tasks(data, max_source_texts=hparams['max_source_texts'])


    # with open(hparams['assignments_json'], 'r') as f:
    #     assignments = json.load(f)

    # total_transfers = sum(
    #     [
    #         len(assignments[source_author]['target'])
    #         * len(assignments[source_author]['test_samples'])
    #         for source_author in assignments.keys()
    #     ]
    # )

    total_transfers = len(tasks)

    counter = -1
    with open(os.path.join(task_folder, f"style.jsonl"), 'w+') as out:
        with tqdm(total=total_transfers) as pbar:

            for task in tasks:

                

                # import pdb; pdb.set_trace()
                counter += 1
                
                # if counter < 2592 or counter > 2592:  # 715 + 121:
                #     print('skipping', counter)
                #     continue


                original_text = task['source_text']
                target_texts = task['target_author_texts'][: hparams['max_examples']]

                result = do_prompted_transfer(
                    original_text=original_text, examples=target_texts, model_name=hparams['model_name']
                )
                task['output'] = result

                out.write(json.dumps(task) + '\n')
                pbar.update(1)
