import sys
import json
from tqdm import tqdm
from compute_em import get_tokens
import collections

# how abstractive are abstractive answers?

def compute_f1(gold_toks, pred_toks):

    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def compute_extractive_upper(preds, data, ex_ids, window_size):
    ex_uppers = []
    assert len(preds) == len(data)
    total = 0
    zero_cnt = 0
    for i, example in enumerate(data):
        if i in ex_ids:
            continue

        f1s_per_example = []
        spans_per_example = []
        # loop through each context 
        for c in example['ctxs']:
            context_toks = get_tokens(c['text'])
            pred_toks = get_tokens(preds[i])
            pred_len = len(pred_toks)

            f1s_per_context = []
            contents_per_context = []
            for ngram in range(1, window_size*pred_len+1):
                for j in range(0, len(context_toks)-ngram+1):
                    gold_toks = context_toks[j:j+ngram]
                    f1s_per_context.append(compute_f1(gold_toks, pred_toks))
                    contents_per_context.append(gold_toks)
            max_f1_per_context = max(f1s_per_context)
            f1s_per_example.append(max_f1_per_context)
            max_span = contents_per_context[f1s_per_context.index(max_f1_per_context)]
            spans_per_example.append(max_span)

        assert len(f1s_per_example) == 100
        max_f1 = max(f1s_per_example)
        max_span_ex = spans_per_example[f1s_per_example.index(max_f1)]
        print('%f\t%s\t%s\t%s'%(max_f1, preds[i], '['+', '.join(pred_toks)+']', '['+', '.join(max_span_ex)+']'))
        ex_uppers.append(max_f1)
        if max_f1 == 0:
            zero_cnt += 1
        total += 1
    print(ex_uppers)
    print('extractive upperbound:', sum(ex_uppers)/float(len(ex_uppers)))
    print('zero cnt:', zero_cnt)
    
if __name__ == '__main__':
    ex_ids = [int(l.strip('\n')) for l in open('../FiD/extractive_indices.txt')]
    data = json.load(open('/data/timchen0618/open_domain_data/NQ/dev.json'))
    preds = [l.strip('\n').split('\t')[1] for l in open('../FiD/pred_dir/dev/final_output.txt')]
    window_size = 2
    compute_extractive_upper(preds, data, ex_ids, window_size)
