import json
import os

import torch
from PIL import Image
import numpy as np
import clip
from loguru import logger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import random
writer = SummaryWriter(log_dir="logs_att")
import time
 
timestamp = time.time()
logger.add(f"test_{timestamp}.log")

train_file_path = "/opt/data/private/llama2/dial2image/finetune_salience/raw_llama3_train_dialogcc_desc_filtered.json"
test_file_path = "/opt/data/private/llama2/dial2image/finetune_salience/raw_llama3_test_dialogcc_desc_filtered.json" 

def image_enhancement(model, image_features, image_patch_features, text_features):
    # image_features [1,512], image_patch [50,512] , text [77,512]
    raw_text_features = text_features
    raw_image_patch_features = image_patch_features
    text_features /= text_features.norm(dim=-1, keepdim=True)
    image_patch_features_norm =  image_patch_features / image_patch_features.norm(dim=-1, keepdim=True)
    # image_patch_features_norm_weighted = torch.matmul(image_patch_features_norm, model.image_enhancement_weight.half())
    text_to_patch_sim = 100 * image_patch_features_norm @ text_features.half().T  
    text_to_patch_sim = text_to_patch_sim.softmax(dim=1)
    mean_sim = torch.mean(text_to_patch_sim, dim=0, keepdim=True)     # N,P,T --> N,P,1
    #weight_images = mean_sim.T @ image_patch_features # N,P,D *+ N,P,1  --> N,1,D
    weight_images = mean_sim @ raw_text_features
    res = torch.add(image_features, weight_images)    # N,D
    return res

class D2IDataset(Dataset):
    def __init__(self, img_path, text_path, is_train, preprocess):
        # 目录
        self.img_path = img_path
        self.text_path = text_path
        self.is_train = is_train
        # 处理图像
        self.img_process = preprocess
        # 获得数据
        self.img_samples = []
        self.photo_descs = []
        self.photo_descs_short = []
        self.dialogue_descs = []
        self.sentence_descs = []
        # 获得所有的样本
        with open(self.text_path, 'r+', encoding='utf-8') as f:
            file = json.load(f)
            for item in file:
                img = f"{item}.jpg"
                p_desc = file[item]['shared_image_caption']
                p_desc_short = file[item]['shared_image_caption']
                d_desc = file[item]['dialogue_description']
                s_desc = file[item]['sentence_description']
                self.img_samples.append(img)
                self.photo_descs.append(p_desc)
                self.photo_descs_short.append(p_desc_short)
                self.dialogue_descs.append(d_desc)
                self.sentence_descs.append(s_desc)
        # 转换为token
        self.photo_tokens = clip.tokenize(self.photo_descs, truncate=True)
        self.dialogue_tokens = clip.tokenize(self.dialogue_descs, truncate=True)
        self.sentence_tokens = clip.tokenize(self.sentence_descs, truncate=True)
        self.photo_short_tokens = clip.tokenize(self.photo_descs_short, truncate=True)

    def __len__(self):
        return len(self.img_samples)

    def __getitem__(self, idx):
        img = self.img_samples[idx]
        photo_tokens = self.photo_tokens[idx]
        photo_short_tokens = self.photo_short_tokens[idx]
        dialogue_tokens = self.dialogue_tokens[idx]
        sentence_tokens = self.sentence_tokens[idx]
        # 加载图像
        image = Image.open(self.img_path + '/' + img)
        # 对图像进行转换
        image = self.img_process(image)
        return image, photo_tokens, photo_short_tokens, dialogue_tokens, sentence_tokens


