import argparse
import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
import clip
import os
import torch
from scipy.stats import wasserstein_distance
import numpy as np
from scipy.optimize import linear_sum_assignment
import numpy as np
import logging
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt




model_names = ['ViT-L/14', 'ViT-L/14@336px']
parser = argparse.ArgumentParser(description='VPL for medical dataset')
parser.add_argument('--data_path', default='/path/IDRiD', type=str,
                    help='dataset path')
parser.add_argument('--type', default='brain', type=str,
                    help='dataset path')
parser.add_argument('-a', '--arch', metavar='ARCH', default='ViT-L/14',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: ViT-L/14)')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 8)')
parser.add_argument('--iters_proxy', default=2000, type=int, metavar='N',
                    help='number of total iterations for learning vision proxy')
parser.add_argument('--iters_sinkhorn', default=20, type=int, metavar='N',
                    help='number of total iterations for optimizing Sinkhorn distance')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256)')

parser.add_argument('--lr', '--learning-rate', default=10, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--tau_t', default=0.01, type=float)
parser.add_argument('--tau_i', default=0.04, type=float)
parser.add_argument('--alpha', default=0.6, type=float)

def main():
    logging.basicConfig(filename='image_classification.log', level=logging.INFO, 
                        format='%(asctime)s %(levelname)s:%(message)s')
    args = parser.parse_args()
    logging.info(f"Running with args: {args}")

    if args.type == "eyes":
        image_classes = ["mild nonproliferative retinopathy", "moderate nonproliferative retinopathy", "no apparent retinopathy",  "proliferative retinopathy", "severe nonproliferative retinopathy"]
        image_2_templates =[
        'The absence of microaneurysms, intraretinal hemorrhages, and hard exudates, along with clear retinal vasculature, indicates {}.',
        'The presence of microaneurysms, mild retinal hemorrhages, mild cotton wool spots, and mild venous beading indicates {}.',
        'Presence of retinal hemorrhages, Moderate cotton wool spots, Intraretinal microvascular abnormalities, Absence of neovascularization indicates {}',
        'The presence of more extensive areas of retinal hemorrhage, venous beading and loops, severe cotton wool spots, and neovascularization suggests {}.',
        'The presence of neovascularization, fibrous proliferation, vitreous hemorrhage, and tractional retinal detachment indicates {}.',
    ]


    image_single_template = [
        'a photo of a {}.',
    ]



    print(args)

    print('load pre-trained model')


    model, preprocess = clip.load(args.arch)
    model = model.cuda()
    model.eval()
    ##########
    print('load data')
    valdir = os.path.join(args.data_path, 'val')
    val_set = datasets.ImageFolder(valdir, transform=preprocess)
    loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, num_workers=args.workers)
    class_to_idx = val_set.class_to_idx
    print("Class to index mapping:", class_to_idx)
    with torch.no_grad():
        image_feat = []
        image_label = []
        for i, (images, target) in enumerate(loader):
            images = images.cuda()
            target = target.cuda()
            image_features = model.encode_image(images)
            image_feat.append(F.normalize(image_features, dim=1))
            image_label.append(target)
            
            
    image_feat = torch.cat(image_feat, dim=0)
    image_label = torch.cat(image_label, dim=0)
    n = len(image_label)
    # import pdb;pdb.set_trace()
    image_feat = image_feat.float() # can keep fp16 for efficiency on GPU

    logging.info('obtain text proxy')
    ########
    text_classifier = zeroshot_classifier(clip, model, image_classes, image_2_templates)

    ########
    text_classifier = text_classifier.float()
    logits_t = image_feat @ text_classifier
    # import pdb;pdb.set_trace()
    acc1, acc5 = accuracy(logits_t, image_label, topk=(1, 2))
    top1 = (acc1 / n) * 100
    logging.info(f"accuracy with text proxy: {top1:.2f}")

    logging.info('obtain vision proxy without Sinkhorn distance')
    plabel = F.softmax(logits_t / args.tau_t, dim=1)
    image_classifier = image_opt(image_feat, text_classifier, plabel, args.lr, args.iters_proxy, args.tau_i, args.alpha)
    logits_i = image_feat @ image_classifier
    acc1, acc5 = accuracy(logits_i, image_label, topk=(1, 2))
    top1 = (acc1 / n) * 100
    logging.info(f"accuracy with image proxy: {top1:.2f}")
    ##############################
    # import pdb;pdb.set_trace()
