import argparse
import json
import os
import random
from tqdm import tqdm
from datetime import datetime

from openai import OpenAI
client = OpenAI()

def load_style_tasks(path, target_label):
    """
    Load data from a given path and label
    """
    with open(path) as f:
        lines = [x.strip() for x in f.readlines()]

    data = []
    for line in lines:
        data.append({
            'source_text': line,
            'source_paraphrase': [''],
            'target_author': target_label,
            'file_path': path
        })
    return data

def generate_tasks(data, max_source_texts=None):
    tasks = []
    for source_author, source_texts in data['source_authors'].items():
        if max_source_texts is not None:
            source_texts = source_texts[:max_source_texts]
        for target_author, target_texts in data['target_authors'].items():
            for source_text in source_texts:
                sample = {
                    'source_author': source_author,
                    'source_text': source_text,
                    'source_author_texts': source_texts,
                    'target_author_texts': target_texts,
                    'target_author': target_author,
                }
                tasks.append(sample)

    return tasks
       
def load_raw_authorship_data(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data


def hit_openai(message, model_name):
    print(model_name)
    response = client.chat.completions.create(
        model=model_name,
        response_format={ "type": "json_object" },
        messages=[
            {"role": "system", "content": "You are a helpful assistant designed to output JSON."},
            {"role": "user", "content": message}
        ]
        )
    
    result = response.choices[0].message.content
    text = json.loads(result)

    return text

def do_prompted_transfer(*, original_text, examples, model_name, target_style):

    message = f'The following texts are written in {target_style} style: \n'
    for i, text in enumerate(examples):
        message += json.dumps({'text':text})+'\n'
    message += f"\n\nCan you rewrite the following text to make it look like the above {target_style} style:\n"
    message += json.dumps({'text':original_text})+'\n'

    return hit_openai(message, model_name=model_name)




def do_transfer(*, hparams, tasks, out_file_name, examples):
    counter = -1

    with open(os.path.join(task_folder, out_file_name), 'w+') as out:
        for task in tqdm(tasks):
            counter += 1
            original_text = task['source_text']
            target_texts = random.sample(examples, hparams['max_examples'])

            result = do_prompted_transfer(
                original_text=original_text, examples=target_texts, model_name=hparams['model_name'], target_style=task['target_author']
            )
            task['output'] = result

            out.write(json.dumps(task) + '\n')



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--out_dir', type=str)
    parser.add_argument('--path_to_formal_input', type=str)
    parser.add_argument('--path_to_informal_input', type=str)
    parser.add_argument('--path_to_formal_examples', type=str)
    parser.add_argument('--path_to_informal_examples', type=str)
    parser.add_argument('--path_max_examples', type=int, default=128)
    parser.add_argument('--max_examples', type=int, default=16)
    parser.add_argument('--max_source_texts', type=int, default=None)
    parser.add_argument('--model_name', type=str, default="gpt-3.5-turbo")

    # "gpt-4-turbo"
    # "gpt-3.5-turbo-0125"

    cmd_args = parser.parse_args()
    hparams = vars(cmd_args)
    out_dir = hparams['out_dir']

    dtime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    task_folder = f"{out_dir}/{dtime}"
    os.makedirs(task_folder, exist_ok=False)

    with open(os.path.join(task_folder, "hparams.json"), 'w') as f:
        json.dump(hparams, f)

    formal_to_informal_tasks = load_style_tasks(hparams['path_to_formal_input'], target_label='informal')
    informal_to_formal_tasks = load_style_tasks(hparams['path_to_informal_input'], target_label='formal')

    if hparams['max_source_texts'] is not None:
        formal_to_informal_tasks = formal_to_informal_tasks[:hparams['max_source_texts']]
        informal_to_formal_tasks = informal_to_formal_tasks[:hparams['max_source_texts']]

    with open(hparams['path_to_formal_examples']) as f:
        formal_examples = [x.strip() for x in f.readlines()]
        formal_examples = formal_examples[:hparams['path_max_examples']]
    
    with open(hparams['path_to_informal_examples']) as f:
        informal_examples = [x.strip() for x in f.readlines()]
        informal_examples = informal_examples[:hparams['path_max_examples']]


    print('Performing formal to informal transfer')
    do_transfer(hparams=hparams, tasks=formal_to_informal_tasks, out_file_name='to_informal.jsonl', examples=informal_examples)

    print('Performing informal to formal transfer')
    do_transfer(hparams=hparams, tasks=informal_to_formal_tasks, out_file_name='to_formal.jsonl', examples=formal_examples)
