import click
import os

os.environ['HF_CACHE'] = '/mnt/swordfish-pool2//hf_cache'

import json

import sys
import torch
import random

from datetime import datetime

from tqdm import tqdm

import pickle

from transformers import PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer


sys.path.append('../inference')
# sys.path.append('../baselines/t5_cond')

from classifiers import load_style_model, text_to_style
# from luar import load_uar_hf_model, get_uar_embeddings

sys.path.append('../baselines/tinystyle')

from tinystyle_train import TinyStyle, smart_tokenizer_and_embedding_resize, MODEL_TO_MODEL_TYPE

import hashlib


def get_paraphrases(*, model, tokenizer, input_texts, num_return_sequences=1, top_p=0.80, temp=1.5, max_length_input=60, max_length_output=60, device='cuda'):
    batch = tokenizer(
        input_texts,
        truncation=True,
        padding='longest',
        max_length=max_length_input,
        return_tensors="pt",
    ).to(device)

    translated = model.generate(
        **batch,
        max_length=max_length_output,
        do_sample=True,
        top_p=top_p,
        num_return_sequences=num_return_sequences,
        temperature=temp,
    )
    tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

    reshaped = []
    for i in range(len(input_texts)):
        reshaped.append(tgt_text[i*num_return_sequences:(i+1)*num_return_sequences])

    return reshaped

def generate_tasks(*, data, source_author_paraphrases=None):
    tasks = []
    for source_author, source_texts in data['source_authors'].items():
        if source_author_paraphrases:
            source_paraphrases = source_author_paraphrases[source_author] 
        else:
            source_paraphrases = [None for _ in source_texts]

        assert len(source_texts) == len(source_paraphrases)

        for target_author, target_texts in data['target_authors'].items():
            for source_text, source_paraphrase in zip(source_texts, source_paraphrases):
                sample = {
                    'source_author': source_author,
                    'source_paraphrase': source_paraphrase,
                    'source_text': source_text,
                    'target_author': target_author,
                    'source_author_texts': source_texts,
                    'target_author_texts': target_texts,
                }
                tasks.append(sample)

    return tasks
       
