import argparse
import torch
import os
from best.data_iterators import iter_best_files, iter_best_old_files
from best.model import model_handler
from best.test_ids import test_ids, valid_ids
from best.custom_logging import logging
from best.belief_model import belief_model_handler
from best.results_to_file import write_results


parser = argparse.ArgumentParser(description='train.py')
parser.add_argument('--data-root', default='data/ldc2016e114/data/eng',
                    help='path to LDC 2016 e114 English data')
parser.add_argument('--attitude', default='belief',
                    help='Choose from belief, sentiment')
parser.add_argument('--old-data-root', default=None,
                    help='path to LDC 2016 e27v2 English data')
parser.add_argument('--glove', default='data/glove.6B/',
                    help='path to pretrained glove word embeddings')
parser.add_argument('--write-results', default='results/ldc2016e114/neuralnetmodel/',
                    help='path to write results to for LDC 2016 e114 English data for neural net model')
# None= no pretrained embeddings, projection = projection layers, direct = pretrained embeddings learn directly
parser.add_argument('--pretrained', default="doc_projection",
                    help='Choose from: None, doc_projection, projection, direct')
parser.add_argument('--n-epochs', type=int, default=6)
parser.add_argument('--embedding-dim', type=int, default=50)
parser.add_argument('--hidden-dim', type=int, default=40)
parser.add_argument('--learning-rate', type=float, default=0.07)
parser.add_argument('--momentum', type=float, default=0.0)
parser.add_argument('--decay', type=float, default=0.0001)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--minibatch', type=int, default=10)
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--dropout', type=float, default=0.10,
                    help="Should be at most 0.5")
parser.add_argument('--rng-seed', type=int, default=1)
parser.add_argument('--attention', default="multilinear",
                    help='Choose from None, multilinear, mlp')
parser.add_argument('--multilinear-hyperparam', type=int, default=17)
parser.add_argument('--encode-relation', default="vector",
                    help="Choose from naive, vector, vectoraffine, affine")
parser.add_argument('--encode-event', default="vector",
                    help="Choose from naive, vector, affine")
parser.add_argument('--optim', default="Adam",
                    help="Choose from Adam, SGD, RMSProp")


def sum(x):
    t = 0
    for y in x:
        t = t + x[y]
    return t


if __name__ == '__main__':

    opt = parser.parse_args()

    train_docs = []
    valid_docs = []
    test_docs = []
    m,n = 0,0
    entity, rel, ev, all = {},{},{},{}
    entitys, rels, evs, alls = {},{},{},{}
    for doc in iter_best_files(opt.data_root):
        if doc.doc_id in valid_ids:
            valid_docs.append(doc)
        elif doc.doc_id in test_ids:
            test_docs.append(doc)
        else:
            # b = doc.evaluator_best.beliefs
            # s = doc.evaluator_best.sentiments
            #
            # for belief in b:
            #     m+=1
            #     t = belief.target.mention_id
            #     x = doc.doc_id
            #     y = t+x
            #     if y in all:
            #         all[y] += 1
            #         if t[1] == "-":
            #             entity[y] += 1
            #         elif t[1] == "e":
            #             rel[y] += 1
            #         elif t[1] == "m":
            #             ev[y] += 1
            #     else:
            #         all[y] = 1
            #         if t[1] == "-":
            #             entity[y] = 1
            #         elif t[1] == "e":
            #             rel[y] = 1
            #         elif t[1] == "m":
            #             ev[y] = 1
            # for sentiment in s:
            #     n+=1
            #     t = sentiment.target.mention_id
            #     x = doc.doc_id
            #     y = t+x
            #     if y in alls:
            #         alls[y] += 1
            #         if t[1] == "-":
            #             entitys[y] += 1
            #         elif t[1] == "e":
            #             rels[y] += 1
            #         elif t[1] == "m":
            #             evs[y] += 1
            #     else:
            #         alls[y] = 1
            #         if t[1] == "-":
            #             entitys[y] = 1
            #         elif t[1] == "e":
            #             rels[y] = 1
            #         elif t[1] == "m":
            #             evs[y] = 1
            train_docs.append(doc)
    # print(len(alls), len(all), len(entity), len(rel), len(ev), len(entitys), len(rels), len(evs))
    # s,b,enb,reb,evb,ens,res,evs = sum(alls), sum(all), sum(entity), sum(rel), sum(ev), sum(entitys), sum(rels), sum(evs)
    # print(s,b,enb,reb,evb,ens,res,evs)
    # print(m,n)
    if opt.old_data_root:
        train_docs.extend(iter_best_old_files(opt.old_data_root))

    logging.info('# train docs: {}'.format(len(train_docs)))
    logging.info('# valid docs: {}'.format(len(valid_docs)))

    if opt.pretrained:
        opt.embedding_dim = 50
        glove_path = os.path.dirname(opt.glove)
        glove_path = os.path.join(glove_path, 'glove.6B.50d.txt')
    else:
        glove_path = None
    write_results_dir = os.path.dirname(opt.write_results)
    file_name = '{}_events={}_relation={}_pretrain={}_attention={}'.format(opt.attitude,
                                                                           opt.encode_event,
                                                                           opt.encode_relation,
                                                                           opt.pretrained,
                                                                           opt.attention)
    torch.manual_seed(opt.rng_seed)
    if opt.attitude == 'sentiment':
        sentiment_validation_results = model_handler(train_docs,
                                                     valid_docs,
                                                     glove_path,
                                                     embedding_dim=opt.embedding_dim,
                                                     n_epochs=opt.n_epochs,
                                                     hidden_dim=opt.hidden_dim,
                                                     initial_learning_rate=opt.learning_rate,
                                                     momentum=opt.momentum,
                                                     weight_decay=opt.decay,
                                                     num_layers=opt.num_layers,
                                                     minibatch=opt.minibatch,
                                                     dropout=opt.dropout,
                                                     bidirectional=opt.bidirectional,
                                                     pretrained=opt.pretrained,
                                                     optimize=opt.optim,
                                                     attention=opt.attention,
                                                     encode_relation=opt.encode_relation,
                                                     encode_event=opt.encode_event,
                                                     multilinear_hyperparam=opt.multilinear_hyperparam)
        write_results(results_dir=write_results_dir,
                      file_name=file_name,
                      validation_results=sentiment_validation_results,
                      opt=opt)
    elif opt.attitude == 'belief':
        belief_validation_results = belief_model_handler(train_docs,
                                                         valid_docs,
                                                         glove_path,
                                                         embedding_dim=opt.embedding_dim,
                                                         n_epochs=opt.n_epochs,
                                                         hidden_dim=opt.hidden_dim,
                                                         initial_learning_rate=opt.learning_rate,
                                                         momentum=opt.momentum,
                                                         weight_decay=opt.decay,
                                                         num_layers=opt.num_layers,
                                                         minibatch=opt.minibatch,
                                                         dropout=opt.dropout,
                                                         bidirectional=opt.bidirectional,
                                                         pretrained=opt.pretrained,
                                                         optimize=opt.optim,
                                                         attention=opt.attention,
                                                         encode_relation=opt.encode_relation,
                                                         encode_event=opt.encode_event,
                                                         multilinear_hyperparam=opt.multilinear_hyperparam)
        write_results(results_dir=write_results_dir,
                      file_name=file_name,
                      validation_results=belief_validation_results,
                      opt=opt)
    else:
        raise ValueError