# Convert tensors to numpy arrays if they are not already
    image_feat_numpy = logits_t.cpu().numpy()
    image_classifier_sinkhorn_numpy = logits_i.cpu().numpy()
    labels = image_label.cpu().numpy()


#     colors = ['#D45361', '#F2A584', '#91C5DA',  '#B1182D', '#FDDAC4']
#     class_names = ["mild nonproliferative retinopathy", "moderate nonproliferative retinopathy", "no apparent retinopathy",  "proliferative retinopathy", "severe nonproliferative retinopathy"]
#
#     tsne = TSNE(n_components=2, random_state=0)
#
#     tsne_results_original = tsne.fit_transform(image_feat_numpy)
#     tsne_results_sinkhorn = tsne.fit_transform(image_classifier_sinkhorn_numpy)
#
#     plt.figure(figsize=(16, 8))
#
#     plt.subplot(1, 2, 1)
#     for i, color in enumerate(colors):
#         plt.scatter(tsne_results_original[labels == i, 0], tsne_results_original[labels == i, 1], c=color, label=class_names[i])
#     plt.title('t-SNE of Original Features')
#     plt.legend()
#
#     plt.subplot(1, 2, 2)
#     for i, color in enumerate(colors):
#         plt.scatter(tsne_results_sinkhorn[labels == i, 0], tsne_results_sinkhorn[labels == i, 1], c=color, label=class_names[i])
#     plt.title('t-SNE of Sinkhorn Optimized Features')
#     plt.legend()
#
#     svg_filename = '/data/image/t-sne/t-sne_comparison_0124.svg'
#     plt.savefig(svg_filename, format='svg')
#     plt.show()

    ###################

    logging.info('barycentric visual proxy')
    image_classifier = barycentric_distance_optimization(image_feat, text_classifier, args.lr, args.iters_proxy)
    logits_i_1 = image_feat @ image_classifier
    acc1, acc5 = accuracy(logits_i_1, image_label, topk=(1, 2))
    top1 = (acc1 / n) * 100
    logging.info(f"accuracy with image proxy + barycentric: {top1:.2f}")

    # import pdb;pdb.set_trace()
    C1 = torch.cdist(image_feat, image_feat, p=2)  # [138, 138]
    C1 = C1.cpu().numpy()

    text_feat = text_classifier.T  #  [2, 768]
    C2_temp = torch.cdist(image_feat, text_feat, p=2)  # [138, 2]
    C2 = torch.cdist(C2_temp, C2_temp, p=2)  # [138, 138]
    C2 = C2.cpu().numpy()

    p = np.ones(C1.shape[0]) / C1.shape[0]
    q = np.ones(C2.shape[0]) / C2.shape[0]


    ###################
    logging.info('obtain refined labels by Sinkhorn distance')
    plabel = sinkhorn(logits_t, args.tau_t, args.gamma, args.iters_sinkhorn)


    logging.info('obtain vision proxy with Sinkhorn distance')
    image_classifier = image_opt(image_feat, text_classifier, plabel, args.lr, args.iters_proxy, args.tau_i, args.alpha)
    logits_i = image_feat @ image_classifier
    # import pdb;pdb.set_trace()
    logits_i_2 = logits_i
    acc1, acc5 = accuracy(logits_i_2, image_label, topk=(1, 2))
    top1 = (acc1 / n) * 100
    logging.info(f"accuracy with image proxy + sinkhorn: {top1:.2f}")


def zeroshot_classifier(clip, model, classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates]
            texts = clip.tokenize(texts).cuda()
            class_embeddings = model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


