from sklearn import metrics
import os
import sys
import time
import pickle
import random
import numpy as np
import tensorflow as tf
from collections import defaultdict
import json
import math
from input import DataInputTrainTarget, DataInputTrainSource, DataInputTest, DataInputTransferTarget, \
    DataInputTransferSource, DataInputTrainAttackerReal, DataInputTrainAttackerFakeSource, DataInputTestAttack
from model import PrivNet
import pprint


pp = pprint.PrettyPrinter()

# Vanilla source network vs
# Adversarial source network

# setenv PYTHONUSERBASE /qydata/ghuac/home
# setenv CUDA_VISIBLE_DEVICES 1
# cd /qydata/ghuac/ml-1m/
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# gpu_options = tf.GPUOptions(allow_growth=True)

RANDOM_SEED = 2019
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
tf.set_random_seed(RANDOM_SEED)

config = {
    'data_name_target': 'target',
    'data_name_source': 'source',
    'train_trans_batch_size': 128,  # target w/ source, and source w/ target
    'train_batch_size': 128,
    'train_attack_batch_size': 128,
    'test_batch_size': 4098,
    'test_batch_size_attack': 1024,
    'lr': 5e-4,
    'lr_attack': 1e-3,
    'hidden_units': 80,
    'clip_norm': 5,
    'clip_norm_attack': 10,
    'fc_layer': 64,
    'nepoch': 50,
    'ratio_train_tgt': 1.0,
    'data_user_gender_map': '../user_gender_map.data',
    'data_user_age_map': '../user_age_map.data',
    'data_user_occupation_map': '../user_occupation_map.data',
    'topKs': [1, 5, 10, 20, 35],
    'adversary_weight': 1.0,
}

start_time = time.time()

user_latest_hist_src = defaultdict(list)  # used to compute the target cold user rep when testing
with open(config['data_name_source'] + '.source.user_latest_hist.json') as ifile:
    data = json.load(ifile)
    for user, history in data.items():
        user_id = int(user)
        user_latest_hist_src[user_id] = history  # [hist_latest, i_latest]
print('#user_latest_hist_src = {}'.format(len(user_latest_hist_src)))  # 10743

user_latest_hist_tgt = defaultdict(list)  # NOT used except for computing shared users since we don't evaluate on source
with open(config['data_name_target'] + '.target.user_latest_hist.json') as ifile:
    data = json.load(ifile)
    for user, history in data.items():
        user_id = int(user)
        user_latest_hist_tgt[user_id] = history  # [hist_latest, i_latest]
print('#user_latest_hist_tgt = {}'.format(len(user_latest_hist_tgt)))  # 9045

with open(config['data_name_target'] + '.target.prnet.pkl', 'rb') as f:
    train_set_tgt = pickle.load(f)  # (user_id, hist, pos_list[i]/neg_train_list[i], 1/0)
    test_set_tgt = pickle.load(f)  # (user_id, pos, neg)
    max_item_id_tgt = pickle.load(f)
    max_user_id_tgt = pickle.load(f)
print('#tgt(train,test) = {},{}'.format(len(train_set_tgt), len(test_set_tgt)))  # 251432,1982200
print('max_item_id_tgt = {}'.format(max_item_id_tgt))
print('max_user_id_tgt = {}'.format(max_user_id_tgt))
item_count_tgt = max_item_id_tgt + 1

# !!!target domain is too 'dense' to significantly show the benefits of knowledge transfer, so we can sparsify it
train_set_tgt = train_set_tgt[: int(len(train_set_tgt) * config['ratio_train_tgt'])]
print('#train_set_tgt = {}'.format(len(train_set_tgt)))

with open(config['data_name_source'] + '.source.prnet.pkl', 'rb') as f:
    train_set_src = pickle.load(f)  # (user_id, hist, pos_list[i]/neg_train_list[i], 1/0)
    max_item_id_src = pickle.load(f)
    max_user_id_src = pickle.load(f)
print('#src(train) = {}'.format(len(train_set_src)))
print('max_item_id_src = {}'.format(max_item_id_src))
print('max_user_id_src = {}'.format(max_user_id_src))
item_count_src = max_item_id_src + 1
max_user_id = max([max_user_id_src,max_user_id_tgt])
user_count = max_user_id + 1

