import random
from collections import defaultdict


def get_sub_dict(full_dict, keys):
    return {key: full_dict[key] for key in keys}


def split_question_id(question_id):
    q_id = question_id.split('-')
    if len(q_id) == 3:
        return dict(zip(['img_id', 'a_ix', 't_ix'], q_id))
    elif len(q_id) == 4:
        return dict(zip(['img_id', 'f_ix', 'a_ix', 't_ix'], q_id))
    elif len(q_id) == 5:
        return dict(zip(['img_id', 'f_ix', 'a_ix', 't_ix', 's_ix'], q_id))
    else:
        raise ValueError('Question id %s passed with unexpected format. Aborting.' % question_id)


def join_question_id(question_id):
    if len(question_id) == 3:
        return '-'.join([str(question_id[key]) for key in ['img_id', 'a_ix', 't_ix']])
    elif len(question_id) == 4:
        return '-'.join([str(question_id[key]) for key in ['img_id', 'f_ix', 'a_ix', 't_ix']])
    elif len(question_id) == 5:
        return '-'.join([str(question_id[key]) for key in ['img_id', 'f_ix', 'a_ix', 't_ix', 's_ix']])
    else:
        raise ValueError('Question id %s passed with unexpected format. Aborting.' % question_id)


def get_ans_from_group(group):
    answers = []
    for question in group:
        answer = list(question['label'].keys())[0]
        answers.append(answer)
    assert len(set(answers)) == 1
    return answers[0]


def get_balanced_question_ids(questions):
    assignment_id_2_q_ids = defaultdict(list)
    for q_id in questions.keys():
        q_id_dict = split_question_id(q_id)
        if 's_ix' in q_id_dict:
            assignment_id = '-'.join([q_id_dict[key] for key in ['img_id', 'f_ix', 'a_ix', 's_ix']])
        elif 'f_ix' in q_id_dict:
            assignment_id = '-'.join([q_id_dict[key] for key in ['img_id', 'f_ix', 'a_ix']])
        else:
            assignment_id = '-'.join([q_id_dict[key] for key in ['img_id', 'a_ix']])
        assignment_id_2_q_ids[assignment_id].append(q_id)
    answer_2_a_ids = defaultdict(list)
    for a_id, q_ids in assignment_id_2_q_ids.items():
        answer = get_ans_from_group([questions[q_id] for q_id in q_ids])
        answer_2_a_ids[answer].append(a_id)
    answer_class_min = min(map(len, answer_2_a_ids.values()))
    for ans in answer_2_a_ids:
        answer_2_a_ids[ans] = random.sample(answer_2_a_ids[ans], k=answer_class_min)
    balanced_q_ids = list()
    for ans in answer_2_a_ids:
        q_ids = [sl for l in [assignment_id_2_q_ids[a_id] for a_id in answer_2_a_ids[ans]] for sl in l]
        balanced_q_ids.extend(q_ids)
    return balanced_q_ids


def get_balanced_questions(in_questions):
    if type(in_questions) is list:
        questions = {q['question_id']: q for q in in_questions}
    elif type(in_questions) is dict:
        questions = in_questions
    else:
        raise TypeError('Expected list or dict, but received questions of type %s' % type(in_questions))
    balanced_ids = get_balanced_question_ids(questions)
    remove_keys = set(questions.keys()).difference(balanced_ids)
    for key in remove_keys:
        del questions[key]
    questions = list(questions.values())
    return questions

def get_inconsistent(preds_list: list):
    """Returns a list of keys for which each dictionary in preds_list don't all have the same value."""
    consistent = get_consistent(preds_list)
    inconsistent = set(preds_list[0].keys()).difference(consistent)
    return list(inconsistent)


def get_consistent(preds_list: list):
    """Returns a list of keys for which each dictionary in preds_list all have the same value."""
    q_ids = preds_list[0].keys()
    assert all(list(map(lambda x: len(x) == len(q_ids), preds_list))), AssertionError('Some preds are \
        incomplete.')
    consistent = list()
    for q_id in q_ids:
        preds = [preds[q_id] for preds in preds_list]
        if len(set(preds)) == 1:
            consistent.append(q_id)
    return consistent


def get_correct(questions, preds, subset=None):
    q_copy = questions.copy()
    if not isinstance(q_copy, dict):
        q_copy = {q['question_id']: q for q in questions}
    if subset is None:
        subset = q_copy.keys()
    correct = list()
    for q_id in subset:
        if preds[q_id] in q_copy[q_id]['label']:
            correct.append(q_id)
    return correct


def get_incorrect(questions, preds, subset=None):
    q_copy = questions.copy()
    if not isinstance(q_copy, dict):
        q_copy = {q['question_id']: q for q in questions}
    if subset is None:
        subset = q_copy.keys()
    q_ids = set(q_copy.keys()).intersection(subset)
    correct_ids = get_correct(questions, preds)
    incorrect_ids = q_ids.difference(correct_ids)
    return list(incorrect_ids)
