import os
import sys
import json
import click
import numpy as np
import torch

from tqdm import tqdm



import sys


from datetime import datetime

sys.path.append('../../emnlp_eval')
sys.path.append('../../baselines/tinystyle')
sys.path.append('../../inference')


from tinystyle_train import TinyStyle, smart_tokenizer_and_embedding_resize, MODEL_TO_MODEL_TYPE

from transformers import PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer


from classifiers import load_style_model, text_to_style

style_model = None
style_tokenizer = None

# from tinystyle_authorship import perform_authorship_transfer



from tinystyle_generate_formal_informal import build_embedding_mapping

def perform_authorship_transfer(*, args, tasks, target_author_embeddings, out_file_name='results.jsonl'):
    # important args:
    # base_model
    # device
    # embed_selection
    # mean_sample
    # max_length_input
    # max_length_output
    # use_actual_input
    # combine_actual_para
    # checkpoint
    # do_sample
    # top_p
    # temp
    # out_dir



    print("Loading TinyStyle model")
    device = args['device']
    model = TinyStyle(base_model=args['base_model'], use_style=True, ctrl_embed_dim=768)

    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args['base_model'])




    if tokenizer._pad_token is None:
        special_tokens_dict = dict(pad_token='[PAD]')
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            model=model.model
        )

    if MODEL_TO_MODEL_TYPE[args['base_model']] == 'llama':
        tokenizer.padding_side = 'left'

    model.to(device)



    print(f"Loading TinyStyle model state dict from {args['checkpoint']}")
    current_state = model.state_dict()
    saved_state_dict = torch.load(args['checkpoint'], map_location=device)
    current_state.update(saved_state_dict)
    model.load_state_dict(current_state)

    out_dir = args['out_dir']
    out_file = os.path.join(out_dir, out_file_name)

    with open(out_file, 'w') as f:
    
        for task in tqdm(tasks):

            target_author = task['target_author']
            target_embeds = target_author_embeddings[target_author]

            input_text = task['source_paraphrase']

            num_inputs = len(input_text)

           

            if args['use_actual_input'] and args['combine_actual_para']:
                input_text = input_text+[task['source_text']]*num_inputs
                num_inputs = len(input_text)

            elif args['use_actual_input']:
                input_text = [task['source_text']]*num_inputs

            if args.get('do_lower', False):
                input_text = [text.lower() for text in input_text]

            # print(f"Num inputs: {len(input_text)}")

            if args['embed_selection'] == 'random':
                batch_ctrl_embeds = random.choice(target_embeds).unsqueeze(0).repeat(num_inputs, 1)
                # batch_ctrl_embeds = batch_ctrl_embeds.squeeze(1)
            elif args['embed_selection'] == 'first':
                batch_ctrl_embeds = target_embeds[0].unsqueeze(0).repeat(num_inputs, 1)
            elif args['embed_selection'] == 'mean':
                assert isinstance(args['mean_sample'], (int, list))

                if isinstance(args['mean_sample'], list):
                    selected_embeddings = []
                    for _ in range(num_inputs):
                        num_selected = random.choice(args['mean_sample'])
                        num_selected = min(num_selected, len(target_embeds))
                        selected_embeddings.append(torch.stack(random.sample(target_embeds, num_selected)).mean(dim=0).unsqueeze(0))
                    batch_ctrl_embeds = torch.cat(selected_embeddings, dim=0)

                elif isinstance(args['mean_sample'], int) and args['mean_sample'] > 0 and args['mean_sample'] != len(target_embeds):
                    selected_embeddings = []
                    for _ in range(num_inputs):
                        num_selected = args['mean_sample']
                        num_selected = min(num_selected, len(target_embeds))
                        selected_embeddings.append(torch.stack(random.sample(target_embeds, num_selected)).mean(dim=0).unsqueeze(0))
                    batch_ctrl_embeds = torch.cat(selected_embeddings, dim=0)

                else:
                    batch_ctrl_embeds = torch.stack(target_embeds).mean(dim=0).unsqueeze(0).repeat(num_inputs, 1)
            else:
                raise ValueError(f"Unknown embed selection method: {args['embed_selection']}")


            # get source text embeddings
            source_embeds = build_embedding_mapping(args, {'source':input_text})['source']
            # batch_ctrl_embeds = batch_ctrl_embeds

            interp_embedding = []

            for i in range(batch_ctrl_embeds.shape[0]):
                interp = interpolate_points_nd(source_embeds[i], batch_ctrl_embeds[i], percent=args['interp_percent'])
                interp_embedding.append(interp)

            batch_ctrl_embeds = torch.stack(interp_embedding)
            batch_ctrl_embeds = batch_ctrl_embeds.to(device)
            
            if MODEL_TO_MODEL_TYPE[args['base_model']] == 't5':
                encoded_input = tokenizer(input_text, return_tensors='pt', padding=True, max_length=args['max_length_input'], truncation=True).to(device)
            
                outputs = model.generate(
                    **encoded_input,
                    style=batch_ctrl_embeds,
                    max_length=args['max_length_output'],
                    do_sample=args['do_sample'],
                    top_p=args['top_p'],
                    temperature=args['temp'],
                )
            elif MODEL_TO_MODEL_TYPE[args['base_model']] == 'llama':

                raise ValueError("LLAMA model not supported")

            else:
                raise ValueError(f"Unknown model type: {MODEL_TO_MODEL_TYPE[args['base_model']]}")
                    
            

            outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

            task['output'] = outputs
            f.write(json.dumps(task) + '\n')



# def interpolate_points_nd(point_1, point_2, n=10):
#     return np.array([point_1 + (point_2 - point_1) * i/n for i in range(n+1)])

