from evaluate import load
from dateutil import parser
import numpy as np
from scipy.optimize import linear_sum_assignment
from dateutil.tz import tzutc
from geonames_collection import *
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import ast
from dateutil.relativedelta import relativedelta
from haversine import haversine, Unit
from itertools import combinations
from utils import *
from dataset_collection.preprocessing_utils import *



#Load the metrics
meteor = load('meteor')
rouge = load('rouge')


def convert_to_date(date_str):
    try:
        return parser.parse(date_str)
    except (ValueError, TypeError):
        return None


def date_distance(dt1, dt2):
    if dt1.tzinfo is None:
        dt1 = dt1.replace(tzinfo=tzutc())
    if dt2.tzinfo is None:
        dt2 = dt2.replace(tzinfo=tzutc())
    if dt1 is None or dt2 is None:
        return float('inf')

    delta = relativedelta(dt1, dt2)
    return abs(delta.years + delta.months / 12 + delta.days / 365.25)


def location_coordinate_distance(coordinates1,coordinates2,unit=1000):
    '''
    Compute the coordinates distance between the prediction and ground truth. 
    Compare all pairs of GeoNames entities and take the smallest distance as optimist heuristic.
    '''
    d = min([haversine(c1,c2,unit=Unit.KILOMETERS) for c1 in coordinates1 for c2 in coordinates2])
    d /= unit
    return d


def hierarchical_distance_metric(pred_hierarchy, gt_hierarchy):
    if  all(i in pred_hierarchy for i in gt_hierarchy):
        return 0
    else:
        common_length = 0
        for p, g in zip(pred_hierarchy, gt_hierarchy):
            if p == g:
                common_length += 1
            else:
                break
        return len(pred_hierarchy) + len(gt_hierarchy) - 2 * common_length


def location_hierarchy_distance(hierarchy1,hierarchy2):
    d = min([hierarchical_distance_metric(h1,h2) for h1 in hierarchy1 for h2 in hierarchy2])
    return d


def is_strict_subset(sublist, mainlist):
    return set(sublist).issubset(set(mainlist)) and len(sublist) < len(mainlist)


def find_locations_to_remove(l):
    indices_to_remove = []
    
    def contains_strict_subset(outer_list, other_lists):
        for sublist in outer_list:
            for other_list in other_lists:
                for other_sublist in other_list:
                    if is_strict_subset(sublist, other_sublist):
                        return True
        return False

    for i, outer_list in enumerate(l):
        if contains_strict_subset(outer_list, [other_list for j, other_list in enumerate(l) if i != j]):
            indices_to_remove.append(i)
    
    return indices_to_remove


