import sys
sys.path.extend(["../../","../","./"])
import random
import itertools
import argparse
from data.Dataloader import *
from driver.Config import *
import time
from modules.Parser import *
from modules.WordLSTM import *
from modules.PretrainedWordEncoder import *
from modules.Sent2Span import *
from modules.BiSent2Sent import *
from modules.EDULSTM import *
from modules.Decoder import *
from modules.BertModel import *
from modules.BertTokenHelper import *
import pickle


class Optimizer:
    def __init__(self, parameter, config, lr):
        self.optim = torch.optim.Adam(parameter, lr=lr, betas=(config.beta_1, config.beta_2),
                                      eps=config.epsilon, weight_decay=config.L2_REG)
        decay, decay_step = config.decay, config.decay_steps
        l = lambda epoch: decay ** (epoch // decay_step)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda=l)

    def step(self):
        self.optim.step()
        self.schedule()
        self.optim.zero_grad()

    def schedule(self):
        self.scheduler.step()

    def zero_grad(self):
        self.optim.zero_grad()

    @property
    def lr(self):
        return self.scheduler.get_lr()

def train(train_inst, dev_data, test_data, parser, vocab, config, tokenizer):
    model_param = filter(lambda p: p.requires_grad,
                         itertools.chain(
                             parser.pwordEnc.parameters(),
                             parser.bisent2sent.parameters(),
                             parser.sent2span.parameters(),
                             parser.wordLSTM.parameters(),
                             parser.EDULSTM.parameters(),
                             parser.dec.parameters()
                         )
                         )

    model_optimizer = Optimizer(model_param, config, config.learning_rate)

    global_step = 0
    best_FF = 0
    batch_num = int(np.ceil(len(train_inst) / float(config.train_batch_size)))

    for iter in range(config.train_iters):
        start_time = time.time()
        print('Iteration: ' + str(iter))
        batch_iter = 0

        overall_action_correct,  overall_total_action = 0, 0
        for onebatch in data_iter(train_inst, config.train_batch_size, True):

            bert_indice_input, bert_segments_ids, bert_piece_ids, bert_mask, \
            bisent2pre_sent_offset, bisent2sent_offset = \
                batch_pretrain_variable_sent_level(onebatch, vocab, config, tokenizer)
            sent2span_index = batch_sent2span_offset(onebatch, config)

            batch_feats, batch_actions, batch_action_indexes, batch_candidate = \
                actions_variable(onebatch, vocab)

            edu_words, edu_extwords, edu_tags, word_mask, edu_mask, word_denominator, edu_types =\
                batch_data_variable(onebatch, vocab, config)

            parser.train()
            #with torch.autograd.profiler.profile() as prof:
            parser.encode(bert_indice_input, bert_segments_ids, bert_piece_ids, bert_mask,
                          bisent2pre_sent_offset, bisent2sent_offset, sent2span_index,
                          edu_words, edu_extwords, edu_tags, word_mask, edu_mask, word_denominator, edu_types)
            predict_actions = parser.decode(onebatch, batch_feats, batch_candidate, vocab)

            loss = parser.compute_loss(batch_action_indexes)
            loss = loss / config.update_every
            loss_value = loss.data.cpu().numpy()
            loss.backward()

            total_actions, correct_actions = parser.compute_accuracy(predict_actions, batch_actions)
            overall_total_action += total_actions
            overall_action_correct += correct_actions
            during_time = float(time.time() - start_time)
            acc = overall_action_correct / overall_total_action
            #acc = 0
            print("Step:%d, Iter:%d, batch:%d, time:%.2f, acc:%.2f, loss:%.2f"
                  %(global_step, iter, batch_iter,  during_time, acc, loss_value))
            batch_iter += 1

            if batch_iter % config.update_every == 0 or batch_iter == batch_num:
                nn.utils.clip_grad_norm_(model_param, max_norm=config.clip)
                model_optimizer.step()
                model_optimizer.zero_grad()

                global_step += 1

            if batch_iter % config.validate_every == 0 or batch_iter == batch_num:
                print("Dev:")
                predict(dev_data, parser, vocab, config, tokenizer, config.dev_file + '.' + str(global_step))
                dev_FF = evaluate(config.dev_file, config.dev_file + '.' + str(global_step))

                print("Test:")
                predict(test_data, parser, vocab, config, tokenizer, config.test_file + '.' + str(global_step))
                evaluate(config.test_file, config.test_file + '.' + str(global_step))

                if dev_FF > best_FF:
                    print("Exceed best Full F-score: history = %.2f, current = %.2f" % (best_FF, dev_FF))
                    best_FF = dev_FF
                    if config.save_after >= 0 and iter >= config.save_after:
                        discoure_parser_model = {
                            "pwordEnc": parser.pwordEnc.state_dict(),
                            "sent2span": parser.sent2span.state_dict(),
                            "bisent2sent": parser.bisent2sent.state_dict(),
                            "wordLSTM": parser.wordLSTM.state_dict(),
                            "EDULSTM": parser.EDULSTM.state_dict(),
                            "dec": parser.dec.state_dict()
                            }
                        torch.save(discoure_parser_model, config.save_model_path + "." + str(global_step))
                        print('Saving model to ', config.save_model_path + "." + str(global_step))

def evaluate(gold_file, predict_file):
    gold_data = read_corpus(gold_file)
    predict_data = read_corpus(predict_file)
    S = Metric()
    N = Metric()
    R = Metric()
    F = Metric()
    for gold_doc, predict_doc in zip(gold_data, predict_data):
        assert len(gold_doc.EDUs) == len(predict_doc.EDUs)
        assert len(gold_doc.sentences) == len(predict_doc.sentences)
        gold_doc.evaluate_labelled_attachments(predict_doc.result, S, N, R, F)
    print("S:", end=" ")
    S.print()
    print("N:", end=" ")
    N.print()
    print("R:", end=" ")
    R.print()
    print("F:", end=" ")
    F.print()
    return F.getAccuracy()

def predict(data, parser, vocab, config, tokenizer, outputFile):
    start = time.time()
    parser.eval()
    outf = open(outputFile, mode='w', encoding='utf8')
    for onebatch in data_iter(data, config.test_batch_size, False):
        bert_indice_input, bert_segments_ids, bert_piece_ids, bert_mask, \
        bisent2pre_sent_offset, bisent2sent_offset = \
            batch_pretrain_variable_sent_level(onebatch, vocab, config, tokenizer)
        sent2span_index = batch_sent2span_offset(onebatch, config)

        edu_words, edu_extwords, edu_tags, word_mask, edu_mask, word_denominator, edu_types = \
            batch_data_variable(onebatch, vocab, config)

        # with torch.autograd.profiler.profile() as prof:
        parser.encode(bert_indice_input, bert_segments_ids, bert_piece_ids, bert_mask,
                      bisent2pre_sent_offset, bisent2sent_offset, sent2span_index,
                      edu_words, edu_extwords, edu_tags, word_mask, edu_mask, word_denominator, edu_types)
        parser.decode(onebatch, None, None, vocab)
        batch_size = len(onebatch)
        for idx in range(batch_size):
            doc = onebatch[idx][0]
            cur_states = parser.batch_states[idx]
            cur_step = parser.step[idx]
            predict_tree = cur_states[cur_step - 1]._stack[cur_states[cur_step - 1]._stack_size - 1].str
            for sent, tags, type in zip(doc.origin_sentences, doc.sentences_tags, doc.sent_types):
                assert len(sent) == len(tags)
                for w, tag in zip(sent, tags):
                    outf.write(w + '_' + tag + ' ')
                outf.write(type[-1])
                outf.write('\n')
            for info in doc.other_infos:
                outf.write(info + '\n')
            outf.write(predict_tree + '\n')
            outf.write('\n')
    outf.close()
    end = time.time()
    during_time = float(end - start)
    print("Doc num: %d,  parser time = %.2f " % (len(data), during_time))

if __name__ == '__main__':
    ### process id
    print("Process ID {}, Process Parent ID {}".format(os.getpid(), os.getppid()))

    random.seed(666)
    np.random.seed(666)
    torch.cuda.manual_seed(666)
    torch.manual_seed(666)

    ### gpu
    gpu = torch.cuda.is_available()
    print("GPU available: ", gpu)
    print("CuDNN: \n", torch.backends.cudnn.enabled)

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--config_file', default='examples/default.cfg')
    argparser.add_argument('--model', default='BaseParser')
    argparser.add_argument('--thread', default=4, type=int, help='thread num')
    argparser.add_argument('--use-cuda', action='store_true', default=True)
    argparser.add_argument('--conj_model', default="model.101")

    args, extra_args = argparser.parse_known_args()
    config = Configurable(args.config_file, extra_args)

    train_data = read_corpus(config.train_file)
    dev_data = read_corpus(config.dev_file)
    test_data = read_corpus(config.test_file)
    vocab = creatVocab(train_data, config.min_occur_count)
    #vec = vocab.load_pretrained_embs(config.pretrained_embeddings_file)# load extword table and embeddings

    torch.set_num_threads(args.thread)

    config.use_cuda = False
    if gpu and args.use_cuda: config.use_cuda = True
    print("\nGPU using status: ", config.use_cuda)

    start_a = time.time()
    train_feats, train_actions = get_gold_actions(train_data, vocab)
    print("Get Action Time: ", time.time() - start_a)
    vocab.create_action_table(train_actions)

    train_candidate = get_gold_candid(train_data, vocab)

    train_insts = inst(train_data, train_feats, train_actions, train_candidate)
    dev_insts = inst(dev_data)
    test_insts = inst(test_data)

    print("train num: ", len(train_insts))
    print("dev num: ", len(dev_insts))
    print("test num: ", len(test_insts))

    print('Load pretrained encoder.....')
    tok = BertTokenHelper(config.bert_dir)
    enc_model = BertExtractor(config)
    print(enc_model)
    print('Load pretrained encoder ok')

    pwordEnc = PretrainedWordEncoder(config, enc_model, enc_model.bert_hidden_size, enc_model.layer_num)

    wordLSTM = WordLSTM(vocab, config)
    sent2span = Sent2Span(vocab, config)
    bisent2sent = BiSent2Sent(vocab, config)
    EDULSTM = EDULSTM(vocab, config)
    dec = Decoder(vocab, config)
    pickle.dump(vocab, open(config.save_vocab_path, 'wb'))

    if args.conj_model != "":
        conj_model = torch.load(args.conj_model)
        pwordEnc.load_state_dict(conj_model['pwordEnc'])
        bisent2sent.load_state_dict(conj_model["bisent2sent"])
        sent2span.load_state_dict(conj_model["sent2span"])
        wordLSTM.load_state_dict(conj_model["wordLSTM"])
        EDULSTM.load_state_dict(conj_model["EDULSTM"])
        print("Load " + args.conj_model + "  model ok....")

    if config.use_cuda:
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.deterministic = True
        #torch.backends.cudnn.benchmark = True
        pwordEnc = pwordEnc.cuda()
        wordLSTM = wordLSTM.cuda()
        bisent2sent = bisent2sent.cuda()
        sent2span = sent2span.cuda()
        EDULSTM = EDULSTM.cuda()
        dec = dec.cuda()

    parser = DisParser(pwordEnc, wordLSTM, bisent2sent, sent2span, EDULSTM, dec, config)
    train(train_insts, dev_insts, test_insts, parser, vocab, config, tok)

