import os
import sys
import json
import click

from datetime import datetime

sys.path.append('../../emnlp_eval')
sys.path.append('../../baselines/tinystyle')
sys.path.append('../../inference')


from classifiers import load_style_model, text_to_style

from tinystyle_authorship import perform_authorship_transfer

def load_style_tasks(path, target_label):
    """
    Load data from a given path and label
    """
    with open(path) as f:
        lines = [x.strip() for x in f.readlines()]

    data = []
    for line in lines:
        data.append({
            'source_text': line,
            'source_paraphrase': [''],
            'target_author': target_label,
            'file_path': path
        })
    return data

def build_embedding_mapping(args, labels_to_texts):
    style_model, style_tokenizer, _ = load_style_model()
    style_model.eval()
    style_model.to(args['device'])
    label_to_embeds = {}
    for label in labels_to_texts:
        texts = labels_to_texts[label]
        label_to_embeds[label] = text_to_style(model=style_model, tokenizer=style_tokenizer, texts=texts, device=args['device'], model_type='style')
        label_to_embeds[label] = [x.detach().cpu() for x in label_to_embeds[label]]

    return label_to_embeds


def main():
    args = {
            'base_model': 'google/t5-v1_1-large', 
            'device': 'cuda',
            'embed_selection': 'mean',
            'mean_sample': 64, #16, #64, #8, 64, 128 #20000, # basically use all
            'max_length_input': 80,
            'max_length_output': 80,
            'use_actual_input': True,
            'combine_actual_para': False,
            'do_sample': True,
            'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1_fixed/2024-05-28-02.06.56/best_model_google_t5-v1_1-large_1e-05_64.pt',
            'top_p': 0.80,
            'temp': 1.0,
            'do_lower': False, 
            'max_length_input': 80,
            'max_length_output': 80,
            'out_dir': 'TIMING/sft_v2_outputs',
            'path_to_formal_examples': '/home///gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/formal_exemplar_sample_filtered_0.95.128',
            'path_to_informal_examples': '/home///gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/informal_exemplar_sample_filtered_0.95.128',
            'path_to_formal_input': '/home///gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/formal_for_timing.txt', 
            'path_to_informal_input': '/home///gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/informal',
            'path_max_examples': 128, #64, #16,
    }

    cur_date  = datetime.now().strftime('%Y-%m-%d-%H_%M_%S')
    args['out_dir'] = os.path.join(args['out_dir'], cur_date)
    os.makedirs(args['out_dir'], exist_ok=True)

    formal_to_informal_tasks = load_style_tasks(args['path_to_formal_input'], target_label='informal')
    informal_to_formal_tasks = load_style_tasks(args['path_to_informal_input'], target_label='formal')

    with open(os.path.join(args['out_dir'], 'args.json'), 'w') as f:
        json.dump(args, f, indent=2)

    with open(args['path_to_formal_examples']) as f:
        formal_examples = [x.strip() for x in f.readlines()]
        formal_examples = formal_examples[:args['path_max_examples']]
    
    with open(args['path_to_informal_examples']) as f:
        informal_examples = [x.strip() for x in f.readlines()]
        informal_examples = informal_examples[:args['path_max_examples']]


    target_embed_mapping = build_embedding_mapping(args, labels_to_texts={'formal': formal_examples, 'informal': informal_examples})

    print('Performing formal to informal transfer')
    perform_authorship_transfer(args=args, tasks=formal_to_informal_tasks, target_author_embeddings=target_embed_mapping, out_file_name='to_informal.jsonl')

    # print('Performing informal to formal transfer')
    # perform_authorship_transfer(args=args, tasks=informal_to_formal_tasks, target_author_embeddings=target_embed_mapping, out_file_name='to_formal.jsonl')



if __name__ == '__main__':
    main()