def image_opt(feat, init_classifier, plabel, lr=10, iter=2000, tau_i=0.04, alpha=0.6):
    ins, dim = feat.shape
    val, idx = torch.max(plabel, dim=1)
    mask = val > alpha
    plabel[mask, :] = 0
    plabel[mask, idx[mask]] = 1
    base = feat.T @ plabel
    classifier = init_classifier.clone()
    pre_norm = float('inf')
    for i in range(0, iter):
        prob = F.softmax(feat @ classifier / tau_i, dim=1)
        grad = feat.T @ prob - base
        temp = torch.norm(grad)
        if temp > pre_norm:
            lr /= 2.
        pre_norm = temp
        classifier -= (lr / (ins * tau_i)) * grad
        classifier = F.normalize(classifier, dim=0)
    return classifier
###############


def barycentric_distance_optimization(image_feat, text_classifier, lr=10, iters=1000):

    text_classifier_mean = text_classifier.mean(dim=1)
    text_classifier_mean = text_classifier_mean.unsqueeze(1).expand(-1, image_feat.size(1))

    image_classifier = text_classifier.clone()
    for _ in range(iters):
        centroid_image = image_feat.mean(dim=0)
        centroid_text = text_classifier_mean.mean(dim=0)
        distance = torch.norm(centroid_image - centroid_text)


        image_classifier -= lr * distance

    return image_classifier

def mahalanobis_distance(x, y, cov):

    x_minus_y = x - y

    distance = torch.sqrt((x_minus_y @ torch.inverse(cov)) @ x_minus_y.T)
    return distance

def mahalanobis_distance_optimization(image_feat, text_classifier, lr, iters):

    image_feat_cpu = image_feat.cpu()
    text_classifier_cpu = text_classifier.cpu()

    cov_matrix = torch.Tensor(np.cov(image_feat_cpu.numpy(), rowvar=False))

    image_classifier = text_classifier_cpu.clone()

    for _ in range(iters):

        for i in range(text_classifier_cpu.shape[1]):
            distance = mahalanobis_distance(image_feat_cpu.mean(dim=0), text_classifier_cpu[:, i], cov_matrix)
            image_classifier[:, i] -= lr * distance

    if text_classifier.is_cuda:
        image_classifier = image_classifier.cuda()

    return image_classifier



# """
# def gromov_wasserstein_distance(C1, C2, p, q, epsilon=1e-3):
#
#     n = C1.shape[0]
#     m = C2.shape[0]
#
#
#     T = np.outer(p, q)
#
#     while True:
#         T_old = T.copy()
#
#         r = np.sum(T, axis=1)
#         c = np.sum(T, axis=0)
#         M = C1 / (r[:, None] + epsilon) + C2 / (c[None, :] + epsilon)
#
#         row_ind, col_ind = linear_sum_assignment(M)
#         T = np.zeros_like(T)
#         T[row_ind, col_ind] = p[row_ind] * q[col_ind]
#
#         if np.linalg.norm(T - T_old) < epsilon:
#             break
#
#     return np.sum(T * M)
# """
def image_opt_gw(image_feat, text_classifier, gw_distance, lr, iters):

    image_classifier = text_classifier.clone()


    for _ in range(iters):

        adjusted_lr = lr * gw_distance
        image_classifier -= adjusted_lr

    return image_classifier

def sinkhorn(M, tau_t=0.01, gamma=0, iter=20):
    
    row, col = M.shape
    log_P = M / tau_t
    log_P -= torch.logsumexp(log_P, dim=1, keepdim=True)

    if gamma > 0:
        q = torch.exp(torch.logsumexp(log_P, dim=0, keepdim=True))
        q = q**gamma
        q /= q.sum()

    for _ in range(iter):
        log_P -= torch.logsumexp(log_P, dim=0, keepdim=True)
        if gamma > 0:
            log_P += torch.log(q)
        else:
            log_P -= torch.log(torch.tensor(col, dtype=log_P.dtype, device=log_P.device))

        log_P -= torch.logsumexp(log_P, dim=1, keepdim=True)
        log_P -= torch.log(torch.tensor(row, dtype=log_P.dtype, device=log_P.device))
        # import pdb;pdb.set_trace()
    return torch.exp(log_P)

def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]


if __name__ == '__main__':
    main()

