import os
import sys
import torch
import time
import json
from tqdm.auto import tqdm
from datetime import datetime
import random

sys.path.append('../inference')

from classifiers import load_formality_model, load_style_model, text_to_style, compute_style_loss
from inference_utils import get_setup, batched_controlled_paraphrase


from transformers import PegasusForConditionalGeneration, PegasusTokenizer

sys.path.append('../gyafc/eval')
from eval_attribute import load_internal_formality_model

if __name__ == '__main__':

    random.seed(1234)
    torch.manual_seed(1234)


    
    NUM_INFERENCES_PER_INPUT = 1

    TASK = sys.argv[1] #'informal' # 'style'

    assert TASK in ['formal', 'informal', 'style']

    OPPOSITE_TASK = 'formal' if TASK == 'informal' else 'informal'

    assert TASK == 'informal' # for the timing eval
    # INPUT_PATH = '../../gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/'+OPPOSITE_TASK
    INPUT_PATH='../../gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/formal_for_timing.txt'
    OUT_DIR = '../../gyafc/eval/TIMING/paraguide'

    hparams = {
        'size': 80,
        'lr': float(sys.argv[2]), #5000, #1000, #200,
        'total_t': 200,
        'num_drift_steps': 3,
        'use_sqrt_schedule': True,
        'top_p': 0.8,
        'temperature': 3.0,
        'straight_through': False,
        'use_actual': False,
        'model_path': '/mnt/reddit_mud/raw_all/emnlp/models/paraguide_uncond/ssd_cs_dbs80/best_checkpoint/' # requires downloading model
    }

    (
        args,
        model,
        tokenizer,
        model_embedding_lut,
        embedding_sum_layer,
        timestep_layer,
        ctr_embed_projection, # unused
    ) = get_setup(**hparams)


    dtime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    task_folder = f"{OUT_DIR}/{dtime}_{hparams['lr']}"
    os.makedirs(task_folder, exist_ok=False)

    with open(os.path.join(task_folder, "hparams.json"), 'w') as f:
        json.dump(hparams, f)

    with open(os.path.join(task_folder, "args.txt"), 'w') as f:
        json.dump(str(args), f)


    if TASK in ['formal', 'informal']:
        # Load formality guidance model
        ctr_model, tokenizer, label_mapping = load_internal_formality_model()
        args.optimizing_label_index = label_mapping[TASK]
        ctr_embeds = ctr_model.get_input_embeddings().weight.detach()
        args.ctr_model = ctr_model
        args.ctr_embeds = ctr_embeds
        args.tokenizer = tokenizer
        args.ctr_embeds = args.ctr_embeds.to(args.accelerator.device)
        args.ctr_model.to(args.accelerator.device)
        args.ctr_model.eval()


        # Define a loss function to optimize that takes word embeddings and a sequence mask
        args.loss_fn = lambda embeds, mask: -torch.nn.functional.log_softmax(
            args.ctr_model(inputs_embeds=embeds, attention_mask=mask).logits, dim=-1
        )[:, args.optimizing_label_index].sum()

    elif TASK in ['style']:
        raise ValueError("Style task not supported")


    else:
        raise ValueError(f"Unknown task: {TASK}")



    # load AR paraphraser
    paraphraser_tokenizer = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
    paraphraser_model = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase').to(args.accelerator.device)
    
    with open(INPUT_PATH, 'r') as f:
        input_data = [l.strip() for l in f.readlines()]

    total_transfers = len(input_data)
    with open(os.path.join(task_folder, f"to_{TASK}.jsonl"), 'w+') as out:
        with tqdm(total=total_transfers) as pbar:
            for original_text in input_data:

                start = time.time()

                # Skip paraphrasing if using actual text
                if args.use_actual:
                    input = original_text
                    paraphrase = ''
                    print(f'Using actual: {original_text}')
                
                # Otherwise, first paraphrase the input
                else:
                    encoded = paraphraser_tokenizer([original_text], return_tensors='pt').to(args.accelerator.device)
                    paraphrase = paraphraser_model.generate(**encoded, max_length=60, do_sample=True, top_p=0.8, temperature=1.5)
                    paraphrase = paraphraser_tokenizer.batch_decode(paraphrase, skip_special_tokens=True)[0]
                    input = paraphrase
                    print(f'Paraphrased: {original_text} -> {input}')

                outputs = batched_controlled_paraphrase([input]*NUM_INFERENCES_PER_INPUT, num_samples=1, args=args, model=model, tokenizer=tokenizer, model_embedding_lut=model_embedding_lut, embedding_sum_layer=embedding_sum_layer, timestep_layer=timestep_layer, ctr_embed_projection=None, batch_ctrl_embeds=None, logging=False)
                # result = dict(
                #     input_label=INPUT_PATH,
                #     paraphrase=paraphrase,
                #     original_text=original_text,
                #     target_label=TASK,
                #     decoded=outputs)

                result = dict(
                    input_label=INPUT_PATH,
                    paraphrase=paraphrase,
                    source_text=original_text,
                    target_label=TASK,
                    output=outputs)
                
                print(f'{original_text} -> {paraphrase} ->' + "\n\t->" + "\n\t->".join(outputs[0]))
                out.write(json.dumps(result) + '\n')
                print('Elapsed:',time.time() - start)
                pbar.update(1)
