# coding: utf-8
import math
import torch
import argparse

from torch import nn
from misc import setup_seed, compute_P_R_F1
from model import GraphConvModel, HFetClassifier, Encoder, HFet
from dataset import GraphDataset, TestGraphDataset, DataLoader, HFetSentenceDataset
from step import train_step, test_step

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', dest='name', default='save')
    parser.add_argument('--sample_encoder', dest='sample_encoder', default='hfet')
    parser.add_argument('--dataset', dest='dataset', default='ontonotes')
    parser.add_argument('--hierlossnorm', dest='hierlossnorm', default=False, type=float)

    parser.add_argument('--batchsize', dest='batchsize', default=160, type=int)
    parser.add_argument('--lr', dest='lr', default=5e-5, type=float)
    parser.add_argument('--weight_decay', type=float, default=0.001)
    parser.add_argument('--hidden', dest='n_hidden', default=400, type=int)
    parser.add_argument('--neighbors', dest='n_neighbor', default=5, type=int)
    parser.add_argument('--layers', dest='n_graphconv', default=1, type=int)

    parser.add_argument('--max_epochs', dest='max_epochs', default=100, type=int)

    parser.add_argument('--inter_keepprob', dest='inter_keepprob', default=0.3, type=float)
    parser.add_argument('--neg_weight', dest='neg_weight', default=0.5, type=float)

    parser.add_argument('--test', dest='test', action='store_true', default=False)
    parser.add_argument('--restore', dest='restore', default=0)
    parser.add_argument('--testper', dest='testper', default=200, type=int)

    parser.add_argument('--elmo_option')
    parser.add_argument('--elmo_weight')
    
    args = parser.parse_args()
    return args



