import numpy as np
from sklearn.manifold import TSNE
import os
import seaborn as sns
import matplotlib.pyplot as plt
import time
import sys

def read_img_features(feature_store):
    import csv
    import base64
    from tqdm import tqdm

    csv.field_size_limit(sys.maxsize)

    print("Start loading the image feature")
    start = time.time()

    views = 36

    tsv_fieldnames = ['scanId', 'viewpointId', 'image_w', 'image_h', 'vfov', 'features']
    features = {}
    with open(feature_store, "r") as tsv_in_file:     # Open the tsv file.
        reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=tsv_fieldnames)
        for item in reader:
            long_id = item['scanId'] + "_" + item['viewpointId']
            features[long_id] = np.frombuffer(base64.decodestring(item['features'].encode('ascii')),
                                                   dtype=np.float32).reshape((views, -1))   # Feature of long_id is (36, 2048)

    print("Finish Loading the image feature from %s in %0.4f seconds" % (feature_store, time.time() - start))
    return features


scans = ['759xd9YjKW5', 'q9vSo1VnCiC', 'pa4otMbVnkk', 'Uxmj2M2itWa', 'JF19kD82Mey', '2azQ1b91cZZ', 'YVUC4YcDtcY', 'VVfe2KiqLaN', 'D7N2EKCX4Sj', 'RPmz2sHmrrY']

features = 'ResNet-152-imagenet.tsv'
feat_dict = read_img_features(features)

path = "../Documents/views_img"
files= os.listdir(path)
X = None
Y = []
for scan in scans:
    if os.path.isdir(path+"/"+scan):
        print("scan:",scan)
        views = os.listdir(path+"/"+scan)
        for i, view in enumerate(views):
            # image_feature = np.load("learned_features/"+scan+"_"+view+".npy")
            image_feature = feat_dict[scan + '_' + view]
            image_feature = np.mean(image_feature, axis=0, keepdims=True)

            if X is None:
                X = image_feature
            else:
                X = np.concatenate((X, image_feature), 0)

            Y.append(scans.index(scan))


print(X.shape)
X_embedded = TSNE(n_components=2).fit_transform(X)
print(X_embedded.shape)

plt.figure(figsize=(16,10))
sns.scatterplot(
    x=X[:,0], y=X[:,1],
    hue=Y,
    palette=sns.color_palette("hls", 10),
    legend="full",
    alpha=1
)
plt.show()
