'''
given set of observations, fit irt model and write outputs to disk
subject: train_noise
item: imageID
y: response
'''

import argparse
import csv 

import numpy as np

import torch
import torch.nn as nn

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import torch.distributions.constraints as constraints

from models.irt import OneParamLog


parser = argparse.ArgumentParser()
parser.add_argument('-i', '--infile', help='response pattern file')
parser.add_argument('-e', '--num-epochs', default=1000, type=int)
parser.add_argument('--gpu', action='store_true')
parser.add_argument('--priors', help='[vague, hierarchical]', default='hierarchical')
parser.add_argument('--model', help='[1PL,2PL]', default='1PL')
args = parser.parse_args()

device = torch.device('cpu')
if args.gpu:
    device = torch.device('cuda')

# 1. load data from file 
# 2. combine into obs 3-tuples

obs = []
models = []
items = []
responses = []

model2idx = {}
item2idx = {}
idx2model = {}
idx2item = {}
modelcounter = 0
itemcounter = 0

# inputs have to be trainsize, noise, itemID, response
with open(args.infile, 'r') as infile:
    inreader = csv.reader(infile, delimiter=',')
    next(inreader)  # skip headers 
    for line in inreader:
        trainsize, noise, itemID, response = line
        modelID = '{}_{}'.format(trainsize, noise)
        response = int(response)
        
        if modelID not in model2idx:
            model2idx[modelID] = modelcounter
            idx2model[modelcounter] = modelID
            modelcounter += 1
        if itemID not in item2idx:
            item2idx[itemID] = itemcounter
            idx2item[itemcounter] = itemID
            itemcounter += 1
        midx = model2idx[modelID]
        iidx = item2idx[itemID]
        models.append(midx) 
        items.append(iidx) 
        responses.append(response)

num_models = len(set(models))
num_items = len(set(items))

# 3. define model and guide accordingly
if args.model == '1PL':
    m = OneParamLog(args.priors, device)


# 4. fit irt model with svi, trace-elbo loss
m.fit(models, items, responses, args.num_epochs) 

# 5. once model is fit, write outputs (diffs and thetas) to disk, 
#       retaining original modelIDs and itemIDs so we can use them 

out_items = {}
out_students = {}
for name in pyro.get_param_store().get_all_param_names():
    print(name)
    val = pyro.param(name).data.cpu().numpy()
    print(val)
    
    if name == 'loc_diff':
        # write diffs to disk
        for i in range(len(val)):
            if idx2item[i] not in out_items:
                out_items[idx2item[i]] = {'mean': 0, 'sd': 0}
            out_items[idx2item[i]]['mean'] = val[i]
    elif name == 'loc_ability':
        # write thetas to disk
        for i in range(len(val)):
            if idx2model[i] not in out_students:
                out_students[idx2model[i]] = {'mean': 0, 'sd': 0}
            out_students[idx2model[i]]['mean'] = val[i]
    elif name == 'scale_diff':
        # write diffs to disk
        for i in range(len(val)):
            if idx2item[i] not in out_items:
                out_items[idx2item[i]] = {'mean': 0, 'sd': 0}
            out_items[idx2item[i]]['sd'] = val[i]
    elif name == 'scale_ability':
        # write thetas to disk
        for i in range(len(val)):
            if idx2model[i] not in out_students:
                out_students[idx2model[i]] = {'mean': 0, 'sd': 0}
            out_students[idx2model[i]]['sd'] = val[i]

with open('thetas.csv', 'w') as outfile:
    outwriter = csv.writer(outfile)
    for key, val in out_students.items():
        outwriter.writerow([key, val['mean'], val['sd']])

with open('diffs.csv', 'w') as outfile:
    outwriter = csv.writer(outfile)
    for key, val in out_items.items(): 
        outwriter.writerow([key, val['mean'], val['sd']])