users_latest_src = set(user_latest_hist_src.keys())
users_latest_tgt = set(user_latest_hist_tgt.keys())
users_train_transfer = list(users_latest_src.intersection(users_latest_tgt))
print('#users_train_transfer = {}'.format(len(users_train_transfer)))
percent_users_transfer = 1.0  # control how much knowledge
print('percent_users_transfer = {}'.format(percent_users_transfer))

# Ensure that the test users are the same for recommendation and privacy
test_set_users_attack = set()
for t in test_set_tgt:
    user = t[0]
    test_set_users_attack.add(user)
print('#test_set_users_attack = {}'.format(len(test_set_users_attack)))  # 1098
test_set_users_attack_list = [e for e in test_set_users_attack]

user_private_map = defaultdict(lambda: defaultdict(int))
private_labels_set = defaultdict(set)
for private in {'gender', 'age', 'occupation'}:
    print('---{}---'.format(private))
    with open(config['data_user_{}_map'.format(private)]) as fin:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            line = line.split(' ')
            user_id = int(line[0])
            label = int(line[1])
            user_private_map[private][user_id] = label
            private_labels_set[private].add(label)
    print('#user_{}_map = {}'.format(private, len(user_private_map[private])))
    config['num_classes_{}'.format(private)] = len(private_labels_set[private])
pp.pprint(config)
pp.pprint(private_labels_set)

print('loading done, begin training...[{:.2f}s]'.format(time.time()-start_time))


def _eval(sess, model, topKs, epoch=-1):
    user_pos_item_scores_pos = defaultdict(lambda: defaultdict(list))
    user_pos_item_scores_neg = defaultdict(lambda: defaultdict(list))
    for _, uij_record in DataInputTest(test_set_tgt, user_latest_hist_src, config['test_batch_size']):
        u_test, j_test, score_j_test, j_parent = model.test(sess, uij_record)
        for u, j, s, parent in zip(u_test, j_test, score_j_test, j_parent):
            if parent == -1:  # a positive
                user_pos_item_scores_pos[u][j] = [(j,s)]  # 1:99
            else:  # a negative and its parent is the corresponding positive
                user_pos_item_scores_neg[u][parent].append((j,s))
    assert len(user_pos_item_scores_pos) == len(user_pos_item_scores_neg)
    AUC = 0.0
    count_comparisons = 0
    length_topKs = len(topKs)
    HRs = [0.0] * length_topKs
    NDCGs = [0.0] * length_topKs
    MRR = 0.0
    for user in user_pos_item_scores_pos:
        pos_item_scores_pos = user_pos_item_scores_pos[user]
        pos_item_scores_neg = user_pos_item_scores_neg[user]
        assert len(pos_item_scores_pos) == len(pos_item_scores_neg)
        hrs = [0.0] * length_topKs
        ndcgs = [0.0] * length_topKs
        mrr = 0.0
        for pos in pos_item_scores_pos:
            item_scores_pos = pos_item_scores_pos[pos]  # 1
            item_scores_neg = pos_item_scores_neg[pos]  # 99
            assert len(item_scores_pos) == 1
            assert len(item_scores_neg) == 99
            # compute AUC
            p, s_pos = item_scores_pos[0]  # [(j,s)]
            for n, s_neg in item_scores_neg:
                if s_pos > s_neg:
                    AUC += 1
                count_comparisons += 1
            # compute HR, NDCG, MRR
            item_scores = item_scores_pos
            item_scores.extend(item_scores_neg)
            assert len(item_scores) == 100
            item_scores = sorted(item_scores, key=lambda x: -x[1])  # rank by scores from High to Low
            pred_rank_gt = 0
            for item, score in item_scores:
                if item == pos:
                    break
                pred_rank_gt += 1  # the best rank is 0
            for ind, topK in enumerate(topKs):
                if pred_rank_gt < topK:  # the best rank is 0
                    hrs[ind] += 1
                    ndcgs[ind] += math.log(2) / math.log(pred_rank_gt + 2)
            mrr += 1.0 / (pred_rank_gt + 1)
        # normalize over num_of_positives for a user
        for ind in range(length_topKs):
            HRs[ind] += hrs[ind] / len(pos_item_scores_pos)
            NDCGs[ind] += ndcgs[ind] / len(pos_item_scores_pos)
        MRR += mrr / len(pos_item_scores_pos)
    AUC /= count_comparisons
    # normalize over users
    for ind in range(length_topKs):
        HRs[ind] /= len(user_pos_item_scores_pos)
        NDCGs[ind] /= len(user_pos_item_scores_pos)
    MRR /= len(user_pos_item_scores_pos)
    return AUC, HRs, NDCGs, MRR


