from common.utils import *
from orig_eval_squad import metric_max_over_ground_truths, exact_match_score, f1_score
import random

def extract_worst_attack(dataset, predictions):

    orig_f1_score = 0.0
    orig_exact_match_score = 0.0
    adv_f1_scores = {}  # Map from original ID to F1 score
    adv_exact_match_scores = {}  # Map from original ID to exact match score
    adv_ids = {}
    all_ids = set()  # Set of all original IDs
    f1 = exact_match = 0

    kept_article_by_id = {}
    for article in dataset:
        for i_par, paragraph in enumerate(article['paragraphs']):
            for qa in paragraph['qas']:
                if not '-' in qa['id']:
                    # it is not adversarial attack
                    continue
                orig_id = qa['id'].split('-')[0]
                # for adverarial attacl there is only one paragraph, on qas in the article
                # assert len(article['paragraphs']) == 1
                assert len(paragraph['qas']) == 1

                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']][0]['text']
                cur_f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths)

                if orig_id in kept_article_by_id:
                    _, pred, prev_f1 = kept_article_by_id[orig_id]
                else:
                    prev_f1 = 10000.0
                if cur_f1 < prev_f1:
                    adv_article = {'title': f"{article['title']}-adv{i_par}", 'paragraphs': [paragraph] }
                    kept_article_by_id[orig_id] = (adv_article, prediction, cur_f1)
    print(len(kept_article_by_id))
    return kept_article_by_id

def sample_addone(dataset, predictions):
    print(dataset.keys())
    orig_scores = extract_orig_scores(dataset['data'], predictions)
    adv_articles = extract_worst_attack(dataset['data'], predictions)

    flipped_articles = []
    still_articles = []
    for qid in adv_articles:
        adv_article, worst_prediction, adv_f1 = adv_articles[qid]
        
        if adv_f1 < orig_scores[qid]:
            flipped_articles.append(adv_article)
        else:
            still_articles.append(adv_article)

    random.seed(123)
    
    random.shuffle(flipped_articles)
    random.shuffle(still_articles)
    
    sampled_articles = flipped_articles[:10] + still_articles[:10]
    
    
    sampled_dataset = {'version': dataset['version'], 'data': sampled_articles}
    dump_json(sampled_dataset, 'outputs/sample-addone-dev_squad.json')
    

# addone sample
# for article in dataset:
#     for paragraph in article['paragraphs']:
#     #   for qa in paragraph['qas']:
#     #     orig_id = qa['id'].split('-')[0]
#     #     if id_set and orig_id not in id_set: continue
#         all_ids.add(orig_id)
#         if qa['id'] not in predictions:
#           message = 'Unanswered question ' + qa['id'] + ' will receive score 0.'
#           print(message, file=sys.stderr)
#           continue
#         ground_truths = list(map(lambda x: x['text'], qa['answers']))
#         prediction = predictions[qa['id']]
#         cur_exact_match = metric_max_over_ground_truths(exact_match_score,
#                                                         prediction, ground_truths)
#         cur_f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths)
#         if orig_id == qa['id']:
#           # This is an original example
#           orig_f1_score += cur_f1
#           orig_exact_match_score += cur_exact_match
#           if orig_id not in adv_f1_scores:
#             # Haven't seen adversarial example yet, so use original for adversary
#             adv_ids[orig_id] = orig_id
#             adv_f1_scores[orig_id] = cur_f1
#             adv_exact_match_scores[orig_id] = cur_exact_match
#         else:
#           # This is an adversarial example
#           if (orig_id not in adv_f1_scores or adv_ids[orig_id] == orig_id 
#               or adv_f1_scores[orig_id] > cur_f1):
#             # Always override if currently adversary currently using orig_id
#             adv_ids[orig_id] = qa['id']
#             adv_f1_scores[orig_id] = cur_f1
#             adv_exact_match_scores[orig_id] = cur_exact_match
#   if verbose:
#     print_details(dataset, predictions, adv_ids)
#   orig_f1 = 100.0 * orig_f1_score / len(all_ids)
#   orig_exact_match = 100.0 * orig_exact_match_score / len(all_ids)
#   adv_exact_match = 100.0 * sum(adv_exact_match_scores.values()) / len(all_ids)
#   adv_f1 = 100.0 * sum(adv_f1_scores.values()) / len(all_ids)
#   return OrderedDict([
#       ('orig_exact_match', orig_exact_match),
#       ('orig_f1', orig_f1),
#       ('adv_exact_match', adv_exact_match),
#       ('adv_f1', adv_f1),
#   ])


