# get all facts with preferences
all_ids = list(set(comp_data['fact_id'] + ind_data['fact_id'] + learn_data['fact_id']))

# index mnemonics
mnemonic_map = dict()
for d in [comp_data, ind_data, learn_data]:
  for i in range(len(d['fact_id'])):
    fact_id = d['fact_id'][i]
    if fact_id in mnemonic_map:
      continue
    mapper = {'A': d['mnemonic_a'][i], 'B': d['mnemonic_b'][i], d['mnemonic_a'][i] : 'A', d['mnemonic_b'][i] : 'A'}
    mnemonic_map[fact_id] = mapper

# comparison
direct_comparisons_dict = {}
for i in range(len(comp_data['fact_id'])):
  fact_id, a_votes, b_votes, ties = comp_data['fact_id'][i], comp_data['mnemonic_a_votes'][i], comp_data['mnemonic_b_votes'][i], comp_data['tie_votes'][i]
  mn_a, mn_b = comp_data['mnemonic_a'][i], comp_data['mnemonic_b'][i]
  #outcome = aggregate_outcomes([a_votes, b_votes, ties])
  swap = mnemonic_map[fact_id]['A'] == mn_b
  if swap:
    assert(mnemonic_map[fact_id]['A'] == mn_b and mnemonic_map[fact_id]['B'] == mn_a)
    direct_comparisons_dict[fact_id] = [b_votes, a_votes, ties]
  else:
    assert(mnemonic_map[fact_id]['A'] == mn_a and mnemonic_map[fact_id]['B'] == mn_b)
    direct_comparisons_dict[fact_id] = [a_votes, b_votes, ties]

# likert
ratings_A_dict = {}
ratings_B_dict = {}
for i in range(len(ind_data['fact_id'])):
  fact_id, mn_a, mn_b = ind_data['fact_id'][i], ind_data['mnemonic_a'][i], ind_data['mnemonic_b'][i]
  swap = mnemonic_map[fact_id]['A'] == mn_b
  if swap:
    assert(mnemonic_map[fact_id]['A'] == mn_b and mnemonic_map[fact_id]['B'] == mn_a)
    ratings_A_dict[fact_id] = ind_data['mnemonic_b_ratings'][i]
    ratings_B_dict[fact_id] = ind_data['mnemonic_a_ratings'][i]
  else:
    assert(mnemonic_map[fact_id]['A'] == mn_a and mnemonic_map[fact_id]['B'] == mn_b)
    ratings_A_dict[fact_id] = ind_data['mnemonic_a_ratings'][i]
    ratings_B_dict[fact_id] = ind_data['mnemonic_b_ratings'][i]

# learning outcomes
learning_outcomes_A_dict = {}
learning_outcomes_B_dict = {}
for i in range(len(ind_data['fact_id'])):
  fact_id, mn_a, mn_b = learn_data['fact_id'][i], learn_data['mnemonic_a'][i], learn_data['mnemonic_b'][i]
  swap = mnemonic_map[fact_id]['A'] == mn_b
  if swap:
    assert(mnemonic_map[fact_id]['A'] == mn_b and mnemonic_map[fact_id]['B'] == mn_a)
    learning_outcomes_A_dict[fact_id] = learn_data['mnemonic_b_num_wrong'][i]
    learning_outcomes_B_dict[fact_id] = learn_data['mnemonic_a_num_wrong'][i]
  else:
    assert(mnemonic_map[fact_id]['A'] == mn_a and mnemonic_map[fact_id]['B'] == mn_b)
    learning_outcomes_A_dict[fact_id] = learn_data['mnemonic_a_num_wrong'][i]
    learning_outcomes_B_dict[fact_id] = learn_data['mnemonic_b_num_wrong'][i]

import numpy as np
import pymc as pm
import arviz as az