is_private_label_users = True
def _eval_attack(sess, model, epoch=-1):
    global is_private_label_users
    true_predict_results = {
        'gender': {'y_true': [], 'y_pred':  []},
        'age': {'y_true': [], 'y_pred':  []},
        'occupation': {'y_true': [], 'y_pred':  []},
    }
    if is_private_label_users:
        private_label_users = {
            'gender': defaultdict(int),
            'age': defaultdict(int),
            'occupation': defaultdict(int)
        }
    private_accuracy = defaultdict(int)
    for _, uij_record in DataInputTestAttack(test_set_users_attack_list, user_latest_hist_src, user_private_map, config):
        u_src, score_gender, score_age, score_occupation = model.test_attack(sess, uij_record)
        for u, s_gender, s_age, s_occupation in zip(u_src, score_gender, score_age, score_occupation):
            s_private = {
                'gender': s_gender,
                'age': s_age,
                'occupation': s_occupation
            }
            for private in s_private:
                truth = user_private_map[private][u]
                true_predict_results[private]['y_true'].append(truth)
                pred = np.argmax(s_private[private])
                true_predict_results[private]['y_pred'].append(pred)
                if truth == pred:
                    private_accuracy[private] += 1
                if is_private_label_users:
                    private_label_users[private][truth] += 1
    if is_private_label_users:
        pp.pprint(private_label_users)
        is_private_label_users = False
    for private in {'gender', 'age', 'occupation'}:
        assert len(true_predict_results[private]['y_true']) == len(true_predict_results[private]['y_pred'])
    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
    report_gender_age_occupation = defaultdict(lambda: defaultdict(object))
    for private in {'gender', 'age', 'occupation'}:
        for average in {'micro', 'macro', 'weighted'}:
            report_all = metrics.precision_recall_fscore_support(true_predict_results[private]['y_true'],
                                                                 true_predict_results[private]['y_pred'], average=average)
            report = {
                'precision': report_all[0],
                'recall': report_all[1],
                'f1': report_all[2],
            }
            report_gender_age_occupation[private][average] = report
        report_gender_age_occupation[private]['accuracy'] = private_accuracy[private] / len(true_predict_results[private]['y_true'])

    return report_gender_age_occupation


length_topKs = len(config['topKs'])
recommendation_metric_best ={
    'AUC_best': 0.0,
    'HRs_best': [0.0] * length_topKs,
    'NDCGs_best': [0.0] * length_topKs,
    'MRR_best': 0.0,
    'AUC_best_epoch': -1,
    'HRs_best_epoch': [-1] * length_topKs,
    'NDCGs_best_epoch': [-1] * length_topKs,
    'MRR_best_epoch': -1,
}

length_metric = 2  # [value,epoch]
gender_age_occupation_metric_best = {
    'gender': {
        'micro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'macro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'weighted': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'accuracy': [0.0] * length_metric,
    },
    'age': {
        'micro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'macro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'weighted': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'accuracy': [0.0] * length_metric,
    },
    'occupation': {
        'micro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'macro': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'weighted': {'precision': [0.0] * length_metric, 'recall': [0.0] * length_metric, 'f1': [0.0] * length_metric},
        'accuracy': [0.0] * length_metric,
    },
}


