# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""

from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import random
from collections import defaultdict
import re
import shutil

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange

from transformers import (WEIGHTS_NAME, BertConfig,
                                  BertTokenizer,
                                  RobertaConfig,
                                  RobertaForSequenceClassification,
                                  RobertaForMarkerSequenceClassification,
                                  RobertaTokenizer,
                                  get_linear_schedule_with_warmup,
                                  AdamW
                                  )


from transformers import AutoTokenizer

from torch.utils.data import TensorDataset, Dataset
import json
import pickle
import numpy as np
import unicodedata
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
import time

from knowledge_bert.modeling import ERNIEForMarkerSequenceClassification



logger = logging.getLogger(__name__)

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,  RobertaConfig)), ())

MODEL_CLASSES = {
    'roberta': (RobertaConfig, RobertaForMarkerSequenceClassification, RobertaTokenizer),
    'ernie': (RobertaConfig, ERNIEForMarkerSequenceClassification, RobertaTokenizer),
}

class Wiki80Dataset(Dataset):
    def __init__(self, tokenizer, args=None, evaluate=False, do_test=False, ent_emb=None, entity2id=None):
        if evaluate:
            if do_test:
                file_path = 'test.jsonl'
            else:
                file_path = 'dev.jsonl'
        else:
            file_path = args.train_file

        file_path = os.path.join(args.data_dir, file_path)
        assert os.path.isfile(file_path)
        self.args = args
        self.data_json = []
        with open(file_path, "r", encoding='utf-8') as f:
            print('reading file:', file_path)
            for line in f:
                self.data_json.append(json.loads(line))
            print('done reading')
        self.label2id = json.load(open(os.path.join(args.data_dir, 'label2id.json')))

        self.tokenizer = tokenizer
        self.vocab_size = self.tokenizer.vocab_size 
        self.max_seq_length = args.max_seq_length

        self.zero_item = [0] * self.max_seq_length
        self.one_zero_item = [1] + [0] * (self.max_seq_length-1)

        self._max_mention_length = 30
        self.dim = dim = 768
        self.position_bias = args.position_bias

        self.entity_K = args.entity_K

        self.null_entity_position = [[-1] * self._max_mention_length  for i in range(self.entity_K)] 

        self.qid2id = {} 

        self.entity2id = entity2id
        self.ent_emb = ent_emb
        print ('Loaded data')

    def __len__(self):
        return len(self.data_json)

    def __getitem__(self, idx):
        entry = self.data_json[idx]
        item = self.one_example_to_tensors(entry)
        return item



    def tokenize(self, raw_tokens, pos_head, pos_tail):
        def getIns(bped, bpeTokens, tokens, L):
            resL = 0
            tkL = " ".join(tokens[:L])
            bped_tkL = " ".join(self.tokenizer.tokenize(tkL))
            if bped.find(bped_tkL) == 0:
                resL = len(bped_tkL.split())
            else:
                tkL += " "
                bped_tkL = " ".join(self.tokenizer.tokenize(tkL))
                if bped.find(bped_tkL) == 0:
                    resL = len(bped_tkL.split())
                else:
                    raise Exception("Cannot locate the position")
            return resL


        s = " ".join(raw_tokens)
        sst = self.tokenizer.tokenize(s)
        headL = pos_head[0]
        headR = pos_head[-1]
        hiL = getIns(" ".join(sst), sst, raw_tokens, headL)
        hiR = getIns(" ".join(sst), sst, raw_tokens, headR)

        tailL = pos_tail[0]
        tailR = pos_tail[-1]
        tiL = getIns(" ".join(sst), sst, raw_tokens, tailL)
        tiR = getIns(" ".join(sst), sst, raw_tokens, tailR)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(sst)

        if self.args.model_type=='ernie':
            E1b, E1e, E2b, E2e = self.tokenizer.convert_tokens_to_ids(['madeupword0000', 'madeupword0001', 'madeupword0002', 'madeupword0003'])

        else:
            E1b = self.vocab_size   #'madeupword0000'
            E1e = self.vocab_size+1 #'madeupword0001'
            E2b = self.vocab_size+2 #'madeupword0002'
            E2e = self.vocab_size+3 #'madeupword0003'

        ins = [(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)]
        ins = sorted(ins)
        h_pos = [-1, -1]
        t_pos = [-1, -1]
        for i in range(0, 4):
            indexed_tokens.insert(ins[i][0] + i, ins[i][1])
            if ins[i][1]==E1b:
                h_pos[0] = ins[i][0] + i
            if ins[i][1]==E1e:
                h_pos[1] = ins[i][0] + i 

            if ins[i][1]==E2b:
                t_pos[0] = ins[i][0] + i 
            if ins[i][1]==E2e:
                t_pos[1] = ins[i][0] + i 
        assert(h_pos[0]>=0 and h_pos[1]>=0 and t_pos[0]>=0 and t_pos[1]>=0)
        return indexed_tokens, h_pos, t_pos

    def one_example_to_tensors(self, example):
        max_length = self.max_seq_length
        input_ids, h_pos, t_pos = self.tokenize(example['token'], example['h']['pos'], example['t']['pos'])

        h_id = example['h']['id']
        t_id = example['t']['id']
        h_pos[0] += 1
        h_pos[1] += 1
        t_pos[0] += 1
        t_pos[1] += 1
        input_ids = input_ids[:max_length -2]
        input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] 

        if self.args.model_type=='ernie':
            input_ent = torch.zeros((max_length,), dtype=torch.int64)
            ent_mask = torch.zeros((max_length,), dtype=torch.int64)

            if h_id in self.entity2id:
                head_id = self.entity2id[h_id]
            else:
                head_id = -1
            if t_id in self.entity2id:
                tail_id = self.entity2id[t_id]
            else:
                tail_id = -1
            
            for p in range(h_pos[0]+1, h_pos[1]):
                if p >= max_length-1:
                    break
                input_ent[p] = head_id + 1
                ent_mask[p] = 1

            for p in range(t_pos[0]+1, t_pos[1]):
                if p >= max_length-1:
                    break
                input_ent[p] = tail_id + 1
                ent_mask[p] = 1

            # input_ent = torch.tensor(input_ent).long()
            input_ent_emb = self.ent_emb(input_ent)


        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)

        attention_mask = [1] * len(input_ids) + [0] * padding_length
        token_type_ids = [0] * max_length
        input_ids += [self.tokenizer.pad_token_id] * padding_length

        assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
        assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)

        label = self.label2id[example['relation']]
        item = [torch.tensor(input_ids),  torch.tensor(attention_mask), torch.tensor(token_type_ids), torch.tensor(label)]

        item.append(torch.tensor([min(h_pos[0], max_length-1), min(t_pos[0], max_length-1)]))
        if self.args.model_type=='ernie':
            item += [input_ent_emb, ent_mask]
        return item
 

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
    if len(glob_checkpoints) <= args.save_total_limit:
        return

    ordering_and_checkpoint_path = []
    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)

