# try tsne
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import json
import os
import pickle

import glob

from tinystyle_generate_formal_informal_interp import build_embedding_mapping

# def run_pca(points):
#     pca = PCA(n_components=2)
#     pca.fit(points)
#     return pca

def embed_mapping_to_points(args, embed_mapping, labels_to_texts):



    label_to_color = {
        'formal': 'blue',
        'informal': 'green',
        'transfer_formal': 'black',
        'transfer_informal': 'black',
        'source_formal': 'black',
        'source_informal': 'black',
        'transfer_formal_middle': 'black',
        'transfer_informal_middle': 'black',
    }

    label_to_marker = {
        'formal': 'o',
        'informal': 'o',
        'transfer_formal': '^',
        'transfer_informal': '^',
        'source_formal': 'o',
        'source_informal': 'o',
        'transfer_formal_middle': '^',
        'transfer_informal_middle': '^',
    }

    embeds = []
    keys = sorted(list(embed_mapping.keys()))
    index_to_text = {}
    index_to_label = {}
    num_points = 0
    for key in keys:
        for i in range(len(embed_mapping[key])):
            index_to_text[num_points] = labels_to_texts[key][i]
            index_to_label[num_points] = key
            num_points += 1
        
        embeds.append(np.vstack(embed_mapping[key]))

    # import pdb; pdb.set_trace()


    # pca = run_pca(np.concatenate(embeds, axis=0))
    # for i, key in enumerate(keys):
    #     embeds[i] = pca.transform(embeds[i])
    #     plt.plot(*zip(*embeds[i]), 'o', label=key)

    # try tsne
    tsne = TSNE(n_components=2)
    transform_embeds = tsne.fit_transform(np.concatenate(embeds, axis=0))

    # pca = PCA(n_components=2)
    # transform_embeds = pca.fit_transform(np.concatenate(embeds, axis=0))

    for key in keys:
        indices = [i for i in range(num_points) if index_to_label[i] == key]
        plt.plot(transform_embeds[indices, 0], transform_embeds[indices, 1], 'o', label=key, color=label_to_color[key], marker=label_to_marker[key], markersize=5) #8  if 'middle' in key else 5)
        # with hollow markers
        # plt.plot(transform_embeds[indices, 0], transform_embeds[indices, 1], 'o', label=key, color=label_to_color[key], marker=label_to_marker[key], markersize=10, markerfacecolor='none', markeredgewidth=1.5)

    # remove axis
    plt.axis('off')
    plt.savefig('tsne_new.png')

    for key in keys:
        indices = [i for i in range(num_points) if index_to_label[i] == key]
        if 'middle' in key:
            # add texts
            for i in indices:
                plt.text(transform_embeds[i, 0], transform_embeds[i, 1], index_to_text[i], fontsize=8)

    plt.savefig('tsne_new_with_text.png')
  

    

    


def get_args_and_data(result_path):
        
    folder_path = os.path.dirname(result_path)
    args_fname = os.path.join(folder_path, 'args.json')
    if not os.path.exists(args_fname):
        args_fname = os.path.join(folder_path, 'hparams.json')
        assert os.path.exists(args_fname)

    with open(args_fname, 'r') as f:
        args = json.load(f)

    percent = args['interp_percent']
    pairs = []
    with open(result_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            pairs.append((data['source_text'], data['output'][0]))

    return (percent, pairs)


args = {
        'out_dir': 'sft_v2_outputs_interp_v2',
        'path_to_formal_examples': '/home//gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/formal_exemplar_sample_filtered_0.95.128',
        'path_to_informal_examples': '/home//gyafc/data/GYAFC_Corpus 2/Entertainment_Music/tune/informal_exemplar_sample_filtered_0.95.128',
        'path_to_formal_input': '/home//gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/formal',
        'path_to_informal_input': '/home//gyafc/data/GYAFC_Corpus 2/Entertainment_Music/test/informal',
        'path_max_examples': 64,
        'device': 'cuda',
}







formal_transfer_paths = sorted(glob.glob(args['out_dir'] + '/*/to_formal.jsonl'))
informal_transfer_paths = sorted(glob.glob(args['out_dir'] + '/*/to_informal.jsonl'))


formal_datapoints = []

for formal_path in formal_transfer_paths:
    percent, pairs = get_args_and_data(formal_path)
    formal_datapoints.append((percent, pairs))

formal_datapoints = sorted(formal_datapoints, key=lambda x: x[0])

informal_datapoints = []
for informal_path in informal_transfer_paths:
    percent, pairs = get_args_and_data(informal_path)
    informal_datapoints.append((percent, pairs))

informal_datapoints = sorted(informal_datapoints, key=lambda x: x[0])


# formal_idx = 42
# informal_idx = 42

formal_idx = 42
informal_idx = 42

percents = [percent for percent, _ in formal_datapoints]
assert percents == [percent for percent, _ in informal_datapoints]

# percents = percents[::4]

percent_to_idx = {percent: i for i, percent in enumerate(percents)}

selected_formal_interps = [pairs[formal_idx] for _, pairs in formal_datapoints] #[::4]
selected_informal_interps = [pairs[informal_idx] for _, pairs in informal_datapoints] #[ #::4]

source_for_formal = selected_formal_interps[0][0]
source_for_informal = selected_informal_interps[0][0]

percent_to_idx = {percent: i for i, percent in enumerate(percents)}
middle_formal = [selected_formal_interps[percent_to_idx[0.2]][1],selected_formal_interps[percent_to_idx[0.4]][1], selected_formal_interps[percent_to_idx[0.6]][1], selected_formal_interps[percent_to_idx[0.8]][1]]
middle_informal =  [selected_informal_interps[percent_to_idx[0.2]][1],selected_informal_interps[percent_to_idx[0.4]][1],selected_informal_interps[percent_to_idx[0.6]][1], selected_informal_interps[percent_to_idx[0.8]][1]]

final_formal = selected_formal_interps[-1][1]
final_informal = selected_informal_interps[-1][1]

print(source_for_informal, middle_informal, final_informal)
print(source_for_formal, middle_formal, final_formal)


# load informal_examples
informal_examples = []
with open(args['path_to_informal_examples'], 'r') as f:
    for line in f:
        informal_examples.append(line.strip())

# load formal_examples
formal_examples = []
with open(args['path_to_formal_examples'], 'r') as f:
    for line in f:
        formal_examples.append(line.strip())

labels_to_texts = {
        'formal': formal_examples,
        'informal': informal_examples,
        # 'transfer_formal': [final_formal],
        # 'transfer_formal_middle': middle_formal,
        'transfer_informal': [final_informal],
        'transfer_informal_middle': middle_informal,
        # 'source_formal': [source_for_formal],
        'source_informal': [source_for_informal],
}

embedding_mapping = build_embedding_mapping(args, labels_to_texts)
embed_mapping_to_points(args, embedding_mapping, labels_to_texts)