def interpolate_points_nd(point_1, point_2, percent=0.1):
    return point_1 + (point_2 - point_1) * percent




def load_style_tasks(path, target_label):
    """
    Load data from a given path and label
    """
    with open(path) as f:
        lines = [x.strip() for x in f.readlines()]

    data = []
    for line in lines:
        data.append({
            'source_text': line,
            'source_paraphrase': [''],
            'target_author': target_label,
            'file_path': path
        })
    return data


       



    # remove the axis labels
    plt.xticks([])
    plt.yticks([])

    plt.savefig('tsne.png')



    plt.savefig('pca_interp.png')

def build_embedding_mapping(args, labels_to_texts):
    global style_model
    global style_tokenizer
    if style_model is None:
        style_model, style_tokenizer, _ = load_style_model()
    style_model.eval()
    style_model.to(args['device'])
    label_to_embeds = {}
    for label in labels_to_texts:
        texts = labels_to_texts[label]
        label_to_embeds[label] = text_to_style(model=style_model, tokenizer=style_tokenizer, texts=texts, device=args['device'], model_type='style')
        label_to_embeds[label] = [x.detach().cpu() for x in label_to_embeds[label]]

    return label_to_embeds


def main():
    args = {
            'base_model': 'google/t5-v1_1-large', 
            'device': 'cuda',
            'embed_selection': 'mean',
            'mean_sample': 64, #64, #8, 64, 128 #20000, # basically use all
            'max_length_input': 80,
            'max_length_output': 80,
            'use_actual_input': True,
            'combine_actual_para': False,
            'do_sample': True,
            'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1_fixed/2024-05-28-02.06.56/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            # '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1_lower/2024-05-21-01.15.22/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1_fixed/2024-05-28-02.06.56/best_model_google_t5-v1_1-large_1e-05_64.pt',
            
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1_fixed'
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v2_config_1/2024-05-21-11.41.35/best_model_google_t5-v1_1-large_1e-05_64.pt',
            #'/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1_lower/2024-05-21-01.15.22/best_model_google_t5-v1_1-large_1e-05_64.pt',
            # 'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_7_skip_optim/2024-05-17-02.21.47/best_model_google_t5-v1_1-large_1.25e-06_64.pt',
            # 'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_6_skip_optim/2024-05-17-01.35.01/best_model_google_t5-v1_1-large_2.5e-06_64.pt',
            # 'checkpoint': '/mnt/swordfish-pool2//reddit_mud/raw_all/emnlp/supervised_ft_models/enc_dec_ft_v1_config_1/2024-05-16-17.58.41/best_model_google_t5-v1_1-large_1e-05_64.pt',
            'top_p': 0.80,
            'temp': 1.0,
            'do_lower': False, 
            'max_length_input': 80,
            'max_length_output': 80,
            'out_dir': 'sft_v2_outputs_interp_v2',
            'path_to_formal_examples': '/home/V2/gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/formal_exemplar_sample_filtered_0.95.128',
            'path_to_informal_examples': '/home/V2/gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/informal_exemplar_sample_filtered_0.95.128',
            'path_to_formal_input': '/home/V2/gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/formal',
            'path_to_informal_input': '/home/V2/gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/informal',
            'path_max_examples': 64, #64, #16,
            'interp_percent': float(sys.argv[1])
    }

    cur_date  = datetime.now().strftime('%Y-%m-%d-%H_%M_%S')
    args['out_dir'] = os.path.join(args['out_dir'], cur_date)
    os.makedirs(args['out_dir'], exist_ok=True)

    formal_to_informal_tasks = load_style_tasks(args['path_to_formal_input'], target_label='informal')
    informal_to_formal_tasks = load_style_tasks(args['path_to_informal_input'], target_label='formal')

    with open(os.path.join(args['out_dir'], 'args.json'), 'w') as f:
        json.dump(args, f, indent=2)

    with open(args['path_to_formal_examples']) as f:
        formal_examples = [x.strip() for x in f.readlines()]
        formal_examples = formal_examples[:args['path_max_examples']]
    
    with open(args['path_to_informal_examples']) as f:
        informal_examples = [x.strip() for x in f.readlines()]
        informal_examples = informal_examples[:args['path_max_examples']]


    target_embed_mapping = build_embedding_mapping(args, labels_to_texts={'formal': formal_examples, 'informal': informal_examples})



    # save plot
    # embed_mapping_to_points(target_embed_mapping)
    # import pdb; pdb.set_trace()

    # target_to_mean_embed = {}
    # for key in target_embed_mapping:
    #     target_to_mean_embed[key] = torch.mean(torch.stack(target_embed_mapping[key]), dim=0)

    # # interpolate between the two points
    # new_target_embeddings = {}
    # new_target_embeddings['formal'] = [interpolate_points_nd(target_to_mean_embed['informal'], target_to_mean_embed['formal'], percent=args['interp_percent'])]
    # new_target_embeddings['informal'] = [interpolate_points_nd(target_to_mean_embed['formal'], target_to_mean_embed['informal'], percent=args['interp_percent'])]

    # target_embed_mapping = new_target_embeddings

    print('Performing formal to informal transfer')
    perform_authorship_transfer(args=args, tasks=formal_to_informal_tasks, target_author_embeddings=target_embed_mapping, out_file_name='to_informal.jsonl')

    print('Performing informal to formal transfer')
    perform_authorship_transfer(args=args, tasks=informal_to_formal_tasks, target_author_embeddings=target_embed_mapping, out_file_name='to_formal.jsonl')



if __name__ == '__main__':
    main()