def train(args, model, tokenizer, ent_emb=None, entity2id=None):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter("logs/wiki80_log/"+args.output_dir[args.output_dir.rfind('/'):])

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)


                
    train_dataset = Wiki80Dataset(tokenizer=tokenizer, args=args, ent_emb=ent_emb, entity2id=entity2id)

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=int(args.output_dir.find('test')==-1))

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1*t_total), num_training_steps=t_total
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    best_acc = 0
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM and RoBERTa don't use segment_ids
                      'labels':         batch[3]}
            inputs.update({
                    'ht_position':       batch[4], 
            })


            if args.model_type=='ernie':
                inputs.update({
                        'input_ent':       batch[5], 
                        'ent_mask':        batch[6],
                })


            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    if args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer, ent_emb=ent_emb, entity2id=entity2id)

                        acc = results['f1']
                        if acc > best_acc:
                            checkpoint_prefix = 'checkpoint'

                            best_acc = acc
                            print ('Best F1', best_acc)
                            output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training

                            if args.model_type.startswith('ernie'):
                                torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch.pt'))
                            else:
                                model_to_save.save_pretrained(output_dir)


                            torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                            logger.info("Saving model checkpoint to %s", output_dir)
                            _rotate_checkpoints(args, checkpoint_prefix)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, best_acc


