import sys
sys.path.extend(["../../","../","./"])
import random
import argparse
from driver.Config import *
import numpy as np
import torch
import pickle
from data.Dataloader import *
from modules.Parser import *
from modules.WordEncoder import *
from modules.PretrainedWordEncoder import *
from modules.Sent2Span import *
from modules.EDUEncdoer import *
from modules.Decoder import *
from modules.TagEncoder import *
from modules.BertModel import *
from modules.BertTokenHelper import *
from driver.TrainTest import predict, evaluate

if __name__ == '__main__':
    random.seed(666)
    np.random.seed(666)
    torch.cuda.manual_seed(666)
    torch.manual_seed(666)

    ### gpu
    gpu = torch.cuda.is_available()
    print("GPU available: ", gpu)
    print("CuDNN: \n", torch.backends.cudnn.enabled)

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--config_file', default='experiments/rst_model/config.cfg')
    argparser.add_argument('--model_id', default=4, type=int, help='model id')
    argparser.add_argument('--thread', default=1, type=int, help='thread num')
    argparser.add_argument('--use-cuda', action='store_true', default=True)
    argparser.add_argument('--test_file', default='', help='without evaluation')

    args, extra_args = argparser.parse_known_args()
    config = Configurable(args.config_file, extra_args)

    vocab = pickle.load(open(config.load_vocab_path, 'rb'))
    discoure_parser_model = torch.load(config.load_model_path + '.' + str(args.model_id))

    config.use_cuda = False
    if gpu and args.use_cuda: config.use_cuda = True
    print("\nGPU using status: ", config.use_cuda)

    if args.test_file != "":
        test_data = read_corpus(args.test_file)
        test_insts = inst(test_data)

        print('Load pretrained encoder.....')
        tok = BertTokenHelper(config.bert_dir)
        enc_model = BertExtractor(config)
        print(enc_model)
        print('Load pretrained encoder ok')

        wordEnc = WordEncoder(vocab, config)
        tagEnc = TagEncoder(vocab, config)
        sent2span = Sent2Span(vocab, config)
        EDUEnc = EDUEncoder(vocab, config)
        dec = Decoder(vocab, config)

        pwordEnc = PretrainedWordEncoder(config, enc_model, enc_model.bert_hidden_size, enc_model.layer_num)

        wordEnc.load_state_dict(discoure_parser_model["wordEnc"])
        tagEnc.load_state_dict(discoure_parser_model["tagEnc"])
        sent2span.load_state_dict(discoure_parser_model["sent2span"])
        EDUEnc.load_state_dict(discoure_parser_model["EDUEnc"])
        dec.load_state_dict(discoure_parser_model["dec"])

        if config.use_cuda:
            torch.backends.cudnn.enabled = True
            # torch.backends.cudnn.benchmark = True
            pwordEnc = pwordEnc.cuda()
            wordEnc = wordEnc.cuda()
            tagEnc = tagEnc.cuda()
            sent2span = sent2span.cuda()
            EDUEnc = EDUEnc.cuda()
            dec = dec.cuda()

        parser = DisParser(pwordEnc, wordEnc, tagEnc, sent2span, EDUEnc, dec, config)
        predict(test_insts, parser, vocab, config, tok, args.test_file + '.out')
        evaluate(args.test_file, args.test_file + '.out')

