import os
import sys
import torch
import time
import json
from tqdm.auto import tqdm
from datetime import datetime
import random

sys.path.append('../inference')

from classifiers import load_formality_model, load_style_model, text_to_style, compute_style_loss
from inference_utils import get_setup, batched_controlled_paraphrase

from transformers import PegasusForConditionalGeneration, PegasusTokenizer

# sys.path.append('../gyafc/eval')
# from eval_attribute import load_internal_formality_model



def generate_tasks(data, max_source_texts=None):
    tasks = []
    for source_author, source_texts in data['source_authors'].items():
        if max_source_texts is not None:
            source_texts = source_texts[:max_source_texts]
        for target_author, target_texts in data['target_authors'].items():
            for source_text in source_texts:
                sample = {
                    'source_author': source_author,
                    'source_text': source_text,
                    'source_author_texts': source_texts,
                    'target_author_texts': target_texts,
                    'target_author': target_author,
                }
                tasks.append(sample)

    return tasks
       
def load_raw_authorship_data(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data




if __name__ == '__main__':

    random.seed(1234)
    torch.manual_seed(1234)


    
    NUM_INFERENCES_PER_INPUT = 1

    TASK = 'style' #sys.argv[1] #'informal' # 'style'

    assert TASK in ['formal', 'informal', 'style']

    # OPPOSITE_TASK = 'formal' if TASK == 'informal' else 'informal'

    # INPUT_PATH = '../../gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/'+OPPOSITE_TASK
    # OUT_DIR = '../../gyafc/eval/paraguide'

    INPUT_PATH='../../emnlp_eval/styll_and_metrics/styll_data_and_metrics/dataset/random.json'

    OUT_DIR='../../emnlp_eval/results/random/paraguide/paraguide_random'

    hparams = {
        'size': 80,
        'lr': float(sys.argv[1]), #5000, #1000, #200,
        'total_t': 200,
        'num_drift_steps': 3,
        'use_sqrt_schedule': True,
        'top_p': 0.8,
        'temperature': 3.0,
        'straight_through': False,
        'use_actual': False,
        'model_path': '/mnt/reddit_mud/raw_all/emnlp/models/paraguide_uncond/ssd_cs_dbs80/best_checkpoint/'
    }

    (
        args,
        model,
        tokenizer,
        model_embedding_lut,
        embedding_sum_layer,
        timestep_layer,
        ctr_embed_projection, # unused
    ) = get_setup(**hparams)



    dtime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    task_folder = f"{OUT_DIR}/{dtime}_{hparams['lr']}"
    os.makedirs(task_folder, exist_ok=False)

    hparams['data'] = INPUT_PATH


    with open(os.path.join(task_folder, "hparams.json"), 'w') as f:
        json.dump(hparams, f)

    with open(os.path.join(task_folder, "args.txt"), 'w') as f:
        json.dump(str(args), f)


    if TASK in ['formal', 'informal']:
        raise ValueError("Formality task not supported")

    elif TASK in ['style']:
        args.optimizing_label_index = None
        ctr_model, tokenizer, ctr_embeds = load_style_model()
        args.ctr_model = ctr_model
        args.tokenizer = tokenizer
        args.ctr_embeds = ctr_embeds
        args.ctr_embeds = args.ctr_embeds.to(args.accelerator.device)
        args.ctr_model.to(args.accelerator.device)
        args.ctr_model.eval()

      


    else:
        raise ValueError(f"Unknown task: {TASK}")


    # load AR paraphraser
    paraphraser_tokenizer = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
    paraphraser_model = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase').to(args.accelerator.device)
    
    # with open(INPUT_PATH, 'r') as f:
    #     input_data = [l.strip() for l in f.readlines()]

    data = load_raw_authorship_data(hparams['data'])
    tasks = generate_tasks(data, max_source_texts=None)

    total_transfers = len(tasks)
    with open(os.path.join(task_folder, f"to_{TASK}.jsonl"), 'w+') as out:
        with tqdm(total=total_transfers) as pbar:
            for task in tasks:

                start = time.time()

                target_style_examples = task['target_author_texts']

                original_text = task['source_text']

                args.target_embeds = text_to_style(
                            model=args.ctr_model,
                            tokenizer=args.tokenizer,
                            texts=target_style_examples,
                            device=args.accelerator.device,
                            model_type='style',
                        )
                
                def style_loss(embeds, mask):
                    # To do: move set up batching inside of loss function
                    loss = 0
                    for e, m in zip(embeds, mask):
                        loss += compute_style_loss(
                            e.unsqueeze(0),
                            model=args.ctr_model,
                            target_embeds=args.target_embeds,
                            attention_mask=m.float().unsqueeze(0),
                            model_type='style',
                        )
                
                    return loss

                args.loss_fn = style_loss

                # Skip paraphrasing if using actual text
                if args.use_actual:
                    input = original_text
                    paraphrase = ''
                    print(f'Using actual: {original_text}')
                
                # Otherwise, first paraphrase the input
                else:
                    encoded = paraphraser_tokenizer([original_text], return_tensors='pt', padding=True, truncation=True, max_length=60).to(args.accelerator.device)
                    paraphrase = paraphraser_model.generate(**encoded, max_length=60, do_sample=True, top_p=0.8, temperature=1.5)
                    paraphrase = paraphraser_tokenizer.batch_decode(paraphrase, skip_special_tokens=True)[0]
                    input = paraphrase
                    print(f'Paraphrased: {original_text} -> {input}')

                outputs = batched_controlled_paraphrase([input]*NUM_INFERENCES_PER_INPUT, num_samples=1, args=args, model=model, tokenizer=tokenizer, model_embedding_lut=model_embedding_lut, embedding_sum_layer=embedding_sum_layer, timestep_layer=timestep_layer, ctr_embed_projection=None, batch_ctrl_embeds=None, logging=False)
                
                # result = dict(
                #     input_label=INPUT_PATH,
                #     paraphrase=paraphrase,
                #     original_text=original_text,
                #     target_label=TASK,
                #     decoded=outputs)

                # result = dict(
                #     input_label=INPUT_PATH,
                #     paraphrase=paraphrase,
                #     source_text=original_text,
                #     target_label=TASK,
                #     output=outputs)



                task['output'] = outputs
            

                out.write(json.dumps(task) + '\n')
                
                print(f'{original_text} -> {paraphrase} ->' + "\n\t->" + "\n\t->".join(outputs[0]))
                # out.write(json.dumps(result) + '\n')
                # print('Elapsed:',time.time() - start)
                pbar.update(1)
