#!/usr/bin/env python

import argparse
import csv
import os
import re
import sys

from collections import defaultdict
import nltk

# Create n defaultdicts of the corresponding type
def make_dicts(n, datatype):
    out = []
    for i in range(n):
        out.append(defaultdict(datatype))
    return out

# Compute unweighted precision/recall against the majority expert or crowd judgments
def precision_recall(data, syslist, use_crowd=False):
    hits, misses, false_positives = make_dicts(3, int)
    for row in data:        
        
        # Get the various fields in the data row        
        unitid, prep, sentence, preplocation, sys1pred, sys2pred, internal, crowd = row

        # Use an NLTK FreqDist to count the ratings
        fd = nltk.FreqDist()
        
        # use expert/crowd ratings depending on what's asked for
        ratings = crowd if use_crowd else internal
        
        # make sure to ignore any 'Ungram' (indicated by '2') ratings
        ratings = [r for r in ratings.split('|') if r != '2']
        
        # ignore instances where the experts might have all assigned 
        # 'Ungram' rating
        if not ratings:
            continue
            
        # count the ratings and get the majority rating
        fd.update(ratings)
        majority_rating = fd.max()
            
        # compute the hits/misses/false positives for the requested systems
        if majority_rating == '1':
            for sys in syslist:
                pred = eval(sys + 'pred')
                if pred == '1':
                    hits[sys] += 1
                elif pred == '3':
                    misses[sys] += 1
        elif majority_rating == '3':
            for sys in syslist:
                pred = eval(sys + 'pred')
                if pred == '1':
                    false_positives[sys] += 1
                    
    # return the hits/misses/falsepositives 
    return hits, misses, false_positives

# compute weighted precision recall against expert or crowd judgments
def weighted_precision_recall(data, syslist, use_crowd=False):
    hits, misses, false_positives = make_dicts(3, float)
    for row in data:

        # Get the various fields in the data row        
        unitid, prep, sentence, preplocation, sys1pred, sys2pred, internal, crowd = row

        # Use an NLTK FreqDist to count the ratings
        fd = nltk.FreqDist()
        
        # use expert/crowd ratings depending on what's asked for
        ratings = crowd if use_crowd else internal
        
        # make sure to ignore any 'Ungram' (indicated by '2') ratings
        ratings = [r for r in ratings.split('|') if r != '2']
        
        # ignore instances where the experts might have all assigned 
        # 'Ungram' rating
        if not ratings:
            continue
            
        # count the ratings            
        fd.update(ratings)
        
        # compute the proportions of the two classes (Error and OK)
        prop_ones = float(fd['1'])/(fd['1'] + fd['3'])
        prop_threes = float(fd['3'])/(fd['1'] + fd['3'])
        
        # Compute the weighted hits/misses/false positives for the requested systems
        for sys in syslist:
            pred = eval(sys + 'pred')
            if pred == '1':
                hits[sys] += prop_ones
                false_positives[sys] += prop_threes
            elif pred == '3':
                misses[sys] += prop_ones

    # return the hits/misses/false positives
    return hits, misses, false_positives

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Compute precision/recall for two given systems.", prog='compute_precision_recall.py')
    parser.add_argument('inputfile', help='the input file containing the records')
    parser.add_argument('-s', '--system', dest='systems', default="both", choices=["1", "2", "both"], help='compute for system 1, 2 or both (default:BOTH)')
    parser.add_argument('-w', '--weighted', dest='weighted', default=False, action="store_true", help='compute weighted precision/recall (default:FALSE)')
    parser.add_argument('-c', '--crowd', dest='crowd', default=False, action="store_true", help='compute against crowd judgments (default: FALSE)')

    # Parse the command line arguments
    args = parser.parse_args()
    precision, recall = make_dicts(2, float)

    # what is the list of systems that we need to compute precision/recall for
    if args.systems == "1":
        syslist = ['sys1']
    elif args.systems == "2":
        syslist = ['sys2']
    if args.systems == "both":
        syslist = ['sys1', 'sys2']

    # compute unweighted/weighted precision & recall against expert/crowd judgments
    # depending on the command line switches
    r = csv.reader(file(args.inputfile))
    rows = list(r)[1:]
    if not args.weighted:
        hits, misses, false_positives = precision_recall(rows, syslist, args.crowd)
    else:
        hits, misses, false_positives = weighted_precision_recall(rows, syslist, args.crowd)

    # print out the results
    for sys in syslist:
        precision[sys] = float(hits[sys])/(hits[sys] + false_positives[sys])
        recall[sys] = float(hits[sys])/(hits[sys] + misses[sys])
        print ' %s => Precision: %.4f, Recall: %.4f' % (sys, precision[sys], recall[sys])