def evaluate(model, preprocess, text_path, is_train):
    with open(text_path, 'r+', encoding='utf-8') as output:
        test_file = json.load(output)
    count_1 = 0
    count_5 = 0
    count_10 = 0
    keys = list(test_file.keys())
    if is_train:
        image_path = "/root/dial2image/dialogCC/train_photo/"
    else:
        image_path = "/root/dial2image/dialogCC/test_photo/"
    images = []
    photo_descs = []
    sentence_descs = []
    keyword_descs = []
    flag = 0
    for item in test_file:
        p_desc = test_file[item]['shared_image_caption']
        photo_text = p_desc
        photo_text_tokens = clip.tokenize(photo_text, truncate=True).to(device)
        with torch.no_grad():
            photo_text_tokens_features = model.encode_text_tokens(photo_text_tokens)
            image = preprocess(
                Image.open(f'{image_path}{item}.jpg')).unsqueeze(
                0).to(device)
            image_patch_features, image_features = model.encode_image(image)
            images_enhanced = image_enhancement(model, image_features, image_patch_features[0], photo_text_tokens_features[0])
            images.append(images_enhanced)
            sentence_descs.append(test_file[item]['sentence_description'])
            keyword_descs.append(test_file[item]['dialogue_description'])
            photo_desc = p_desc
            photo_descs.append(photo_desc)

    with torch.no_grad():
        text_photo_desc = clip.tokenize(photo_descs, truncate=True).to(device)
        text_photo_desc_features = model.encode_text(text_photo_desc)
        text_photo_desc_features /= text_photo_desc_features.norm(dim=-1, keepdim=True)
        text_sentence_desc = clip.tokenize(sentence_descs, truncate=True).to(device)
        text_sentence_desc_features = model.encode_text(text_sentence_desc)
        text_keyword_desc = clip.tokenize(keyword_descs, truncate=True).to(device)
        text_keyword_desc_features = model.encode_text(text_keyword_desc)
        text_fusion_features = text_sentence_desc_features + text_keyword_desc_features
        text_fusion_features /= text_fusion_features.norm(dim=-1, keepdim=True)
        images = torch.from_numpy(torch.cat(images).cpu().detach().numpy()).to(device)
        images /= images.norm(dim=-1, keepdim=True)

        similarity_balanced = (100.0 * text_fusion_features @ images.T).softmax(dim=0)
        similarity_scene = (100.0 * text_fusion_features @ text_photo_desc_features.T).softmax(dim=0)

    for item in test_file:
        # print(flag)
            
        similarity = similarity_balanced[int(flag), :] + similarity_scene[int(flag), :]
        similarity_5 = similarity
        similarity_10 = similarity
        values, indices = similarity.topk(1)
        vals, inds = similarity_5.topk(5)
        vs, ins = similarity_10.topk(10)
        res_set_1 = set()
        res_set_5 = set()
        res_set_10 = set()
        for value, index in zip(values, indices):
            res_set_1.add(keys[int(index)])

        for val, ind in zip(vals, inds):
            res_set_5.add(keys[int(ind)])

        for v, i in zip(vs, ins):
            res_set_10.add(keys[int(i)])

        if item in res_set_1:
            count_1 += 1

        if item in res_set_5:
            count_5 += 1

        if item in res_set_10:
            count_10 += 1
        flag += 1
    #print(f'Recall@1:{count_1 / 1000.0}')
    #print(f'Recall@5:{count_5 / 1000.0}')
    #print(f'Recall@10:{count_10 / 1000.0}')
    num = len(test_file)
    return count_1 / num, count_5 / num, count_10 / num