with tf.Session() as sess:
    model = PrivNet(user_count,item_count_tgt,item_count_src,config)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    start_time = time.time()
    accumulate_best_mrr_epoch = 0
    accumulate_best_accuracy_epoch = 0
    for epoch in range(config['nepoch']):
        train_set_attack_fake_source = []
        train_set_attack_real_source = []
        for t in train_set_src:
            user = t[0]
            if user in test_set_users_attack:  # test users cannot occur in the training, cold-start users
                continue
            t_fake = [e for e in t]
            t_real = [e for e in t]
            for private in ['gender', 'age', 'occupation']:  # must be a list to reserve the order!
                real_label = user_private_map[private][user]
                fake_labels = [i for i in private_labels_set[private] if i != real_label]
                random.shuffle(fake_labels)
                fake_label = fake_labels[0]
                t_fake.append(fake_label)
                t_real.append(real_label)
            train_set_attack_fake_source.append(t_fake)
            train_set_attack_real_source.append(t_real)
        random.shuffle(train_set_attack_fake_source)
        random.shuffle(train_set_attack_real_source)
        if epoch == 0:
            print('#train_set_attack_fake_source = {}'.format(len(train_set_attack_fake_source)))
            print('#train_set_attack_real_source = {}'.format(len(train_set_attack_real_source)))
        # Adversarial recommender: Privacy attacker simulation, fool the adversary, Adversary src net, MTL learning
        # 1) fitting user-item, 2) training on the fake privacy labels
        # self.loss_attack_src + self.adversary_weight * self.loss_attack_fake
        # DO NOT update the attacker's parameters!!!
        loss_attack_fake_src_sum = 0.0
        loss_attack_src_sum = 0.0
        for batch, uij_words in DataInputTrainAttackerFakeSource(train_set_attack_fake_source, config):
            loss_attack_fake_src, loss_attack_src = model.train_attack_fake_source(sess, uij_words)
            loss_attack_fake_src_sum += loss_attack_fake_src
            loss_attack_src_sum += loss_attack_src
        print('Epoch={}, loss_attack_fake_src={:.6f}, loss_attack_src={:.6f}, [{:.2f}s]'.format(epoch,
                loss_attack_fake_src_sum/len(train_set_attack_fake_source),
                loss_attack_src_sum/len(train_set_attack_fake_source), time.time()-start_time))
        # adversary attacker
        loss_attack_real_sum = 0.0  # reduce_mean(privates)
        loss_gender_real_sum = 0.0
        loss_age_real_sum = 0.0
        loss_occupation_real_sum = 0.0
        for batch, uij_words in DataInputTrainAttackerFakeSource(train_set_attack_real_source, config):
            loss_attack_real, loss_gender_real, loss_age_real, loss_occupation_real = model.train_attack_real_source(sess, uij_words)
            loss_gender_real_sum += loss_attack_real
            loss_age_real_sum += loss_gender_real
            loss_age_real_sum += loss_age_real
            loss_occupation_real_sum += loss_occupation_real
        print('Epoch={}, loss_attack_real={:.6f}, gender={:.6f}, age={:.6f}, occupation={:.6f}, [{:.2f}s]'.format(epoch,
                loss_attack_real_sum/len(train_set_attack_real_source),
                loss_gender_real_sum/len(train_set_attack_real_source),
                loss_age_real_sum/len(train_set_attack_real_source),
                loss_occupation_real_sum/len(train_set_attack_real_source), time.time()-start_time))

        # must train source network since the test users are cold in the target domain
        # Source base network
        random.shuffle(train_set_src)
        loss_src_sum = 0.0
        for batch, uij_words in DataInputTrainSource(train_set_src, config):
            loss_src = model.train_src(sess, uij_words)
            loss_src_sum += loss_src
        print('Epoch={}, src_loss={:.6f}, [{:.2f}s]'.format(epoch, loss_src_sum/len(train_set_src),  time.time()-start_time))
        if epoch == 0:
            print('#batches_src = {}'.format(len(train_set_src) // config['train_batch_size']))  # 975

        # Evaluation: privacy protection
        report_gender_age_occupation = _eval_attack(sess, model)
        for private in ['gender', 'age', 'occupation']:
            accuracy = report_gender_age_occupation[private]['accuracy']
            if gender_age_occupation_metric_best[private]['accuracy'][0] < accuracy:
                gender_age_occupation_metric_best[private]['accuracy'] = [accuracy, epoch]
                accumulate_best_accuracy_epoch = 0
            print('---{}: accuracy={:.4f}---'.format(private, accuracy))
            for average in ['micro', 'macro', 'weighted']:
                report = report_gender_age_occupation[private][average]
                for metric in ['precision', 'recall', 'f1']:
                    if gender_age_occupation_metric_best[private][average][metric][0] < report[metric]:
                        gender_age_occupation_metric_best[private][average][metric] = [report[metric], epoch]
                print('{}: {:.4f},{:.4f},{:.4f}'.format(average, report['precision'], report['recall'], report['f1']))

        '''
        # Transfer: train source with knowledge from target
        # This is NOT realistic, since: 1) target is sparse; 2) target do not share data to source
        users_train_transfer = list(users_latest_src.intersection(users_latest_tgt))
        random.shuffle(users_train_transfer)
        users_train_transfer = users_train_transfer[: int(percent_users_transfer * len(users_train_transfer))]
        users_train_transfer = set(users_train_transfer)
        train_set_transfer_src = []
        for t in train_set_src:  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            if user not in users_train_transfer:
                continue
            hist_i = t[1]
            item = t[2]
            label = t[3]
            history_tgt = user_latest_hist_tgt[user]  # [hist_latest, i_latest]
            hist_i_tgt = history_tgt[0]
            i_tgt = history_tgt[1]
            train_set_transfer_src.append([user,hist_i,item,label, hist_i_tgt,i_tgt])
        if epoch == 0:
            print('#train_set_transfer_src = {}'.format(len(train_set_transfer_src)))  # 175957
            print('#batches_transfer_src = {}'.format(len(train_set_transfer_src) // config['train_trans_batch_size']))
        loss_sum = 0.0
        for batch, uij_words in DataInputTransferSource(train_set_transfer_src, config['train_trans_batch_size']):
            loss = model.train_transfer_src_joint(sess, uij_words)
            loss_sum += loss
        print('Epoch={}, trans_loss_src={:.6f}, [{:.2f}s]'.format(epoch, loss_sum/len(train_set_transfer_src), time.time()-start_time))
        '''

        # Transfer: train target with knowledge from source
        users_train_transfer = list(users_latest_src.intersection(users_latest_tgt))
        random.shuffle(users_train_transfer)
        users_train_transfer = users_train_transfer[: int(percent_users_transfer * len(users_train_transfer))]
        users_train_transfer = set(users_train_transfer)
        train_set_transfer_tgt = []
        for t in train_set_tgt:  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            if user not in users_train_transfer:
                continue
            hist_i = t[1]
            item = t[2]
            label = t[3]
            history_src = user_latest_hist_src[user]  # [hist_latest, i_latest]
            hist_i_src = history_src[0]
            i_src = history_src[1]
            train_set_transfer_tgt.append([user,hist_i,item,label, hist_i_src,i_src])
        if epoch == 0:
            print('#train_set_transfer_tgt = {}'.format(len(train_set_transfer_tgt)))  # 175957
            print('#batches_transfer_tgt = {}'.format(len(train_set_transfer_tgt) // config['train_trans_batch_size']))
        loss_sum = 0.0
        for batch, uij_words in DataInputTransferTarget(train_set_transfer_tgt, config['train_trans_batch_size']):
            loss = model.train_transfer_tgt_joint(sess, uij_words)
            loss_sum += loss
        print('Epoch={}, trans_loss_tgt={:.6f}, [{:.2f}s]'.format(epoch, loss_sum/len(train_set_transfer_tgt), time.time()-start_time))

        # Target base network
        random.shuffle(train_set_tgt)
        loss_sum = 0.0
        for batch, uij_words in DataInputTrainTarget(train_set_tgt, config['train_batch_size']):
            loss = model.train_tgt(sess, uij_words)
            loss_sum += loss
        print('Epoch={}, tgt_loss={:.6f}, [{:.2f}s]'.format(epoch, loss_sum/len(train_set_tgt), time.time()-start_time))
        if epoch == 0:
            print('#batches_tgt = {}'.format(len(train_set_tgt) // config['train_batch_size']))  # 982

        # Evaluation: recommendation quality
        AUC, HRs, NDCGs, MRR = _eval(sess, model, config['topKs'], epoch)
        if recommendation_metric_best['AUC_best'] < AUC:
            recommendation_metric_best['AUC_best'] = AUC
            recommendation_metric_best['AUC_best_epoch'] = epoch
        for ind in range(length_topKs):
            if recommendation_metric_best['HRs_best'][ind] < HRs[ind]:
                recommendation_metric_best['HRs_best'][ind] = HRs[ind]
                recommendation_metric_best['HRs_best_epoch'][ind] = epoch
            if recommendation_metric_best['NDCGs_best'][ind] < NDCGs[ind]:
                recommendation_metric_best['NDCGs_best'][ind] = NDCGs[ind]
                recommendation_metric_best['NDCGs_best_epoch'][ind] = epoch
        if recommendation_metric_best['MRR_best'] < MRR:
            recommendation_metric_best['MRR_best'] = MRR
            recommendation_metric_best['MRR_best_epoch'] = epoch
            accumulate_best_mrr_epoch = 0  # reset
        print('Epoch={}: (AUC,MRR) = {:.4f}, {:.4f}, [{:.2f}s]'.format(epoch, AUC, MRR, time.time()-start_time))
        for ind, topK in enumerate(config['topKs']):
            print('HR={:.4f}, NDCG={:.4f} at top-{}'.format(HRs[ind], NDCGs[ind], topK))
        sys.stdout.flush()

        accumulate_best_mrr_epoch += 1
        accumulate_best_accuracy_epoch += 1
        # if accumulate_best_mrr_epoch > 5 and accumulate_best_accuracy_epoch > 5:  # early stopping
        if accumulate_best_mrr_epoch > 5:  # early stopping
            print('early stopping...')
            break

        model.global_epoch_increment.eval()

    pp.pprint(recommendation_metric_best)
    pp.pprint(gender_age_occupation_metric_best)

    sys.stdout.flush()

print('---{:.2f}min---'.format((time.time()-start_time) / 60))

'''
{'AUC_best': 0.952477818753791,
 'AUC_best_epoch': 1,
 'HRs_best': [0.5680234048600671,
              0.7587802487497468,
              0.8597316587746594,
              0.9415027609379591,
              0.9830705939297468],
 'HRs_best_epoch': [5, 1, 1, 3, 2],
 'MRR_best': 0.657087884577244,
 'MRR_best_epoch': 5,
 'NDCGs_best': [0.5680234048600671,
                0.6662497213513103,
                0.6989336658686313,
                0.719540002771952,
                0.7283576712877663],
 'NDCGs_best_epoch': [5, 1, 1, 1, 1]}
{'age': {'accuracy': [0.19489981785063754, 0],
         'macro': {'f1': [0.17196660494990748, 10],
                   'precision': [0.19089141289469178, 27],
                   'recall': [0.30208781870123724, 0]},
         'micro': {'f1': [0.19489981785063754, 0],
                   'precision': [0.19489981785063754, 0],
                   'recall': [0.19489981785063754, 0]},
         'weighted': {'f1': [0.1475355493493391, 10],
                      'precision': [0.19618889601308825, 27],
                      'recall': [0.19489981785063754, 0]}},
 'gender': {'accuracy': [0.28688524590163933, 23],
            'macro': {'f1': [0.2837191476953198, 23],
                      'precision': [0.3758146943985206, 23],
                      'recall': [0.41015092016373256, 0]},
            'micro': {'f1': [0.28688524590163933, 23],
                      'precision': [0.28688524590163933, 23],
                      'recall': [0.28688524590163933, 23]},
            'weighted': {'f1': [0.2738437393264303, 13],
                         'precision': [0.44445711245705083, 23],
                         'recall': [0.28688524590163933, 23]}},
 'occupation': {'accuracy': [0.08105646630236794, 14],
                'macro': {'f1': [0.01587151094525699, 22],
                          'precision': [0.020787153938630194, 22],
                          'recall': [0.061400519473110204, 24]},
                'micro': {'f1': [0.08105646630236794, 14],
                          'precision': [0.08105646630236794, 14],
                          'recall': [0.08105646630236794, 14]},
                'weighted': {'f1': [0.020777401328269215, 14],
                             'precision': [0.031165043354378145, 22],
                             'recall': [0.08105646630236794, 14]}}}

'''