def extract_all_attack(dataset, predictions):

    orig_f1_score = 0.0
    orig_exact_match_score = 0.0
    adv_f1_scores = {}  # Map from original ID to F1 score
    adv_exact_match_scores = {}  # Map from original ID to exact match score
    adv_ids = {}
    all_ids = set()  # Set of all original IDs
    f1 = exact_match = 0

    # kept_article_by_id = {}
    adv_articles_by_id = {}
    for article in dataset:
        for i_par, paragraph in enumerate(article['paragraphs']):
            for qa in paragraph['qas']:
                if not '-' in qa['id']:
                    # it is not adversarial attack
                    continue
                orig_id = qa['id'].split('-')[0]
                # for adverarial attacl there is only one paragraph, on qas in the article
                # assert len(article['paragraphs']) == 1
                assert len(paragraph['qas']) == 1

                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']][0]['text']
                cur_f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths)

                if orig_id in kept_article_by_id:
                    _, prev_f1 = kept_article_by_id[orig_id]
                else:
                    prev_f1 = 10000.0
                if cur_f1 < prev_f1:
                    adv_article = {'title': f"{article['title']}-adv{i_par}", 'paragraphs': [paragraph] }
                    kept_article_by_id[orig_id] = (adv_article, cur_f1)
    print(len(kept_article_by_id))
    return kept_article_by_id

def extract_orig_articles(dataset, predictions):
    orig_scores = {}

    orig_articles = {}
    for article in dataset:
        for i_par, paragraph in enumerate(article['paragraphs']):
            for i_qa, qa in enumerate(paragraph['qas']):
                if  '-' in qa['id']:
                    # it is adversarial attack
                    continue
                orig_id = qa['id']
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']][0]['text']
                cur_f1 = metric_max_over_ground_truths(f1_score, prediction, ground_truths)

                orig_scores[orig_id] = cur_f1

                new_paragraph = paragraph.copy()
                new_paragraph['qas'] = [qa]

                new_article = {'title': f"{article['title']}-orig{i_par}-qa{i_qa}", 'paragraphs': [new_paragraph] }
                orig_articles[orig_id] = (new_article, prediction, cur_f1)

    return orig_articles

def sample_robust_and_inrobust(dataset, predictions):
    orig_articles = extract_orig_articles(dataset['data'], predictions)
    adv_articles = extract_worst_attack(dataset['data'], predictions)

    flipped_articles = []
    still_articles = []
    wired_articles = []
    for qid in adv_articles:
        adv_article, adv_prediction, adv_f1 = adv_articles[qid]
        orig_article, orig_prediction, orig_f1 = orig_articles[qid]
        if adv_prediction == orig_prediction:
            still_articles.append(qid)
        elif (adv_f1 < orig_f1):
            flipped_articles.append(qid)
        else:
            wired_articles.append(qid)

    random.seed(321)
    
    random.shuffle(flipped_articles)
    random.shuffle(still_articles)
    random.shuffle(wired_articles)
    print(len(flipped_articles), len(still_articles), len(wired_articles))

    base = orig_articles
    # base = adv_articles
    flipped_articles = [base[x][0] for x in flipped_articles]
    still_articles = [base[x][0] for x in still_articles]
    wired_articles = [base[x][0] for x in wired_articles]
    # for p in wired_articles:
    #     print(p)
    sampled_articles = flipped_articles[:10] + still_articles[:10]
    
    
    # sampled_dataset = {'version': dataset['version'], 'data': sampled_articles}
    # dump_json(sampled_dataset, 'outputs/sample-addsent_squad.json', indent=2)
    for a in sampled_articles:
        # print(a['paragraphs'][0]['qas'][0]['id'].split('-')[0])
        print(a['paragraphs'][0]['qas'][0]['question'])


