import os
import sys
sys.path.append('.')
from os.path import join
from common.utils import read_json, dump_json
from itertools import permutations, combinations_with_replacement, product
from utils import *
import json

def make_two_par_qa(contexts, nations, question):
    context_choices = list(permutations(contexts, 2))
    nation_choices = list(product(nations, repeat=2))

    qas = []
    # (q, c, a)
    
    for c_n_choice in product(context_choices, nation_choices):
        permuted_contexts, permuted_nations = c_n_choice        
        answer = 'yes' if permuted_nations[0][0] == permuted_nations[1][0] else 'no'
        filled_pars = []
        for (c, n) in zip(permuted_contexts, permuted_nations):
            filled_c = c[0].replace('#NATION', n[0]).replace('#CITY', n[1])            
            filled_pars.append(filled_c)
        context = ' '.join(filled_pars)
        
        q0 = question.replace('#NAME1', permuted_contexts[0][1]).replace('#NAME2', permuted_contexts[1][1])
        q1 = question.replace('#NAME1', permuted_contexts[1][1]).replace('#NAME2', permuted_contexts[0][1])

        qas.append((q0, context, answer))
        qas.append((q1, context, answer))
    
    return qas

def make_three_par_qa(contexts, nations, question):
    context_choices = list(permutations(contexts, 3))
    nation_choices = list(product(nations, repeat=3))
    question_choices = list(permutations([0, 1, 2], 2))
    qas = []
    # (q, c, a)
    nation_choices = [x for x in nation_choices if not (x[0][0] == x[1][0] and x[0][0] == x[2][0])]
    # print(len(nation_choices))
    for c_n_q_choice in product(context_choices, nation_choices, question_choices):
        permuted_contexts, permuted_nations, permutated_q_ids  = c_n_q_choice
        # answer = 'yes' if permutations[0][0] == permutations[1][0] else 'no'
        filled_pars = []
        for (c, n) in zip(permuted_contexts, permuted_nations):
            filled_c = c[0].replace('#NATION', n[0]).replace('#CITY', n[1])            
            filled_pars.append(filled_c)
        context = ' '.join(filled_pars)
        name_q0 = permuted_contexts[permutated_q_ids[0]][1]
        name_q1 = permuted_contexts[permutated_q_ids[1]][1]
        nation_q0 = permuted_nations[permutated_q_ids[0]][0]
        nation_q1 = permuted_nations[permutated_q_ids[1]][0]
        q = question.replace('#NAME1', name_q0).replace('#NAME2', name_q1)

        answer = 'yes' if nation_q0 == nation_q1 else 'no'
        qas.append((q, context, answer))
    return qas

def make_qa_dataset(qas):
    data = {}
    data['version'] = '1.1'

    dataset = []
    for i, qa in enumerate(qas):
        q, c, a = qa
        title = 'perturbation%d' % i
        ex = {}
        ex['title'] = title
        par = {}
        par['context'] = c
        par['qas'] = [{
                        'id': title,
                        'question': q,
                        'answers': [{'answer_start': -1, 'text': a}],
                        'is_yesno': True,
                        'question_type': 'comparison'
                    }]
        ex['paragraphs'] = [par]

        dataset.append(ex)
    data['data'] = dataset
    return data

def make_perturbs():
    data = read_json('lime/sp_case/nation.json')

    nations = list(zip(*data['nations']))
    names = list(zip(*data['names']))
    contexts = data['contexts']
    question = data['question']
    contexts = [ (c.replace('#NAME', n_c), n_q) for c, (n_c, n_q) in zip(contexts, names)]
    # print(contexts)

    # qas = make_two_par_qa(contexts, nations, question) + make_three_par_qa(contexts, nations, question)
    qas = make_three_par_qa(contexts, nations, question)
    qa_data = make_qa_dataset(qas)
    
    output_prefix = 'outputs/'
    dev_outfile = 'nation-perturb-dev_hpqa.json'
    dev_outfile = join(output_prefix, dev_outfile)

    dump_json(qa_data, dev_outfile, indent=1)


