import numpy as np
import matplotlib.pyplot as plt
import torch

from sklearn.decomposition import PCA

import sys

from tinystyle_generate_formal_informal import build_embedding_mapping

# from classifiers import load_style_model, text_to_style


def interpolate_points_nd(point_1, point_2, n=10):
    return np.array([point_1 + (point_2 - point_1) * i/n for i in range(n+1)])

def run_pca(points):
    pca = PCA(n_components=2)
    pca.fit(points)
    return pca


TARGET_PATH = 'inputs_formal.txt'
INPUT_PATH = 'inputs_informal.txt'


with open(TARGET_PATH, 'r') as f:
    target_examples = [x.strip() for x in f.readlines()]

with open(INPUT_PATH, 'r') as f:
    input_examples = [x.strip() for x in f.readlines()]
   
ctr_model, tokenizer, ctr_embeds = load_style_model()

ctr_model.eval()
ctr_model.to('cuda')




target_embeds = [x.detach().cpu().numpy() for x in text_to_style(model=ctr_model, tokenizer=tokenizer, texts=target_examples, device='cuda', model_type='style')]
input_embeds = [x.detach().cpu().numpy() for x in text_to_style(model=ctr_model, tokenizer=tokenizer, texts=input_examples, device='cuda', model_type='style')]


pca = run_pca(target_embeds + input_embeds)
target_embeds = pca.transform(target_embeds)
input_embeds = pca.transform(input_embeds)

# plot the points
plt.plot(*zip(*target_embeds), 'ro')
plt.plot(*zip(*input_embeds), 'bo')


# draw a line between the first two points
point_1 = input_embeds[2]
point_2 = target_embeds[0]


# remove the axis labels
plt.xticks([])
plt.yticks([])

plt.savefig('pca.png')

# use black to draw the line
plt.plot(*zip(*interpolate_points_nd(point_1, point_2, 10)), 'k-')


plt.savefig('pca_interp.png')



import pdb; pdb.set_trace()


# def read_input_file(file_path):
#     with open(file_path, 'r') as f:
#         lines = f.readlines()
#         points = []
#         for line in lines:
#             points.append(np.array([int(i) for i in line.split()]))
#         return points


# plt.plot(*zip(point_1, point_2), 'ro')
# # plt.plot(*zip(*interpolate_points_nd(point_1, point_2, 10)), 'b-')
# # show each point on the line segment

# plt.savefig('interpolate.png')