def evaluate(prediction, 
             ground_truth, 
             task, 
             NER_model = None,
             geonames_data=None,
             geonames_username=None,
             sleep_geonames=2):
    #Source, Motivation
    if task in ["source" "motivation"]:
        if not isinstance(prediction, list):
            prediction = [prediction]
        if not isinstance(ground_truth, list):
            ground_truth = [ground_truth]
        rouge_result = rouge.compute(predictions=prediction, references=[ground_truth])['rougeL']
        meteor_result = meteor.compute(predictions=prediction, references=[ground_truth])['meteor']
        return {'rougeL':rouge_result,"meteor": meteor_result}
    
    #Location
    elif task=="location":
        if not isinstance(prediction, list):
            prediction = [prediction]
        if not isinstance(ground_truth, list):
            ground_truth = [ground_truth]
        rouge_result = rouge.compute(predictions=prediction, references=[ground_truth])['rougeL']
        meteor_result = meteor.compute(predictions=prediction, references=[ground_truth])['meteor']
        return {'rougeL':rouge_result,"meteor": meteor_result}
    
    elif task == "location NER": #separate task for thoes metrics
            geonames_entries = list(set([d['query'].lower() for d in load_json(geonames_data)]))
            prediction_location_NER = [l for l in extract_named_entities(prediction,'locations')]
            prediction_coordinates = []
            prediction_hierarchies = []
            matching_records = []
            #Prepare the predictions
            for p in prediction_location_NER:
                if p.lower() not in geonames_entries: 
                        #Add a new entry to the collected GeoName database if the prediction is not there yet
                        matching_records = search_location(p,geonames_username,sleep_geonames)
                        time.sleep(sleep_geonames)
                        save_result(matching_records,geonames_data)
                else:
                    matching_records = [d for d in geonames_data if 'coordinates' in d.keys() and d['query'].lower()==p.lower()]
                if len(matching_records) > 0 :            
                    prediction_coordinates.append([r['coordinates'] for r in matching_records])
                    prediction_hierarchies.append([r['hierarchy'] for r in matching_records])
            ground_truth_location_NER = [l for l in ast.literal_eval(ground_truth)] 
            ground_truth_coordinates = []
            ground_truth_hierarchies = []
            for g in ground_truth_location_NER:
                matching_records = [d for d in geonames_data if 'coordinates' in d.keys() and d['query'].lower()==g.lower()]
                if len(matching_records) > 0 : 
                    ground_truth_hierarchies.append([r['hierarchy'] for r in matching_records])
                    ground_truth_coordinates.append([r['coordinates'] for r in matching_records])
            idx_to_remove  = find_locations_to_remove(ground_truth_hierarchies)
            ground_truth_coordinates = [ground_truth_coordinates[i] for i in range(len(ground_truth_coordinates)) if i not in idx_to_remove]
            ground_truth_hierarchies = [ground_truth_hierarchies[i] for i in range(len(ground_truth_hierarchies)) if i not in idx_to_remove]
            if len(prediction_coordinates) > 0:
                
                if len(prediction_coordinates) > len(ground_truth_coordinates):
                    # Generate all combinations of size up to x
                    candidates = []
                    size = len(ground_truth_coordinates)
                    candidates.extend(combinations(prediction_coordinates, size))
                else: 
                    candidates = [prediction_coordinates]

                best_codelta = 0
                for candidate in candidates:
                #We find the minimal distance among all pairs
                    distances = np.array([[location_coordinate_distance(pc, gc) for gc in ground_truth_coordinates] for pc in candidate])           
                    row_ind, col_ind = linear_sum_assignment(distances)
                    scores = 0
                    non_zero_distance_list = [distances[r,c] for r, c in zip(row_ind, col_ind)]
                    non_zero_distance_list = sorted(non_zero_distance_list)
                    for d in non_zero_distance_list:
                        scores += 1/(1+d)
                    
                    coefficient = 1/len(ground_truth_coordinates) 
                    codelta = coefficient *scores
                    if codelta > best_codelta:
                        best_codelta = codelta
                
            else:
                best_codelta = 0

            if len(prediction_hierarchies) > 0:
                
                if len(prediction_hierarchies) > len(ground_truth_hierarchies):
                    # Generate all combinations of size up to x
                    candidates = []
                    size = len(ground_truth_hierarchies)
                    candidates.extend(combinations(prediction_hierarchies, size))
                else: 
                    candidates = [prediction_hierarchies]

                best_hierarchy_delta = 0
                for candidate in candidates:
                    distances = np.array([[location_hierarchy_distance(pc, gc) for gc in ground_truth_hierarchies] for pc in candidate])          
                    row_ind, col_ind = linear_sum_assignment(distances)
                    scores = 0
                    non_zero_distance_list = [distances[r,c] for r, c in zip(row_ind, col_ind)]
                    non_zero_distance_list = sorted(non_zero_distance_list)
                    for d in non_zero_distance_list:
                        scores += 1/(1+d)
                    coefficient = 1/len(ground_truth_hierarchies) 
                    hierarchy_delta  = coefficient *scores
                    if hierarchy_delta > best_hierarchy_delta :
                        best_hierarchy_delta = hierarchy_delta 
            else:
                best_hierarchy_delta  = 0
            return {"codelta": best_codelta, "hldelta": best_hierarchy_delta}
    #Date
    elif task == "date":
        if prediction!='':
            if prediction[0]=='[':
                prediction = prediction[1:-1]
        prediction_dates = extract_named_entities(prediction, NER_model,'dates_and_times')
        prediction_dates = [convert_to_date(date_str) for date_str in prediction_dates]
        prediction_dates = [d for d in prediction_dates if d is not None]
        ground_truth_dates = [convert_to_date(date_str) for date_str in ast.literal_eval(ground_truth)]
        if len(ground_truth_dates) > 0 and len(prediction_dates) > 0:
            if len(prediction_dates) > len(ground_truth_dates):
                # Generate all combinations of size up to x
                candidates = []
                size = len(ground_truth_dates)
                candidates.extend(combinations(prediction_dates, size))
            else: 
                candidates = [prediction_dates]
            best_delta = 0
            best_EM = 0
            for candidate in candidates:
                distances = np.array([[date_distance(pd, gd) for gd in ground_truth_dates] for pd in candidate])          
                row_ind, col_ind = linear_sum_assignment(distances)
                scores = 0
                non_zero_distance_list = [distances[r,c] for r, c in zip(row_ind, col_ind)]
                non_zero_distance_list = sorted(non_zero_distance_list)
                for d in non_zero_distance_list:
                    scores += 1/(1+d)
                exact_match = np.all(distances[row_ind, col_ind] == 0)
                coefficient = 1/len(ground_truth_dates) 
                delta = coefficient *scores
                if delta > best_delta:
                    best_delta = delta
                    best_EM = exact_match
            
            if len(prediction_dates) > len(ground_truth_dates):
                best_EM=0

            return {"exact_match": best_EM, "delta": best_delta}
        else:
            return {"exact_match": 0, "delta": 0}   
    else:
        raise ValueError("Invalid task name")