def name_order_in_question_not_matter(raw_data, norm_preds):
    NAMES = ['Scott', 'Wood', 'James']
    grouped_ids = {}
    for k, data in raw_data.items():
        contained_names = [x for x in NAMES if x in data['question']]
        group = contained_names[0] + contained_names[1] + data['context']
        if group not in grouped_ids:
            grouped_ids[group] = []
        grouped_ids[group].append(k)

    print(len(grouped_ids))
    for g, ids in grouped_ids.items():
        predictions = [norm_preds[id] > 0.5 for id in ids]
        num_yes = sum(predictions)
        num_no = len(predictions) - num_yes
        # print('Num Yes', num_yes, 'Num no', num_no)
        if num_yes and num_no:
            print('-------------Group Failed--------------')
            for id in ids:
                print(json.dumps(raw_data[id], indent=1))
                print(norm_preds[id]>0.5, '%.2f'% norm_preds[id], '\n')

def how_nation_order_matter(raw_data, norm_preds):
    NAMES = ['Scott', 'Wood', 'James']
    grouped_ids = {}
    for k, data in raw_data.items():
        
        question = data['question']
        context = data['context']

        NAMES = ['Scott', 'Wood', 'James']
        contain_by_q = [x for x in NAMES if x in question]
        pos_in_c = [context.index(x) for x in NAMES]
        name_and_pos = list(zip(NAMES, pos_in_c))
        name_and_pos.sort(key=lambda x: x[1])
        nations = []
        for i in range(3):
            if i == 2:
                segment = context[name_and_pos[i][1]:]
            else:
                segment = context[name_and_pos[i][1]: name_and_pos[i+1][1]]

            if 'American' in segment:
                nations.append('American')
            elif 'English' in segment:
                nations.append('English')
        assert len(nations) == 3

        group = '-'.join(nations)
        if group not in grouped_ids:
            grouped_ids[group] = []
        grouped_ids[group].append(k)

    print(len(grouped_ids))
    for g, ids in grouped_ids.items():
        predictions = [norm_preds[id] > 0.5 for id in ids]
        num_yes = sum(predictions)
        num_no = len(predictions) - num_yes
        # print('Num Yes', num_yes, 'Num no', num_no)
        # if num_yes and num_no:
        if True:
            # print('-------------Group Failed--------------')
            print('-------------', g,'-------------------')
            print('Num Yes', num_yes, 'Num no', num_no)
            # for id in ids:
            #     if num_yes < num_no and norm_preds[id]>0.5:
            #         print(json.dumps(raw_data[id], indent=1))
            #         print(norm_preds[id]>0.5, '\n')
            #     if num_yes > num_no and norm_preds[id]<0.5:
            #         print(json.dumps(raw_data[id], indent=1))
            #         print(norm_preds[id]>0.5, '\n')
            #     # print('')

def how_nation_and_context_matter(raw_data, norm_preds):
    grouped_ids = {}
    for k, data in raw_data.items():        
        group = data['context']
        if group not in grouped_ids:
            grouped_ids[group] = []
        grouped_ids[group].append(k)

    print(len(grouped_ids))
    for g, ids in grouped_ids.items():
        predictions = [norm_preds[id] > 0.5 for id in ids]
        num_yes = sum(predictions)
        num_no = len(predictions) - num_yes
        # print('Num Yes', num_yes, 'Num no', num_no)
        if num_yes and num_no:
            print('-------------Group Failed--------------')
            print('Num Yes', num_yes, 'Num no', num_no)
            for id in ids:
                print(json.dumps(raw_data[id], indent=1))
                print(norm_preds[id]>0.5, '%.2f'% norm_preds[id], '\n')
                # print('')


def verify_hypothesis():
    raw_preds = read_json('lime/sp_case/raw_predictions.json')
    norm_preds = normalize_raw_prediction(raw_preds)
    
    raw_data = read_json('lime/sp_case/nation_perturb.json')
    raw_data = prepro_data(raw_data)
    name_order_in_question_not_matter(raw_data, norm_preds)
    # how_nation_order_matter(raw_data, norm_preds)
    # how_nation_and_context_matter(raw_data, norm_preds)

if __name__ == "__main__":
    # make_perturbs()
    verify_hypothesis()