with pm.Model() as model:

    # *************** Fixed Values ***************
    n_likert_categories = 5
    n_instances = 472

    # *************** Define Data ***************
    direct_comparisons_curr = direct_comparisons
    ratings_A_counts_curr = ratings_A_counts
    ratings_B_counts_curr = ratings_B_counts
    turn_counts_A_curr = turn_counts_A
    turn_counts_B_curr = turn_counts_B

    compare_idx_curr = compare_idx
    ratings_A_idx_curr = ratings_A_idx
    ratings_B_idx_curr = ratings_B_idx
    learn_A_idx_curr = learn_A_idx
    learn_B_idx_curr = learn_B_idx

    # *************** Overall Mnemonic Effectiveness ***************
    overall_effectiveness_A = pm.Beta('effectiveness_A', alpha=1, beta=1, shape=n_instances)
    overall_effectiveness_B = pm.Beta('effectiveness_B', alpha=1, beta=1, shape=n_instances)

    # *************** Comparison Data ***************
    # Transform effectiveness to a new value for flexibility
    compare_slope = pm.Normal('compare_slope', mu=0, sigma=1)
    compare_intercept = pm.Normal('compare_intercept', mu=0, sigma=1)
    compare_a_prob = pm.math.sigmoid(compare_slope * overall_effectiveness_A[compare_idx_curr] + compare_intercept)
    compare_b_prob = pm.math.sigmoid(compare_slope * overall_effectiveness_B[compare_idx_curr] + compare_intercept)

    # Adjust probabilities by sampling the probability of a tie
    tie_prob = pm.Beta('tie_prob', alpha=1, beta=1)
    total_prob = (compare_a_prob + compare_b_prob + tie_prob)
    comparison_probs = pm.math.stack([compare_a_prob / total_prob,
                                      compare_b_prob / total_prob,
                                      (np.ones(compare_idx_curr.shape) * tie_prob / total_prob)]).T

    # Use these probabilities for a multinomial distribution
    comparison_votes = pm.Multinomial('comparison_votes', n=pm.math.sum(direct_comparisons_curr, axis=1), p=comparison_probs, observed=direct_comparisons_curr)

    # *************** Likert Scale Data ***************

    # Transform effectiveness to a distribution
    likert_slope = pm.Normal('likert_slope', mu=0, sigma=1, shape=(1, n_likert_categories-1))
    likert_intercept = pm.Normal('likert_intercept', mu=0, sigma=1, shape=(1, n_likert_categories-1))

    # Create multinomial distribution for mnemonic A
    likert_logits_a = pm.math.sigmoid(likert_slope * overall_effectiveness_A[ratings_A_idx_curr][:, None] + likert_intercept)
    likert_logits_a = pm.math.concatenate([np.zeros((ratings_A_idx_curr.shape[0], 1)), likert_logits_a], axis=1)
    likert_p_a = pm.math.softmax(likert_logits_a, axis=1)
    ratings_A = pm.Multinomial('ratings_A', n=np.sum(ratings_A_counts_curr, axis=1), p=likert_p_a, observed=ratings_A_counts_curr)

    # Do the same for mnemonic B
    likert_logits_b = pm.math.sigmoid(likert_slope * overall_effectiveness_B[ratings_B_idx_curr][:, None] + likert_intercept)
    likert_logits_b = pm.math.concatenate([np.zeros((ratings_B_idx_curr.shape[0], 1)), likert_logits_b], axis=1)
    likert_p_b = pm.math.softmax(likert_logits_b, axis=1)
    ratings_B = pm.Multinomial('ratings_B', n=np.sum(ratings_B_counts_curr, axis=1), p=likert_p_b, observed=ratings_B_counts_curr)

    # *************** Learning Data ***************

    # Slope/Intercept to adjust learn A/B
    learn_slope = pm.Normal('learn_slope', mu=0, sigma=1)
    learn_intercept = pm.Normal('learn_intercept', mu=0, sigma=1)

    # Learning Mnemonic A is Negative Binomial with 1 success needed and an adjusted probability of success altered by effectiveness
    prob_learn_a = pm.math.sigmoid(learn_slope * overall_effectiveness_A[learn_A_idx_curr] + learn_intercept)
    learn_A = pm.NegativeBinomial('learn_A', p=prob_learn_a, n=1, observed=turn_counts_A_curr)

    # Do the same for Mnemonic B
    prob_learn_b = pm.math.sigmoid(learn_slope * overall_effectiveness_B[learn_B_idx_curr] + learn_intercept)
    learn_B = pm.NegativeBinomial('learn_B', p=prob_learn_b, n=1, observed=turn_counts_B_curr)

    # *************** Training ***************
    # Sample from the posterior using the No-U-Turn Sampler (NUTS)
    trace = pm.sample(1000, tune=1000, chains=5, random_seed=[1, 2, 3, 4, 5])