def load_raw_authorship_data(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data

def load_data(data_dir):
    with open(os.path.join(data_dir, 'args.json'), 'r') as f:
        args = json.load(f)
    assert args == args

    with open(os.path.join(data_dir, 'task_data.jsonl'), 'r') as f:
        task_data = [json.loads(l) for l in f.readlines()]

    with open(os.path.join(data_dir, 'target_author_embeddings.pkl'), 'rb') as f:
        target_author_embeddings = pickle.load(f)

    return task_data, target_author_embeddings

def make_data(args, data_dir):
    data = load_raw_authorship_data(args['data_path'])

    paraphraser_tokenizer = PegasusTokenizer.from_pretrained(args['paraphraser_name'])
    paraphraser_model = PegasusForConditionalGeneration.from_pretrained(args['paraphraser_name']).to(args['paraphraser_args']['device'])

    print("Generating source author paraphrases")
    source_author_paraphrases = {}
    for author in tqdm(sorted(data['source_authors'].keys())):

        input_texts = data['source_authors'][author]
        paraphrases = []
        for batch in range(0, len(input_texts), args['paraphraser_batch_size']):
            paraphrases.extend(get_paraphrases(**args['paraphraser_args'], model=paraphraser_model, tokenizer=paraphraser_tokenizer, input_texts=input_texts[batch:batch + args['paraphraser_batch_size']]))
        assert len(paraphrases) == len(input_texts)
        source_author_paraphrases[author] = paraphrases

    print("Generating target author embeddings")
    style_model, style_tokenizer, _ = load_style_model()
    style_model.eval()
    style_model.to(args['device'])
    target_author_embeddings = {}
    for author in tqdm(sorted(data['target_authors'].keys())):
        author_target_texts = data['target_authors'][author]
        target_author_embeddings[author] = text_to_style(model=style_model, tokenizer=style_tokenizer, texts=author_target_texts, device=args['device'], model_type='style')
        target_author_embeddings[author] = [x.detach().cpu() for x in target_author_embeddings[author]]

    task_data = generate_tasks(data=data, source_author_paraphrases=source_author_paraphrases)


    os.makedirs(data_dir, exist_ok=True)

    with open(os.path.join(data_dir, 'args.json'), 'w') as f:
        json.dump(args, f)

    with open(os.path.join(data_dir, 'task_data.jsonl'), 'w') as f:
        for task in task_data:
            f.write(json.dumps(task) + '\n')
    
    with open(os.path.join(data_dir, 'target_author_embeddings.pkl'), 'wb') as f:
        pickle.dump(target_author_embeddings, f)

    return task_data, target_author_embeddings


def perform_authorship_transfer(*, args, tasks, target_author_embeddings, out_file_name='results.jsonl'):
    # important args:
    # base_model
    # device
    # embed_selection
    # mean_sample
    # max_length_input
    # max_length_output
    # use_actual_input
    # combine_actual_para
    # checkpoint
    # do_sample
    # top_p
    # temp
    # out_dir



    print("Loading TinyStyle model")
    device = args['device']
    model = TinyStyle(base_model=args['base_model'], use_style=True, ctrl_embed_dim=768)

    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args['base_model'])




    if tokenizer._pad_token is None:
        special_tokens_dict = dict(pad_token='[PAD]')
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            model=model.model
        )

    if MODEL_TO_MODEL_TYPE[args['base_model']] == 'llama':
        tokenizer.padding_side = 'left'

    model.to(device)



    print(f"Loading TinyStyle model state dict from {args['checkpoint']}")
    current_state = model.state_dict()
    saved_state_dict = torch.load(args['checkpoint'], map_location=device)
    current_state.update(saved_state_dict)
    model.load_state_dict(current_state)

    out_dir = args['out_dir']
    out_file = os.path.join(out_dir, out_file_name)

    with open(out_file, 'w') as f:
    
        for task in tqdm(tasks):

            target_author = task['target_author']
            target_embeds = target_author_embeddings[target_author]

            input_text = task['source_paraphrase']

            num_inputs = len(input_text)

           

            if args['use_actual_input'] and args['combine_actual_para']:
                input_text = input_text+[task['source_text']]*num_inputs
                num_inputs = len(input_text)

            elif args['use_actual_input']:
                input_text = [task['source_text']]*num_inputs

            if args.get('do_lower', False):
                input_text = [text.lower() for text in input_text]

            # print(f"Num inputs: {len(input_text)}")

            if args['embed_selection'] == 'random':
                batch_ctrl_embeds = random.choice(target_embeds).unsqueeze(0).repeat(num_inputs, 1)
                # batch_ctrl_embeds = batch_ctrl_embeds.squeeze(1)
            elif args['embed_selection'] == 'first':
                batch_ctrl_embeds = target_embeds[0].unsqueeze(0).repeat(num_inputs, 1)
            elif args['embed_selection'] == 'mean':
                assert isinstance(args['mean_sample'], (int, list))

                if isinstance(args['mean_sample'], list):
                    selected_embeddings = []
                    for _ in range(num_inputs):
                        num_selected = random.choice(args['mean_sample'])
                        num_selected = min(num_selected, len(target_embeds))
                        selected_embeddings.append(torch.stack(random.sample(target_embeds, num_selected)).mean(dim=0).unsqueeze(0))
                    batch_ctrl_embeds = torch.cat(selected_embeddings, dim=0)

                elif isinstance(args['mean_sample'], int) and args['mean_sample'] > 0 and args['mean_sample'] != len(target_embeds):
                    selected_embeddings = []
                    for _ in range(num_inputs):
                        num_selected = args['mean_sample']
                        num_selected = min(num_selected, len(target_embeds))
                        selected_embeddings.append(torch.stack(random.sample(target_embeds, num_selected)).mean(dim=0).unsqueeze(0))
                    batch_ctrl_embeds = torch.cat(selected_embeddings, dim=0)

                else:
                    batch_ctrl_embeds = torch.stack(target_embeds).mean(dim=0).unsqueeze(0).repeat(num_inputs, 1)
            else:
                raise ValueError(f"Unknown embed selection method: {args['embed_selection']}")


            batch_ctrl_embeds = batch_ctrl_embeds.to(device)
            
            if MODEL_TO_MODEL_TYPE[args['base_model']] == 't5':
                encoded_input = tokenizer(input_text, return_tensors='pt', padding=True, max_length=args['max_length_input'], truncation=True).to(device)
            
                outputs = model.generate(
                    **encoded_input,
                    style=batch_ctrl_embeds,
                    max_length=args['max_length_output'],
                    do_sample=args['do_sample'],
                    top_p=args['top_p'],
                    temperature=args['temp'],
                )
            elif MODEL_TO_MODEL_TYPE[args['base_model']] == 'llama':

                raise ValueError("LLAMA model not supported")

                # # I think batched inputs would require left padding, but I'm not sure
                # input_text = [f"{tokenizer.bos_token} {text} |||" for text in input_text]
                # encoded_input = tokenizer(input_text, return_tensors='pt', padding=True, max_length=args['max_length_input'], truncation=True, add_special_tokens=False).to(device)

                # # print(encoded_input['input_ids'])
                # # outputs = []
                # # for i in range(num_inputs):
                #     # output = model.generate(
                #     #     input_ids=encoded_input['input_ids'][i].unsqueeze(0),
                #     #     attention_mask=encoded_input['attention_mask'][i].unsqueeze(0),
                #     #     style=batch_ctrl_embeds[i].unsqueeze(0),
                #     #     max_length=args['max_length_output'],  
                #     #     do_sample=True,
                #     #     top_p=args['top_p'],
                #     #     temperature=args['temp'],
                #     # )
                #     # outputs.append(output[0])
                
                # outputs = model.generate(
                #     input_ids=encoded_input['input_ids'],
                #     attention_mask=encoded_input['attention_mask'],
                #     style=batch_ctrl_embeds,
                #     max_length=args['max_length_output'],  
                #     do_sample=True,
                #     top_p=args['top_p'],
                #     temperature=args['temp'],
                # )

            else:
                raise ValueError(f"Unknown model type: {MODEL_TO_MODEL_TYPE[args['base_model']]}")
                    
            

            outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

            task['output'] = outputs
            # print(f"Input: {input_text}")
            # print(f"Output: {outputs}")

            f.write(json.dumps(task) + '\n')