def evaluate(args, model, tokenizer, prefix="", do_test=False, ent_emb=None, entity2id=None):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    results = {}

    eval_dataset = Wiki80Dataset(tokenizer=tokenizer, args=args, evaluate=True, do_test=do_test, ent_emb=ent_emb, entity2id=entity2id)


    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_dataset)  
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=1)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
                        'attention_mask': batch[1],
                        'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,  # XLM and RoBERTa don't use segment_ids
                        'labels':         batch[3],
                        }
            inputs.update({
                    'ht_position':       batch[4], 
            })

            if args.model_type=='ernie':
                inputs.update({
                        'input_ent':       batch[5], 
                        'ent_mask':        batch[6],
                })
            outputs = model(**inputs)
            tmp_eval_loss, logits = outputs[:2]

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1

        pred = torch.argmax(logits, axis=1)
        preds.append(pred.detach().cpu().numpy())

        out_label_ids.append(inputs['labels'].detach().cpu().numpy())


    eval_loss = eval_loss / nb_eval_steps

    preds = np.concatenate(preds, axis=0)
    out_label_ids = np.concatenate(out_label_ids, axis=0)
    if do_test:
        np.save(os.path.join(args.output_dir, 'output_labels.npy'), preds)

    f1 = f1_score(preds, out_label_ids, average="micro")


    results = {#'eval_loss': eval_loss,
            'f1':  f1,
            # 'macro': loose_macro(tmp_true, tmp_pred),
            # 'micro': loose_micro(tmp_true, tmp_pred)
            }


    return results




def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action='store_true',
                        help="Whether to run test on the dev set.")

    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Rul evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=100,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
    parser.add_argument('--save_total_limit', type=int, default=1,
                        help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')

    # parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument('--entity_K', type=int, default=0)
    parser.add_argument( "--train_file",  default="train.jsonl", type=str)
    parser.add_argument( "--dev_file",  default="dev.jsonl", type=str)
    parser.add_argument( "--test_file",  default="test.jsonl", type=str)
    parser.add_argument('--shareqkv', action='store_true')
    parser.add_argument('--position_bias', type=int, default=2)


    args = parser.parse_args()

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    # Set seed
    set_seed(args)


    num_labels = 80
    

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab


    args.model_type = args.model_type.lower()

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels)

    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)

    ent_emb = None
    entity2id = {}

    if args.model_type=='ernie':
        model = model_class.from_pretrained(args.model_name_or_path)[0]
        vecs = []
        vecs.append([0]*100)
        with open(os.path.join(args.model_name_or_path, "kg_embed/entity2vec.vec"), 'r') as fin:
            for line in tqdm(fin):
                vec = line.strip().split('\t')
                vec = [float(x) for x in vec]
                vecs.append(vec)


        embed = torch.FloatTensor(vecs)
        ent_emb = torch.FloatTensor(vecs)
        ent_emb = torch.nn.Embedding.from_pretrained(ent_emb)

        with open(os.path.join(args.model_name_or_path, "kg_embed/entity2id.txt")) as fin:
            fin.readline()
            for line in fin:
                qid, eid = line.strip().split('\t')
                entity2id[qid] = int(eid)            


    else:
        model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)

    if args.model_type.startswith('roberta') and tokenizer.convert_tokens_to_ids('[unused0]')==tokenizer.unk_token_id:
        special_tokens_dict = {'additional_special_tokens': ['[unused' + str(x) + ']' for x in range(4)]}
        tokenizer.add_special_tokens(special_tokens_dict)
        model.roberta.resize_token_embeddings(len(tokenizer))



    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        # train_dataset = load_and_cache_examples(args,  tokenizer, evaluate=False)
        global_step, tr_loss, best_acc = train(args, model, tokenizer, ent_emb=ent_emb, entity2id=entity2id)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
            results = evaluate(args, model, tokenizer,  ent_emb=ent_emb, entity2id=entity2id)
            acc = results['f1']
            if acc > best_acc:
                best_acc = acc
                print ('Best F1', best_acc)
                checkpoint_prefix = 'checkpoint'
                output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training

                if args.model_type.startswith('ernie'):
                    torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'pytorch.pt'))
                else:
                    model_to_save.save_pretrained(output_dir)


                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)
                _rotate_checkpoints(args, checkpoint_prefix)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        # tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        # checkpoints = [args.output_dir]
        # if args.eval_all_checkpoints:
        if args.model_type.startswith('ernie'):
            WEIGHTS_NAME = 'pytorch.pt'
        else:
            WEIGHTS_NAME = 'pytorch_model.bin'

        checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
        logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""

            if args.model_type.startswith('ernie'):
                model = model_class.from_pretrained(args.model_name_or_path, torch.load(os.path.join(checkpoint, 'pytorch.pt')))[0]
            else:
                model = model_class.from_pretrained(checkpoint)

            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=global_step, do_test=True,  ent_emb=ent_emb, entity2id=entity2id)
            # print (global_step, result)
            result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
            results.update(result)
        output_eval_file = os.path.join(args.output_dir, "results.json")
        json.dump(results, open(output_eval_file, "w"))

    if args.local_rank in [-1, 0]:
        print (results)

    # return results


if __name__ == "__main__":
    main()

