import numpy as np

import copy

def action_em(pred_actions, target_actions):
    """
    get em between actions
    """
    positive = 0
    maximum = len(target_actions)

    for action in target_actions:
        if action in pred_actions:
            positive += 1
            pid = pred_actions.index(action)
            tid = target_actions.index(action)
            pred_actions.pop(pid)
            target_actions.pop(tid)

    if maximum == 0:
        return 0

    return positive/float(maximum)

def action_metrics(pred_actions, target_actions, return_counts=False, lower=True):
    """
    get metrics between actions
    em: original paper defines EM as following:  
        EM of 20 means that if there were 100 gold standard valid actions in an instance, the model predicted 20 of them exactly
    """
    pred_actions = copy.deepcopy(pred_actions)
    target_actions = copy.deepcopy(target_actions)
    orig_pa = copy.deepcopy(pred_actions)
    orig_ta = copy.deepcopy(target_actions)
    if lower: # 
        pred_actions = [item.lower() for item in pred_actions]
        target_actions = [item.lower() for item in target_actions]
        orig_pa = [item.lower() for item in orig_pa]
        orig_ta = [item.lower() for item in orig_ta]

    positive = 0 # tp
    #maximum = len(orig_ta)

    for action in orig_ta:
        if action in pred_actions:
            positive += 1
            pid = pred_actions.index(action)
            tid = target_actions.index(action)
            pred_actions.pop(pid)
            target_actions.pop(tid)

    tp = positive
    fp = len(pred_actions) # remaining preds
    fn = len(target_actions) # remaining targets
    #print(tp, fp, fn)

    denom = float(tp + 1/2*float(fp + fn))

    if denom == 0:
        print(orig_pa)
        print(orig_ta)
        f1 = 1.0
    else:
        f1 = tp/float(tp + 1/2*float(fp + fn))

    maximum = fn + fp + tp
    if maximum == 0: # no true positives
        if (fp == 0 and fn == 0):
            em = 1
        else: # NOTE this case doesn't happen
            em = 0
    else:
        em = tp / maximum

    if return_counts:
        return em, f1, (tp, fp, fn)
    else:
        return em, f1


def action_metrics_backup(pred_actions, target_actions, return_counts=False, lower=True):
    """
    get metrics between actions
    em: original paper defines EM as following:  
        EM of 20 means that if there were 100 gold standard valid actions in an instance, the model predicted 20 of them exactly
    """
    pred_actions = copy.deepcopy(pred_actions)
    target_actions = copy.deepcopy(target_actions)
    orig_pa = copy.deepcopy(pred_actions)
    orig_ta = copy.deepcopy(target_actions)
    if lower: # 
        pred_actions = [item.lower() for item in pred_actions]
        target_actions = [item.lower() for item in target_actions]
        orig_pa = [item.lower() for item in orig_pa]
        orig_ta = [item.lower() for item in orig_ta]

    positive = 0 # tp
    maximum = len(orig_ta)

    for action in orig_ta:
        if action in pred_actions:
            positive += 1
            pid = pred_actions.index(action)
            tid = target_actions.index(action)
            pred_actions.pop(pid)
            target_actions.pop(tid)

    tp = positive
    fp = len(pred_actions) # remaining preds
    fn = len(target_actions) # remaining targets
    #print(tp, fp, fn)

    denom = float(tp + 1/2*float(fp + fn))

    if denom == 0:
        print(orig_pa)
        print(orig_ta)
        f1 = 1.0
    else:
        f1 = tp/float(tp + 1/2*float(fp + fn))

    if maximum == 0:
        if len(orig_pa) == 0:
            em = 1
        else:
            em = 0
    else:
        em = positive/float(maximum)

    if return_counts:
        return em, f1, (tp, fp, fn)
    else:
        return em, f1


def print_test_metrics(metrics, save_name):
    outfile = open(f"/home/mnskim/workspace/tbg/tbg1/results/metrics/{save_name}.csv",'w')
    test_games = ['zork1','library','detective','balances','pentari','ztuu','ludicorp','deephome','temple']
    for task in ['graph', 'action']:
        outfile.write(f"{task}\n")
        for met in ['em', 'f1']:
            outfile.write(f"{met}\n")
            out_line = ''
            for game in test_games:
                out_line += f"{str(np.mean(metrics[game][task][met]))},"
            out_line = out_line.rstrip(',')
            outfile.write(out_line+'\n')