def main():

    random.seed(42)

    args = {
        'transfer_args': {
            'base_model': 'google/t5-v1_1-large', #'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T', #'google/t5-v1_1-large', #'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T',
            'device': 'cuda',
            'embed_selection': 'mean', #mean #first
            'mean_sample': 8, #8, #20, #8,
            'max_length_input': 80,
            'max_length_output': 80,
            'use_actual_input': False,
            'combine_actual_para': False,
            'do_sample': True,
            'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/models/enc_dec_cond/2024-05-05-03.31.44_backup/best_model_google_t5-v1_1-large_1e-05_64.pt',
             
             # '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1_fixed/2024-05-28-02.06.56/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1/2024-05-21-11.41.35/best_model_google_t5-v1_1-large_1e-05_64.pt',
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1_lower/2024-05-21-01.15.22/best_model_google_t5-v1_1-large_1e-05_64.pt',
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/models/enc_dec_cond/2024-05-05-03.31.44_backup/best_model_google_t5-v1_1-large_1e-05_64.pt',
             #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1_lower/2024-05-21-01.15.22/best_model_google_t5-v1_1-large_1e-05_64.pt',
             #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1/2024-05-16-17.58.41/best_model_google_t5-v1_1-large_1e-05_64.pt',
            # /mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1/2024-05-21-11.41.35/best_model_google_t5-v1_1-large_1e-05_64.pt 
            'do_lower': False,
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1/2024-05-16-17.58.41/best_model_google_t5-v1_1-large_1e-05_64.pt',
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_6_skip_optim/2024-05-17-01.35.01/best_model_google_t5-v1_1-large_2.5e-06_64.pt',  
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1/2024-05-16-17.58.41/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/models/causal_fixed/2024-05-10-01.40.22/best_model_TinyLlama_TinyLlama-1.1B-intermediate-step-1431k-3T_1e-05_64.pt',
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/models/enc_dec_cond_opt_resume/2024-05-10-13.11.18/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            # '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/models/enc_dec_cond/2024-05-05-03.31.44_backup/best_model_google_t5-v1_1-large_1e-05_64.pt',

            'top_p': 0.80,
            'temp': 1.0,
            'max_length_input': 80,
            'max_length_output': 80,
            'out_dir': 'results/single/tinystyle_230k_unsup', #'tinystyle_230k_unsup', #'sft_v1_outputs',
        },
        'data_args': {
            'data_path': "styll_and_metrics/styll_data_and_metrics/dataset/single.json",
            'paraphraser_name': 'tuner007/pegasus_paraphrase',
            'device': 'cuda',
            'paraphraser_batch_size': 16,
            'paraphraser_args': {
                'num_return_sequences': 5,
                'top_p': 0.80,
                'temp': 1.5,
                'max_length_input': 60,
                'max_length_output': 60,
                'device': 'cuda'
            },
            'prepared_data': 'single_processed_5_para',
        }
    }

    
    data_args = args['data_args']
    os.makedirs(data_args['prepared_data'], exist_ok=True)


    hashed_args = hashlib.sha256(json.dumps(data_args).encode()).hexdigest()
    print(f"Hashed args: {hashed_args}")

    data_dir = os.path.join(data_args['prepared_data'], hashed_args)

    if os.path.exists(data_dir):
        print(f"Output directory already exists: {data_dir}, loading data")
        task_data, target_author_embeddings = load_data(data_dir)

    else:
        print(f"Output directory does not exist: {data_dir}, making data")
        task_data, target_author_embeddings = make_data(data_args, data_dir)

    args['input_data_path'] = data_dir
    transfer_args = args['transfer_args']



    write_directory = os.path.join(transfer_args['out_dir'])
    os.makedirs(write_directory, exist_ok=True)
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    write_directory = os.path.join(write_directory, current_time)
    os.makedirs(write_directory, exist_ok=False)

    with open(os.path.join(write_directory, 'args.json'), 'w') as f:
        json.dump(args, f)


    args['transfer_args']['out_dir'] = write_directory

    perform_authorship_transfer(args=transfer_args, tasks=task_data, target_author_embeddings=target_author_embeddings)

if __name__ == "__main__":
    main()