def extract_info_for_ner(dataset, predictions):
    orig_articles = extract_orig_articles(dataset['data'], predictions)
    adv_articles = extract_worst_attack(dataset['data'], predictions)

    flipped_articles = []
    still_articles = []
    wired_articles = []
    for qid in adv_articles:
        adv_article, adv_prediction, adv_f1 = adv_articles[qid]
        orig_article, orig_prediction, orig_f1 = orig_articles[qid]
        if adv_prediction == orig_prediction:
            still_articles.append(qid)
        elif (adv_f1 < orig_f1):
            flipped_articles.append(qid)
        else:
            wired_articles.append(qid)

    print(len(flipped_articles), len(still_articles), len(wired_articles))

    base = orig_articles
    # base = adv_articles
    flipped_articles = [base[x][0] for x in flipped_articles]
    still_articles = [base[x][0] for x in still_articles]
    wired_articles = [base[x][0] for x in wired_articles]


    adv_questions = []
    for a in flipped_articles:
        # print(a['paragraphs'][0]['qas'][0]['id'], a['paragraphs'][0]['qas'][0]['question'])
        qid = a['paragraphs'][0]['qas'][0]['id']
        question = a['paragraphs'][0]['qas'][0]['question']
        qtype = 'flipped'
        adv_questions.append({'id': qid, 'question': question, 'behavior': qtype})
    
    for a in still_articles:
        # print(a['paragraphs'][0]['qas'][0]['id'], a['paragraphs'][0]['qas'][0]['question'])
        qid = a['paragraphs'][0]['qas'][0]['id']
        question = a['paragraphs'][0]['qas'][0]['question']
        qtype = 'still'
        adv_questions.append({'id': qid, 'question': question, 'behavior': qtype})

    for a in wired_articles:
        # print(a['paragraphs'][0]['qas'][0]['id'], a['paragraphs'][0]['qas'][0]['question'])
        qid = a['paragraphs'][0]['qas'][0]['id']
        question = a['paragraphs'][0]['qas'][0]['question']
        qtype = 'wired'
        adv_questions.append({'id': qid, 'question': question, 'behavior': qtype})
    print(len(adv_questions))
    dump_json(adv_questions, 'misc/squad_adv_questions.json', indent=2)

def get_attacked(dataset, predictions):
    orig_articles = extract_orig_articles(dataset['data'], predictions)
    adv_articles = extract_worst_attack(dataset['data'], predictions)

    flipped_articles = []
    still_articles = []
    wired_articles = []
    for qid in adv_articles:
        adv_article, adv_prediction, adv_f1 = adv_articles[qid]
        orig_article, orig_prediction, orig_f1 = orig_articles[qid]
        if adv_prediction == orig_prediction:
            still_articles.append(qid)
        elif (adv_f1 < orig_f1):
            flipped_articles.append(qid)
        else:
            wired_articles.append(qid)
    
    print(len(flipped_articles), len(still_articles), len(wired_articles))

    base = orig_articles
    # base = adv_articles
    flipped_articles = [base[x][0] for x in flipped_articles]
    still_articles = [base[x][0] for x in still_articles]
    wired_articles = [base[x][0] for x in wired_articles]
    # for p in wired_articles:
    #     print(p)
    sampled_articles = flipped_articles + still_articles + wired_articles
    
    
    sampled_dataset = {'version': dataset['version'], 'data': sampled_articles}
    dump_json(sampled_dataset, 'outputs/attacked-orig_squad.json', indent=2)
    # for a in sampled_articles:
    #     # print(a['paragraphs'][0]['qas'][0]['id'].split('-')[0])
    #     print(a['paragraphs'][0]['qas'][0]['question'])

if __name__ == "__main__":
    # sample_addone(read_json('outputs/addone-dev_squad.json'), read_json('outputs/addone-dev_squad_predictions.json'))
    # sample_addone(read_json('outputs/addone-dev_squad.json'), read_json('outputs/addone-dev_squad_predictions.json'))
    # sample_robust_and_inrobust(read_json('outputs/addone-dev_squad.json'), read_json('outputs/addone-dev_squad_predictions.json'))
    # sample_robust_and_inrobust(read_json('outputs/addsent-dev_squad.json'), read_json('outputs/addsent-dev_squad_predictions.json'))
    # get_attacked(read_json('outputs/addsent-dev_squad.json'), read_json('outputs/addsent-dev_squad_predictions.json'))
    extract_info_for_ner(read_json('outputs/addsent-dev_squad.json'), read_json('outputs/addsent-dev_squad_predictions.json'))