if __name__ == '__main__':
    seed = 42
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # 创建模型
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net, preprocess = clip.load("ViT-B/32", device=device)

    # 冻结image encoder参数
    vit_parameters = list(net.visual.parameters())
    for param in vit_parameters:
        param.requires_grad = False

    # with torch.no_grad():
    #     recall1, recall5, recall10 = evaluate(net, preprocess,
    #                                           "/opt/data/private/llama2/dial2image/finetune_salience/train_kw_result_gpt3_all.json",
    #                                           True)
    # logger.info(f"zero-shot-train-set: R@1-{recall1}  R@5-{recall5}  R@10-{recall10}")
    with torch.no_grad():
        test_recall1, test_recall5, test_recall10 = evaluate(net, preprocess, test_file_path, False)
    logger.info(f"zero-shot-test-set: R@1-{test_recall1}  R@5-{test_recall5}  R@10-{test_recall10}")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-5, betas=(0.9, 0.98), eps=1e-6,
                           weight_decay=0.001)
    # optimizer = optim.Adam([
    #             {'params': list(net.parameters())[1:]},
    #             {'params': net.loss_temperature, 'lr': 1e-4}
    #             ], lr=1e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.001)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # 加载数据集
    d2i_dataset = D2IDataset(img_path='/root/dial2image/dialogCC/train_photo',
                             text_path=train_file_path,
                             is_train=True, preprocess=preprocess)
    dataset_size = len(d2i_dataset)
    print(dataset_size)
    batch_size = 56
    train_dataloader = DataLoader(d2i_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False)

    phase = "train"
    model_name = "salience_dialogcc_llama"
    epoches = 30
    for epoch in range(epoches):

        total_loss = 0
        batch_num = 0
        # 使用混合精度，占用显存更小
        with torch.cuda.amp.autocast(enabled=True):
            for images, photo_tokens, photo_short_tokens, dialogue_tokens, sentence_tokens in train_dataloader:
                # 将图片和标签token转移到device设备
                images = images.to(device)
                photo_tokens = photo_tokens.to(device)
                photo_short_tokens = photo_short_tokens.to(device)

                dialogue_tokens = dialogue_tokens.to(device)
                sentence_tokens = sentence_tokens.to(device)
                batch_num += 1
                # 优化器梯度清零
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    # logits_per_image, logits_per_p_desc = net(images, photo_tokens, dialogue_tokens)
                    # cur_loss = (loss_scene(logits_per_image, ground_truth) + loss_vison(logits_per_text, ground_truth)) / batch_size
                    cur_loss = net(images, photo_tokens, photo_short_tokens, dialogue_tokens, sentence_tokens)
                    # logger.info(f"current loss: {cur_loss}")
                    total_loss += cur_loss
                    if phase == "train":
                        cur_loss.backward()
                        if device == "cpu":
                            optimizer.step()
                        else:
                            optimizer.step()
                            clip.model.convert_weights(net)

                # if batch_num % 184 == 0:
                #     logger.info('{} epoch:{} loss:{}'.format(phase, epoch, cur_loss))
            epoch_loss = total_loss / dataset_size
            scheduler.step()
            # torch.save(model.state_dict(), path)
            # model.load_state_dict(torch.load(path))
            # with torch.no_grad():
            #     recall1, recall5, recall10 = evaluate(net, preprocess,
            #                                           "/opt/data/private/llama2/dial2image/finetune_salience/train_kw_result_gpt3_all.json",
            #                                           True)
            # logger.info(f"{epoch}: train-set R@1-{recall1}  R@5-{recall5}  R@10-{recall10}")
            with torch.no_grad():
                test_recall1, test_recall5, test_recall10 = evaluate(net, preprocess,test_file_path, False)
            logger.info(f"{epoch}: test-set R@1-{test_recall1}  R@5-{test_recall5}  R@10-{test_recall10}")
            torch.save(net.state_dict(),
                       f"/opt/data/private/llama2/dial2image/finetune_salience/output/{model_name}_epoch_{epoch}.pth")
            logger.info(f"weights_{epoch} saved")
            # if epoch % ckt_gap == 0:
            #     checkpoint_path = f"{model_name}_ckt.pth"
            #     checkpoint = {
            #         'it': epoch,
            #         'network': net.state_dict(),
            #         'optimizer': optimizer.state_dict(),
            #         'scheduler': scheduler.state_dict()}
            #     torch.save(checkpoint, checkpoint_path)
            #     logger.info(f"checkpoint_{epoch} saved")
            # print('loss_temperature:', net.loss_temperature.item())
            lr = optimizer.param_groups[0]['lr']
            writer.add_scalar("Learning Rate", lr, epoch)
            writer.add_scalar("loss", total_loss.detach(), epoch)
            # writer.add_scalar("eval_recall_1", recall1, epoch)
            # writer.add_scalar("eval_recall_5", recall5, epoch)
            # writer.add_scalar("eval_recall_10", recall10, epoch)
            logger.info('{} Loss: {:.4f}'.format(
                phase, epoch_loss))
    writer.close()
