import numpy as np
from sklearn.metrics import ndcg_score
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils import *
from baseline.generation_utils import *


def cosine_similarity(vec1, vec2):
    '''
    Compute cosine similarity between two embeddings.
    '''
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return np.dot(vec1, vec2) / (norm1 * norm2)

def compute_clip_score(image_index, 
                       evidence_index_list,
                       image_embeddings,
                       clip_evidence_embeddings):
    '''
    Compute the similarity between an image and a list of textual evidence
    Params:
        image_index (int): the index of the image in the image embedding matrix.
        evidence_index_list (list): a list containing the indexes of the evidence in the evidence embedding matrix.
        image_embeddings (numpy.array): a matrix containing image embeddings (computed with CLIP)
        clip_evidence_embeddings (numpy.array): a matrix containing evidence embeddings (computed with CLIP)
    '''
    image = image_embeddings[image_index]
    similarities = []

    for idx in evidence_index_list:
        sim = cosine_similarity(image, clip_evidence_embeddings[idx])
        similarities.append((idx, sim))
    return [score for _, score in similarities]


def generate_ngrams(text, n=3):
    '''
    Generate n-grams from a given text 
    '''
    try:
        vectorizer = CountVectorizer(ngram_range=(1,n),stop_words='english')
        vectorizer.fit_transform([text])
        return set(vectorizer.get_feature_names_out())
    except:
        return None

def ngram_overlap_score(passage, answer, n=2):
    '''
     Compute n-gram overlap score
    '''
    passage_ngrams = generate_ngrams(passage, n)
    answer_ngrams = generate_ngrams(answer, n)
    if answer_ngrams and passage_ngrams:
        overlap = passage_ngrams.intersection(answer_ngrams)
        return len(overlap) / max(len(passage_ngrams), len(answer_ngrams), 1)
    else:
        return 0

def eval_ranking(evidence,answer,predicted_ranking):
    '''
    Compute the ndcg score between a predicted ranking and the target ranking.
    '''
    ngram_scores = [ngram_overlap_score(e, answer) for e in evidence] 
    target_ranking = np.array(ngram_scores)
    if len(target_ranking)>1 and len(target_ranking)==len(predicted_ranking):
        ndcg = ndcg_score([target_ranking], [predicted_ranking])
        return ndcg
    else:
        return None
    

def get_ndcg_score(dataset, task, evidence, image_embeddings_map,sort_with_date=False):
    '''
    Reports the NDCG for a specifc ranking method on a specific task.
    '''
    total_ndcg = 0
    count = 0
    img_corpus = [image['image path'] for image in dataset]
    ground_truth = [image[task] for image in dataset]
    for i in range(len(img_corpus)):
        evidence = [ev for ev in evidence if ev['image path']==img_corpus[i]]
        evidence_index = [evidence.index(ev) for ev in evidence if ev['image path']==img_corpus[i]]
        #Retrieve the index of the image in the embedding matrix
        image_index = int(image_embeddings_map[img_corpus[i]])
        if sort_with_date:
            date_sort = pd.Series(evidence['date']).reset_index().sort_values(by='date',ascending=False).index.to_list()
            predicted_ranking = [date_sort.index(i) for  i in evidence.reset_index().index.to_list()]
        else:   
            predicted_ranking  = compute_clip_score(image_index,evidence_index,image_embeddings, clip_evidence_embeddings)
        evidence_text = [ text[2:] for text in get_evidence_prompt(evidence).split('Evidence ')[1:]]
        if len(evidence)>3:
            ndcg = eval_ranking(evidence_text,ground_truth[i],predicted_ranking )
            if ndcg !=None:
                total_ndcg +=ndcg
                count+=1

    return round(100*total_ndcg/count,2)
    

if __name__=='__main__':
    clip_evidence_embeddings = np.load('dataset/embeddings/evidence_embeddings.npy')
    image_embeddings = np.load('dataset/embeddings/image_embeddings.npy')
    image_embeddings_map = load_json('dataset/embeddings/image_embeddings_map.json')
    evidence = load_json('dataset/retrieval_results/evidence.json')

    test = load_json('dataset/test.json')
    source = [t['source'] for t in test if t['source']!='not enough information']
    date = [t['date numeric label'] for t in test if t['date numeric label']!='not enough information']
    location = [t['location'] for t in test if t['location']!='not enough information']
    motivation = [t['motivation'] for t in test if t['motivation']!='not enough information']

    print('-----------')
    print('Time ranking')
    print('-----------')
    print('Source %s'%get_ndcg_score(source,'source',evidence,image_embeddings_map,True))
    print('Date %s'%get_ndcg_score(source,'date numeric label',evidence,image_embeddings_map,True))
    print('Location %s'%get_ndcg_score(source,'location',evidence,image_embeddings_map,True))
    print('Motivation %s'%get_ndcg_score(source,'motivation',evidence,image_embeddings_map,True))
    print('-----------')
    print('CLIP ranking')
    print('-----------')
    print('Source %s'%get_ndcg_score(source,'source',evidence,image_embeddings_map,False))
    print('Date %s'%get_ndcg_score(source,'date numeric label',evidence,image_embeddings_map,False))
    print('Location %s'%get_ndcg_score(source,'location',evidence,image_embeddings_map,False))
    print('Motivation %s'%get_ndcg_score(source,'motivation',evidence,image_embeddings_map,False))   
