import re
import utils
import random
import json
import os
import sys
import h5py
import time
import argparse
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset, Dataset
import _pickle as cPickle
from tqdm import tqdm
from xml.etree.ElementTree import parse
from dataset import Indexer, WordEmbeddings, Flickr30dataset, read_word_embeddings, load_train_flickr30k
from utils import largest, confidence, union, bbox_is_match, get_match_index, my_load_flickr30k, Evaluator, union_target
from model import NN
import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=FutureWarning)
    import h5py

def train(model, train_loader, test_loader, batch, lr=1e-4, epochs=25, use_bert = False, lite_bert = False):
    use_gpu = torch.cuda.is_available()
    model = model.float()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr = lr)
    ceLoss = nn.CrossEntropyLoss(reduction="mean")
    bceLoss = nn.BCEWithLogitsLoss()
    rankingLoss = nn.MarginRankingLoss(margin=0.2)
    cosembLoss = nn.CosineEmbeddingLoss()

    print("---Before Training...")
    score = model_eval(test_loader, model, use_bert, lite_bert)
    print("     eval score on test dataset:", score)

    for epoth in range(epochs):
        t= time.time()
        total_loss = 0
        correct_preds = 0
        all_preds = 0

        all_hits = 0
        all_counts = 0
        n_batches = 0

        # if True:
        '''
            torch.Size([16, 1]) torch.Size([16, 64]) torch.Size([16, 32, 12]) torch.Size([16, 64, 5])
            torch.Size([16, 32, 12, 5]) torch.Size([16, 1]) torch.Size([16])
        '''
        for idx, labels, attrs, feature, query, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, bboxes, target_bboxes, num_obj, num_query in tqdm(train_loader):
            if (use_gpu):
                idx, labels, attrs, feature, query, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, bboxes, target_bboxes, num_obj, num_query = \
                    idx.cuda(), labels.cuda(), attrs.cuda(), feature.cuda(), query.cuda(), label_feats.cuda(), sent_toks.cuda(), phrase_toks.cuda(), entity_indices.cuda(), entity_feats.cuda(), bboxes.cuda(), target_bboxes.cuda(), num_obj.cuda(), num_query.cuda()
            n_batches+=1
            
            model.train(True)
            optimizer.zero_grad()
            probs, target, att_obj_sum = model.forward(idx, query, labels, feature, attrs, bboxes, num_obj, num_query, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, use_bert, lite_bert)  # [B, B]

            target_pred = torch.argmax(target, dim=1) # [B]
            prediction = torch.argmax(probs, dim=1) # [all_querys]
            correct_preds += int(prediction.eq(target_pred).sum())
            all_preds += len(prediction)

            loss = ceLoss(probs, target_pred)

            total_loss+=loss
            loss.backward()
            optimizer.step()

        t1 = time.time()
        print("--- EPOCH", epoth)
        print("     time:", t1-t)
        print("     total loss:", total_loss.item()/n_batches)
        print("     supervised accuracy on training set: ", correct_preds/all_preds)
        t2 = time.time()
        score, supacc = model_eval(test_loader, model, use_bert, lite_bert)
        print("     eval time:", time.time()-t2)
        print("     supervised accuracy on test dataset:", supacc)
        print("     eval score on test dataset:", score)


def model_eval(train_loader, model, use_bert = False, lite_bert = False):
    t= time.time()
    use_gpu = torch.cuda.is_available()
    all_hits = 0
    all_counts = 0

    correct_preds = 0
    all_preds = 0

    model = model.float()

    # record = []
    pred_bboxes_list = []
    target_bboxes_list = []
    num_query_list = []

    for idx, labels, attrs, feature, query, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, bboxes, target_bboxes, num_obj, num_query in tqdm(train_loader):
        if (use_gpu):
                idx, labels, attrs, feature, query, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, bboxes, target_bboxes, num_obj, num_query = \
                    idx.cuda(), labels.cuda(), attrs.cuda(), feature.cuda(), query.cuda(), label_feats.cuda(), sent_toks.cuda(), phrase_toks.cuda(), entity_indices.cuda(), entity_feats.cuda(), bboxes.cuda(), target_bboxes.cuda(), num_obj.cuda(), num_query.cuda()

        model.eval()

        pred_bboxes, pred_labels, probs, target = model.predict(idx, query, labels, feature, attrs, num_obj, num_query, bboxes, label_feats, sent_toks, phrase_toks, entity_indices, entity_feats, use_bert, lite_bert) # [B, 32, 4]

        # sup acc
        target_pred = torch.argmax(target, dim=1) # [B]
        prediction = torch.argmax(probs, dim=1) # [all_querys]
        correct_preds += int(prediction.eq(target_pred).sum())
        all_preds += len(prediction)

        pred_bboxes_list += pred_bboxes.cpu().tolist()
        target_bboxes_list += target_bboxes.cpu().tolist()
        num_query_list +=num_query.cpu().tolist()

    # print(len(pred_bboxes_list), len(target_bboxes_list), len(num_query_list))
    acc = evaluator_acc(pred_bboxes_list, target_bboxes_list, num_query_list)
    supacc = correct_preds/all_preds
    
    return acc, supacc


def evaluator_acc(pred_bboxes, target_bboxes, num_query):
    evaluator = Evaluator()
    gtbox_list=[]
    pred_list=[]
    for ipred, itarget, nq in zip(pred_bboxes, target_bboxes, num_query):
        # ipred: [query, 5]
        # itarget: [query, 12, 4]
        if nq > 0:
            gtbox_list += union_target(itarget[:nq]) # [query, 4]
            pred_list += ipred[:nq]

    accuracy, _ = evaluator.evaluate(pred_list, gtbox_list) # [query, 4]
    return accuracy