if __name__ == '__main__':
    args = get_args()

    # set dimensions
    label_dim = 2048
    sample_dim = 2048
    n_hidden = args.n_hidden
    n_neighbor = args.n_neighbor
    n_graphconv = args.n_graphconv

    if args.dataset == 'ontonotes':
        n_lbls = 89
        lbl2id_fn = 'data/ontology/onto_ontology.txt'
    elif args.dataset == 'BBN':
        n_lbls = 47
        lbl2id_fn = 'data/BBN/hierarchy.txt'

    print('# Labels:', n_lbls)


    # load data
    
    train_graph_dataset = GraphDataset(root='.', sample_n={'sent':n_neighbor}, partial_n=None, \
                            interaction_keep_prob=args.inter_keepprob)
    train_loader = DataLoader(train_graph_dataset, batch_size=args.batchsize, shuffle=True)
    'test on testset'
    test_graph_dataset = TestGraphDataset(root='.', sample_n={'sent':n_neighbor}, graph_info_fn='tmp/graph_info_test.pkl')
    test_on_test_loader = DataLoader(test_graph_dataset, batch_size=args.batchsize, shuffle=False)



    train_sent_dataset = HFetSentenceDataset('data/{}/g_train.json'.format(args.dataset))
    test_sent_dataset = HFetSentenceDataset({'train': 'data/{}/g_train.json'.format(args.dataset), 
                                            'test': 'data/{}/g_test.json'.format(args.dataset)})
    # train_sent_dataset = HFetSentenceDataset('data/ontonotes/augmented_train.json')
    # test_sent_dataset = HFetSentenceDataset({'train': 'data/ontonotes/augmented_train.json', 
    #                                         'test': 'data/ontonotes/g_test.json'})


    print('train dataset len: {}, test dataset len: {}'.format(len(train_graph_dataset), len(test_graph_dataset)))
    print('train sents: {}, test sents: {}'.format(len(train_sent_dataset), len(test_sent_dataset)))


    # initialize model parameters
    lbl_embeds = nn.Embedding(n_lbls, label_dim).to(device)
    torch.nn.init.kaiming_uniform_(lbl_embeds.weight, a=math.sqrt(5))

    print('using hfet encoder')
    sample_encoder = HFet(n_lbls, args.elmo_option, args.elmo_weight, elmo_dropout=.5,\
                            repr_dropout=.2, dist_dropout=.2, latent_size=0)


    encoder = Encoder(n_lbls, sample_encoder, train_sent_dataset, lbl_embeds).to(device)
    model = GraphConvModel(in_features=sample_dim, hidden_features=n_hidden, n_layers=n_graphconv).to(device)

    print('using hfet classifier')
    classifier = HFetClassifier(sample_dim=sample_dim, label_dim=label_dim, n_lbls=n_lbls, 
                        hierlossnorm_ontology=args.hierlossnorm and lbl2id_fn).to(device)



    # restore from file
    if args.restore:
        load_dict = torch.load('save/{}_{}.pth'.format(args.name, args.restore))
        model.load_state_dict(load_dict['model'])
        classifier.load_state_dict(load_dict['classifier'])
        encoder.load_state_dict(load_dict['encoder'])


    # criterion, optimizer, strategy
    from pytorch_pretrained_bert import BertAdam
    criterion = nn.BCELoss()
    optimizer = BertAdam([{'params': filter(lambda x: x.requires_grad, classifier.parameters()), 'initial_lr': args.lr},\
                            {'params': filter(lambda x: x.requires_grad, model.parameters()), 'initial_lr': args.lr},\
                            {'params': filter(lambda x: x.requires_grad, encoder.parameters()), 'initial_lr': args.lr}],\
                            lr=args.lr, warmup=.1, 
                            weight_decay=args.weight_decay, t_total=args.max_epochs*(len(train_graph_dataset) // args.batchsize +1))


    if args.test:
        print('--------Test on testset--------')
        hit, positive, true = 0,0,0
        mP, mR, mF1, N = 0,0,0,0
        encoder.sentence_dataset = test_sent_dataset
        for i, batch_data in enumerate(test_on_test_loader):
            print(i, end='\r')
            hit_i, positive_i, true_i, mP_i, mR_i, mF1_i, n = test_step(batch_data, encoder, model, classifier, n_lbls, threshold=0.5, hfet=args.sample_encoder=='hfet')
            hit, positive, true = hit+hit_i, positive+positive_i, true+true_i

            mP, mR, mF1, N = mP+mP_i, mR+mR_i, mF1+mF1_i, N+n

            
        metrics = compute_P_R_F1(hit, positive, true, mP, mR, mF1, N)
        print(metrics)
        encoder.sentence_dataset = train_sent_dataset
        exit()


    ishfet = args.sample_encoder == 'hfet'

    best_macro_f1 = 0
    best_micro_f1 = 0
    for epoch in range(args.max_epochs):
        losses = []

        for i, batch_data in enumerate(train_loader):
            print('Epoch {}, Batch {}'.format(epoch+1, i+1), end='\r')
            loss_i = train_step(batch_data, encoder, model, classifier, criterion, optimizer, n_lbls, args, negsampling=not ishfet)
            losses.append(loss_i)

            if (i + 1) % args.testper == 0:
                print('--------Test on testset--------')
                hit, positive, true = 0,0,0
                mP, mR, mF1, N = 0,0,0,0
                encoder.sentence_dataset = test_sent_dataset

                for i, batch_data in enumerate(test_on_test_loader):
                    print(i, end='\r')
                    hit_i, positive_i, true_i, mP_i, mR_i, mF1_i, n = test_step(batch_data, encoder, model, classifier, n_lbls, threshold=0.5, hfet=ishfet)
                    hit, positive, true = hit+hit_i, positive+positive_i, true+true_i

                    mP, mR, mF1, N = mP+mP_i, mR+mR_i, mF1+mF1_i, N+n

                metrics = compute_P_R_F1(hit, positive, true, mP, mR, mF1, N)
                print(metrics)
                if metrics['MACRO']['F1'] > best_macro_f1:
                    best_macro_f1 = metrics['MACRO']['F1']
                    print('saving best macro F1 model')
                    torch.save({'encoder': encoder.state_dict(), 'classifier': classifier.state_dict(), 'model': model.state_dict()}\
                        , 'save/{}_{}.pth'.format(args.name, 'macf1'))
                if metrics['MICRO']['F1'] > best_micro_f1:
                    best_micro_f1 = metrics['MICRO']['F1']
                    print('saving best micro F1 model')
                    torch.save({'encoder': encoder.state_dict(), 'classifier': classifier.state_dict(), 'model': model.state_dict()}\
                        , 'save/{}_{}.pth'.format(args.name, 'micf1'))


                encoder.sentence_dataset = train_sent_dataset
           
        print('Epoch {}, Loss {:.3f}'.format(epoch+1, sum(losses) / len(losses)))

