import argparse
import torch
import os
from best.data_iterators import iter_best_files, iter_best_old_files
from best.model import model_handler
from best.test_ids import test_ids, valid_ids
from best.custom_logging import logging
from best.belief_model import belief_model_handler
from best.results_to_file import write_results
from best.mention_statistics import mention_statistics
from pathlib import Path 


parser = argparse.ArgumentParser(description='train.py')
parser.add_argument('--data-root', default='data/ldc2016e114/data/eng',
                    help='path to LDC 2016 e114 English data')
parser.add_argument('--attitude', default='sentiment',
                    help='Choose from belief, sentiment')
parser.add_argument('--old-data-root', default=None,
                    help='path to LDC 2016 e27v2 English data')
parser.add_argument('--glove', default='data',
                    help='path to pretrained glove word embeddings')
parser.add_argument('--elmo', default='data',
                    help='path to pretrained glove word embeddings')
parser.add_argument('--write-results', default='results/ldc2016e114/neuralnetmodel/',
                    help='path to write results to for LDC 2016 e114 English data for neural net model')
# None= no pretrained embeddings, projection = projection layers, direct = pretrained embeddings learn directly
parser.add_argument('--pretrained', default="frozen",
                    help='Choose from: None,, frozen, doc_projection, projection, direct')
parser.add_argument('--n-epochs', type=int, default=100)
parser.add_argument('--embedding-dim', type=int, default=3372)
parser.add_argument('--hidden-dim', type=int, default=128)
parser.add_argument('--learning-rate', type=float, default=0.10)
parser.add_argument('--momentum', type=float, default=0.0)
parser.add_argument('--decay', type=float, default=0.0001)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--batch', type=int, default=10)
parser.add_argument('--batch_mentions', type=bool, default=False)
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--dropout', type=float, default=0.20,
                    help="Should be at most 0.5")
parser.add_argument('--rng-seed', type=int, default=1)
parser.add_argument('--attention', default="interiormultilinear",
                    help='Choose from None, multiplicative, multilinear, interiormultilinear, mlp')
parser.add_argument('--parameterization', default="classify",
                    help='Choose from classify, rank')
parser.add_argument('--attention-hyperparam', type=int, default=32)
parser.add_argument('--encode-relation', default="naive",
                    help="Choose from naive, vector, vectoraffine, affine")
parser.add_argument('--encode-event', default="naive",
                    help="Choose from naive, vector, affine")
parser.add_argument('--optim', default="Adam",
                    help="Choose from Adam, SGD, RMSProp")


if __name__ == '__main__':

    opt = parser.parse_args()

    train_docs = []
    valid_docs = []
    test_docs = []
    for doc in iter_best_files(opt.data_root):
        if True: #str(doc.doc_id)[4] == 'D':
            if doc.doc_id in valid_ids:
                valid_docs.append(doc)
            elif doc.doc_id in test_ids:
                test_docs.append(doc)
            else:
                train_docs.append(doc)
    mention_statistics(train_docs)
    exit()
    train_docs = train_docs # For Final Submission - SPNLP 2019
    valid_docs = test_docs # For Final Submission - SPNLP 2019
    if opt.old_data_root:
        train_docs.extend(iter_best_old_files(opt.old_data_root))

    logging.info('# train docs: {}'.format(len(train_docs)))
    logging.info('# valid docs: {}'.format(len(valid_docs)))

    if opt.pretrained == "None":
        opt.pretrained = None

    if opt.pretrained:
        glove_path = Path(opt.glove)
        if opt.embedding_dim in {50, 3122}:
            glove_path = glove_path / 'glove.6B' 
            glove_file = 'glove.6B.50d.txt'
        else:
            if opt.embedding_dim not in {300, 3072}:
                opt.embedding_dim = 3372
            glove_path = glove_path / 'glove.840B'
            glove_file = 'glove.840B.300d.txt'


    else:
        glove_path = None

    write_results_dir = os.path.dirname(opt.write_results)
    file_name = '{}_events={}_relation={}_pretrain={}_attention={}'.format(opt.attitude,
                                                                           opt.encode_event,
                                                                           opt.encode_relation,
                                                                           opt.pretrained,
                                                                           opt.attention)

    torch.manual_seed(opt.rng_seed)
    if opt.attitude == 'sentiment':
        sentiment_validation_results = model_handler(train_docs,
                                                     valid_docs,
                                                     glove_path,
                                                     glove_file,
                                                     embedding_dim=opt.embedding_dim,
                                                     n_epochs=opt.n_epochs,
                                                     hidden_dim=opt.hidden_dim,
                                                     initial_learning_rate=opt.learning_rate,
                                                     momentum=opt.momentum,
                                                     weight_decay=opt.decay,
                                                     num_layers=opt.num_layers,
                                                     batch=opt.batch,
                                                     dropout=opt.dropout,
                                                     bidirectional=opt.bidirectional,
                                                     pretrained=opt.pretrained,
                                                     optimize=opt.optim,
                                                     attention=opt.attention,
                                                     encode_relation=opt.encode_relation,
                                                     encode_event=opt.encode_event,
                                                     attention_hyperparam=opt.attention_hyperparam,
                                                     parameterization=opt.parameterization,
                                                     batch_mentions=opt.batch_mentions)
        write_results(results_dir=write_results_dir,
                      file_name=file_name,
                      validation_results=sentiment_validation_results,
                      opt=opt)
    elif opt.attitude == 'belief':
        belief_validation_results = belief_model_handler(train_docs,
                                                         valid_docs,
                                                         glove_path,
                                                         embedding_dim=opt.embedding_dim,
                                                         n_epochs=opt.n_epochs,
                                                         hidden_dim=opt.hidden_dim,
                                                         initial_learning_rate=opt.learning_rate,
                                                         momentum=opt.momentum,
                                                         weight_decay=opt.decay,
                                                         num_layers=opt.num_layers,
                                                         batch=opt.batch,
                                                         dropout=opt.dropout,
                                                         bidirectional=opt.bidirectional,
                                                         pretrained=opt.pretrained,
                                                         optimize=opt.optim,
                                                         attention=opt.attention,
                                                         encode_relation=opt.encode_relation,
                                                         encode_event=opt.encode_event,
                                                         attention_hyperparam=opt.attention_hyperparam,
                                                         parameterization=opt.parameterization,
                                                         batch_mentions=opt.batch_mentions)
        write_results(results_dir=write_results_dir,
                      file_name=file_name,
                      validation_results=belief_validation_results,
                      opt=opt)
    else:
        raise ValueError