def calculate_score(pred_bboxes, target_bboxes, num_query):
    '''
    pred: [B, query, 4]
    target: [B, query, 12, 5]
    num_query: [B]
    '''
    match_sum=0
    for ipred, itarget, nq in zip(pred_bboxes, target_bboxes, num_query):
        for j in range(nq):
            union_target = union(itarget[j]) # [4]
            print(union_target)
            ismatch = bbox_is_match(ipred[j], [union_target])
            if ismatch:
                match_sum+=1
    all_query = sum(num_query)
    return match_sum, all_query


def w2v_att(query, det, glove):
    query = query.lower().split()

    softmax = nn.Softmax(dim=1)
    cos = nn.CosineSimilarity()

    if glove:
        q_emb = torch.from_numpy(glove.get_embeddings(query))
        k_emb = torch.from_numpy(glove.get_embeddings(det))
    else:
        q_emb = torch.from_numpy(model_fast.wv[query])
        k_emb = torch.from_numpy(model_fast.wv[det])

    # print(q_emb.shape, k_emb.shape)
    scale = 1.0/np.sqrt(q_emb.size(-1))
    att = torch.matmul(q_emb, k_emb.transpose(0,1))
    # mask = (att==0)
    # att.masked_fill_(mask, -float('inf'))
    att = softmax(att.mul_(scale))
    # att[torch.isnan(att)] = 0

    max_att = torch.max(att, dim=1).values
    max_norm_att = max_att.div(max_att.sum()).unsqueeze(0)

    p_emb = torch.matmul(max_norm_att, q_emb).repeat(len(det), 1)
    # print("att phrase embedding", p_emb[0])

    sim = cos(p_emb, k_emb).unsqueeze(0)

    return sim.tolist()[0]


def load_entries(name='train'):
    dataroot='data/flickr30k/'

    img_id2idx = cPickle.load(
        open(os.path.join(dataroot, '%s_imgid2idx.pkl' % name), 'rb'))
    h5_path = os.path.join(dataroot, '%s.hdf5' % name)

    print('loading features from h5 file...')
    with h5py.File(h5_path, 'r') as hf:
        features = np.array(hf.get('image_features'))
        spatials = np.array(hf.get('spatial_features'))
        bbox = np.array(hf.get('image_bb'))
        pos_boxes = np.array(hf.get('pos_boxes'))

    print("load flickr30k data successfully.")
    entries = my_load_flickr30k(dataroot, img_id2idx, bbox, pos_boxes)
    return entries


def entries_id2img(train_entries):
    train_id2img = {}
    for entry in train_entries:
        img_id = entry['image']
        if img_id not in train_id2img.keys():
            train_id2img[img_id] = [entry]
        else:
            train_id2img[img_id].append(entry)
    return train_id2img

# unsupervised evaluation
def evaluate(object_detect, eval_entries, glove, strategy = 'largest'):
    print("start calculating....")
    total_image = 0
    total_entity2img = 0
    correct=0
    correct_all = 0
    hit_entity = 0
    match_info = {}

    object_detect = sorted(object_detect, key=lambda x:x["image"].split('.')[0])

    # f = open("att_result.json", "w")

    for img in object_detect: # one picture
        img_id = int(img["image"].split('.')[0])
        true_match={}
        false_match = {}

        if img_id in eval_entries.keys(): # has a groung truths
            total_image+=1
            target_entries = eval_entries[img_id]
        else:
            continue

        det_objects = img["objects"]
        det_bboxes = []
        dic_class2bbox={} # {class1: [box1, box2,...];...}
        dic_class2score={} #  {class1: [score1, score2,...];...}
        for d in det_objects:
            box_str = d["bbox"].strip('(').strip(')')
            bbox = box_str.split(',')
            bbox = [int(i) for i in bbox]
            det_bboxes.append(bbox)
            if d["class"] not in dic_class2bbox.keys():
                dic_class2bbox[d["class"]] = [bbox]
                dic_class2score[d["class"]] = [float(d["score"])]
            else:
                dic_class2bbox[d["class"]].append(bbox)
                dic_class2score[d["class"]].append(float(d["score"]))

        labels = list(dic_class2bbox.keys())
        for target_entry in target_entries:
            for i, entity in zip(target_entry["entity_ids"], target_entry["entity_names"]): # entity is given
                total_entity2img+=1

                if len(labels)==0:
                    continue
                w2v = w2v_att(entity, labels, glove)

                # use max similarity score
                if len(w2v)>0:
                    score, idx = np.max(w2v), np.argmax(w2v) # idx is the index of label
                else:  # dic_class2bbox is empty??
                    continue

                pred_class = labels[idx]
                print(entity, pred_class)
                pred_bboxes = dic_class2bbox[pred_class] # bbox in obj_detection.json
                scores = dic_class2score[pred_class]

                # select one from pred bboxes
                if strategy=="largest":
                    pred_bbox = largest(pred_bboxes)
                elif strategy == "confidence":
                    pred_bbox = confidence(scores, pred_bboxes)
                elif strategy == "union":
                    pred_bbox = union(pred_bboxes)
                else:
                    pred_bbox = random.sample(pred_bboxes, 1)[0]

                match = get_match_index([pred_bbox], target_entry['target_bboxes'][i])
                match_all = get_match_index(pred_bboxes, target_entry['target_bboxes'][i])

                if len(match)>0:
                    true_match[entity] = pred_class
                    correct+=1
                else:
                    false_match[entity] = pred_class
                if len(match_all)>0:
                    correct_all+=1

    print('# of images', total_image)
    print('# of entity-image pairs', total_entity2img)
    print("# of correct", correct)
    print('# of correct predict in a image:', correct/total_image)
    print("acc:", 1.0*correct/total_entity2img)
    print("acc upper bound:", 1.0 * correct_all/total_entity2img)

