''' preprocess paranmt sample data '''

import os
import sys
import json
import random
import numpy as np

# import torch
import click
from datasets import Dataset
from datasets import load_dataset
from datasets import disable_caching
from datetime import datetime
from tqdm import tqdm
import torch

from transformers import RobertaTokenizer

from datasets import load_from_disk

sys.path.append('../../inference')
from classifiers import text_to_style, load_style_model
from luar import load_uar_hf_model, get_uar_embeddings





def add_embeddings(*, example, style_tokenizer, luar_tokenizer, style_model, luar_model):
    '''add embeddings to example'''

    texts = example['text']

    # add embeddings
    # print(decoded)

    if style_model is not None:
        style_embedding = text_to_style(model=style_model, tokenizer=style_tokenizer, texts=texts, device='cuda', model_type='style', max_length=512) #[0]
        example['style_embedding'] = [x.detach().cpu().numpy() for x in style_embedding]
        # import pdb; pdb.set_trace()
    
    # if luar_model is not None:
    #     luar_embedding = get_uar_embeddings(model=luar_model, tokenizer=luar_tokenizer, texts=texts, device='cuda') #[0]
    #     example['luar_embedding'] = [x for x in luar_embedding.detach().cpu().numpy()]


    return example


# add click args
@click.command()
@click.option('--dataset_path', help='path_to_dataset', required=True)
def main(dataset_path):
    disable_caching()

    output_path = os.path.normpath(dataset_path) + '_with_style_embeds'

    # load dataset
    dataset = load_from_disk(dataset_path)

    # load wegman style model
    style_model, style_tokenizer, _ = load_style_model()
    style_model.to('cuda')

    # load uar model
    uar_model, uar_tokenizer = None, None #load_uar_hf_model()
    # uar_model.to('cuda')


    with_embeddings = dataset.map(lambda x: add_embeddings(
        example=x,
        style_tokenizer=style_tokenizer,
        luar_tokenizer=uar_tokenizer,
        style_model=style_model,
        luar_model=uar_model), batched=True, batch_size=32)
    

    with_embeddings.save_to_disk(output_path)
    


    
    


if __name__ == '__main__':
    main()
