# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning models on sorting task (e.g. Bert, DistilBERT, XLM).
    Adapted from `examples/text-classification/run_xnli.py`"""


import argparse
import glob
import logging
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange


from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    EncoderDecoderModel, EncoderDecoderConfig,
)
from transformers.file_utils import is_sklearn_available, requires_sklearn
from datasets.processors import _pairwise_convert_examples_to_features
from datasets.processors import data_processors as processors
from datasets.processors import output_modes

# Custom dataset classes.
from datasets.processors import HeadPredDataset
from datasets.processors import SortDatasetV1
from datasets.processors import PureClassDataset

# Sorting methods.
from .topological_sort import Graph

# Metrics.
from .metrics import compute_metrics

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter


logger = logging.getLogger(__name__)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def get_models(args):
    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will
        # download model & vocab
        torch.distributed.barrier()
     
    if "pure_decode" in args.sort_method:
        config_1 = EncoderDecoderConfig.from_pretrained(
            args.config_name_1 if args.config_name_1 \
                               else args.model_name_or_path_1,
        )
        tokenizer_1 = AutoTokenizer.from_pretrained(
            args.tokenizer_name_1 if args.tokenizer_name_1 \
                                  else args.model_name_or_path_1,
            do_lower_case=args.do_lower_case,
        )
        model_1 = EncoderDecoderModel.from_pretrained(
            args.model_name_or_path_1,
            from_tf=bool(".ckpt" in args.model_name_or_path_1),
            config=config_1,
        )
        model_1.to(args.device)
        return model_1, tokenizer_1, config_1

    elif "head_and_topological" in args.sort_method:
        config_1 = AutoConfig.from_pretrained(
            args.config_name_1 if args.config_name_1 \
                               else args.model_name_or_path_1,
            num_labels=2,
        )
        config_2 = AutoConfig.from_pretrained(
            args.config_name_2 if args.config_name_2 \
                               else args.model_name_or_path_2,
            num_labels=args.max_story_length,
        )
        tokenizer_1 = AutoTokenizer.from_pretrained(
            args.tokenizer_name_1 if args.tokenizer_name_1 \
                                  else args.model_name_or_path_1,
            do_lower_case=args.do_lower_case,
        )
        model_1 = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path_1,
            from_tf=bool(".ckpt" in args.model_name_or_path_1),
            config=config_1,
        )
        model_2 = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path_2,
            from_tf=bool(".ckpt" in args.model_name_or_path_2),
            config=config_2,
        )
        model_1.to(args.device)
        model_2.to(args.device)
            
        return model_1, model_2, tokenizer_1, config_1, config_2

    elif ("topological" in args.sort_method
        or "pure_classification" == args.sort_method):
        config_1 = AutoConfig.from_pretrained(
            args.config_name_1 if args.config_name_1 \
                               else args.model_name_or_path_1,
            # num_labels=2,
        )
        tokenizer_1 = AutoTokenizer.from_pretrained(
            args.tokenizer_name_1 if args.tokenizer_name_1 \
                                  else args.model_name_or_path_1,
            do_lower_case=args.do_lower_case,
        )
        model_1 = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path_1,
            from_tf=bool(".ckpt" in args.model_name_or_path_1),
            config=config_1,
        )
        model_1.to(args.device)
        return model_1, tokenizer_1, config_1

    elif "head_and_pairwise" in args.sort_method:
        config_1 = AutoConfig.from_pretrained(
            args.config_name_1 if args.config_name_1 \
                               else args.model_name_or_path_1,
            num_labels=2,
        )
        config_2 = AutoConfig.from_pretrained(
            args.config_name_2 if args.config_name_2 \
                               else args.model_name_or_path_2,
            num_labels=args.max_story_length,
        )
        tokenizer_1 = AutoTokenizer.from_pretrained(
            args.tokenizer_name_1 if args.tokenizer_name_1 \
                                  else args.model_name_or_path_1,
            do_lower_case=args.do_lower_case,
        )
        model_1 = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path_1,
            from_tf=bool(".ckpt" in args.model_name_or_path_1),
            config=config_1,
        )
        model_2 = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path_2,
            from_tf=bool(".ckpt" in args.model_name_or_path_2),
            config=config_2,
        )
        model_1.to(args.device)
        model_2.to(args.device)

        if args.sort_method == "head_and_pairwise_abductive":
            if args.abd_pred_method == "binary":
                abd_num_labels = 2
            else:
                raise NotImplementedError("Pred method {} not implemented"
                                          " yet!".format(args.abd_pred_method))
            config_3 = AutoConfig.from_pretrained(
                args.config_name_3 if args.config_name_3 \
                                   else args.model_name_or_path_3,
                num_labels=abd_num_labels,
            )
            model_3 = AutoModelForSequenceClassification.from_pretrained(
                args.model_name_or_path_3,
                from_tf=bool(".ckpt" in args.model_name_or_path_3),
                config=config_3,
            )
            model_3.to(args.device)
            return (model_1, model_2, model_3, tokenizer_1,
                    config_1, config_2, config_3)
            
        return model_1, model_2, tokenizer_1, config_1, config_2
    else:
        raise NotImplementedError("Sort method {} not "
                                  "implemented yet.".format(args.sort_method))

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will
        # download model & vocab
        torch.distributed.barrier() 

    return None, None, None


def topological_inference(args, model, seqs, tokenizer):
    batch_seqs = debatch_stories(seqs)
    len_seq = len(seqs)
    
    adj_mat = [[0 for i in range(len_seq)] for j in range(len_seq)]

    preds = []
    cnt = 0
    loss = 0
    for seq_idx in range(len(batch_seqs)):
        curr_seq = batch_seqs[seq_idx]
        graph = Graph(len_seq)
        for i in range(len_seq):
            for j in range(len_seq):
                if i < j:
                    text_a = curr_seq[i]
                    text_b = curr_seq[j]
                    batch_encoding = tokenizer(
                        [(text_a, text_b)],
                        max_length=args.max_seq_length,
                        padding="max_length",
                        truncation=True,
                        return_token_type_ids=True,
                        return_tensors="pt",
                    )
                    batch_encoding = {k: batch_encoding[k].to(args.device)
                                      for k in batch_encoding}
                    with torch.no_grad():
                        logits = model(**batch_encoding)
                        binary_preds = logits[0].detach().cpu().numpy()[0]
                        pred_label = np.argmax(binary_preds)
                    if pred_label == 1:  # Ordered.
                        graph.addEdge(i, j)
                    else:  # Unordered.
                        graph.addEdge(j, i)
                    cnt += 1
        top_sort = graph.topologicalSort()
        preds.append(top_sort)
    loss /= cnt
    return preds, loss


def head_and_topological_inference(args, head_model, pairwise_model,
                                   seqs, tokenizer, abductive_model=None):

    head_idx = head_and_sequential_inference(args, head_model, pairwise_model,
                                             seqs, tokenizer,
                                             abductive_model=None,
                                             return_head_idx=True)

    batch_seqs = debatch_stories(seqs)
    len_seq = len(seqs)
    
    adj_mat = [[0 for i in range(len_seq)] for j in range(len_seq)]

    preds = []
    cnt = 0
    loss = 0
    for seq_idx in range(len(batch_seqs)):
        curr_seq = batch_seqs[seq_idx]
        graph = Graph(len_seq)
        for i in range(len_seq):
            for j in range(len_seq):
                if i < j:
                    text_a = curr_seq[i]
                    text_b = curr_seq[j]
                    batch_encoding = tokenizer(
                        [(text_a, text_b)],
                        max_length=args.max_seq_length,
                        padding="max_length",
                        truncation=True,
                        return_token_type_ids=True,
                        return_tensors="pt",
                    )
                    batch_encoding = {k: batch_encoding[k].to(args.device)
                                      for k in batch_encoding}
                    with torch.no_grad():
                        logits = pairwise_model(**batch_encoding)
                        binary_preds = logits[0].detach().cpu().numpy()[0]
                        pred_label = np.argmax(binary_preds)
                    if pred_label == 1:  # Ordered.
                        graph.addEdge(i, j)
                    else:  # Unordered.
                        graph.addEdge(j, i)
                    cnt += 1
        top_sort = graph.topologicalSort(assert_head=head_idx)
        preds.append(top_sort)
    loss /= cnt
    return preds, loss


def head_and_sequential_inference(args, head_model, pairwise_model,
                                  seqs, tokenizer, abductive_model=None,
                                  return_head_idx=False):
    batch_seqs = debatch_stories(seqs)
    preds = []
    cnt = 0
    loss = 0
    for seq_idx in range(len(batch_seqs)):
        curr_seq = batch_seqs[seq_idx]
       
        # Predicts the head first.
        batch_encoding = tokenizer(
            curr_seq,
            max_length=args.per_seq_max_length,
            padding="max_length",
            truncation=True,
        )

        seqs_input_ids = np.asarray(batch_encoding["input_ids"])
        padded_input_ids = np.ones(args.max_seq_length, dtype=int)
        padded_token_type_ids = np.zeros(args.max_seq_length, dtype=int)
        cat_input_ids = np.asarray([], dtype=int)
        cat_token_type_ids = np.asarray([], dtype=int)

        for i in range(len(seqs_input_ids)):
            seq_input_ids = seqs_input_ids[i]
            seq_input_ids_unpad = seq_input_ids[seq_input_ids!=1]
            cat_input_ids = np.concatenate((cat_input_ids,
                                            seq_input_ids_unpad), axis=0)
            token_type_ids = np.ones(len(seq_input_ids_unpad), dtype=int) * i
            cat_token_type_ids = np.concatenate((cat_token_type_ids,
                                                 token_type_ids), axis=0)
        max_length = min(args.max_seq_length, len(cat_input_ids))
        padded_input_ids[:max_length] = cat_input_ids[:max_length]
        padded_token_type_ids[:max_length] = cat_token_type_ids[:max_length]
        input_ids = torch.Tensor(padded_input_ids).long()
        attention_mask = (input_ids != 1).long()
        token_type_ids = torch.Tensor(padded_token_type_ids).long()

        input_ids = input_ids.unsqueeze(0)
        attention_mask = attention_mask.unsqueeze(0)
        token_type_ids = token_type_ids.unsqueeze(0)
        batch_encoding = {
            "input_ids": input_ids.to(args.device),
            "attention_mask": attention_mask.to(args.device),
            "token_type_ids": token_type_ids.to(args.device),
        }
        if not args.replace_token_type_embeddings:
            batch_encoding["token_type_ids"] = None
        with torch.no_grad():
            logits = head_model(**batch_encoding)
            head_preds = logits[0].detach().cpu().numpy()[0]
            head_preds_label = np.argmax(head_preds)

        if return_head_idx:
            return head_preds_label

        curr_pred = [head_preds_label]
        story_seq_idx = list(range(args.max_story_length))
        
        # Sequentially predict nexts.
        while len(story_seq_idx) > 1:  # When only 1 element left, do nothing.
            story_seq_idx.remove(curr_pred[-1])
            curr_prev_idx = curr_pred[-1]
            next_seq_idx = select_next(args, pairwise_model, tokenizer,
                                       story_seq_idx, curr_pred, curr_seq,
                                       abductive_model=abductive_model)
            curr_pred.append(next_seq_idx)
        preds.append(curr_pred)
        cnt += 1

    loss /= cnt
    return preds, loss


def select_next(args, pairwise_model, tokenizer, seq_idx_left,
                curr_pred, seq, abductive_model=None):
    scores = []
    prev_idx = curr_pred[-1]
    for idx in seq_idx_left:
        sent_cand = seq[idx]
        curr_sent = seq[prev_idx]
        pw_score = pairwise_score(args, pairwise_model, tokenizer,
                                  curr_sent, sent_cand)
        if abductive_model is not None and len(curr_pred) >= 2:
            abd_score = abductive_score(args, abductive_model, tokenizer,
                                        curr_pred, idx, seq)
            score = pw_score + abd_score * 0.1
        else:
            score = pw_score
        scores.append(score)
    scores = np.asarray(scores)
    next_seq_idx = int(np.argmax(scores))
    next_seq_idx = seq_idx_left[next_seq_idx]
    return next_seq_idx

def abductive_score(args, abductive_model, tokenizer, curr_pred, idx, seq):
    text_h1 = seq[curr_pred[-2]]
    text_h2 = seq[curr_pred[-1]]
    text_h3 = seq[idx]

    story_seq = [text_h1, text_h2, text_h3]

    batch_encoding = tokenizer(
        story_seq,
        max_length=args.per_seq_max_length,
        padding="max_length",
        truncation=True,
    )

    seqs_input_ids = np.asarray(batch_encoding["input_ids"])
    padded_input_ids = np.ones(args.max_seq_length, dtype=int)
    padded_token_type_ids = np.zeros(args.max_seq_length, dtype=int)
    cat_input_ids = np.asarray([], dtype=int)
    cat_token_type_ids = np.asarray([], dtype=int)

    for i in range(len(seqs_input_ids)):
        seq_input_ids = seqs_input_ids[i]
        seq_input_ids_unpad = seq_input_ids[seq_input_ids!=1]
        cat_input_ids = np.concatenate((cat_input_ids,
                                        seq_input_ids_unpad), axis=0)
        token_type_ids = np.ones(len(seq_input_ids_unpad), dtype=int) * i
        cat_token_type_ids = np.concatenate((cat_token_type_ids,
                                             token_type_ids), axis=0)
    max_length = min(args.max_seq_length, len(cat_input_ids))
    padded_input_ids[:max_length] = cat_input_ids[:max_length]
    padded_token_type_ids[:max_length] = cat_token_type_ids[:max_length]
    input_ids = torch.Tensor(padded_input_ids).long()
    attention_mask = (input_ids != 1).long()
    token_type_ids = torch.Tensor(padded_token_type_ids).long()

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    token_type_ids = token_type_ids.unsqueeze(0)
    batch_encoding = {
        "input_ids": input_ids.to(args.device),
        "attention_mask": attention_mask.to(args.device),
        "token_type_ids": token_type_ids.to(args.device),
    }
    if not args.replace_token_type_embeddings:
        batch_encoding["token_type_ids"] = None

    if args.abd_pred_method == "binary":
        with torch.no_grad():
            logits = abductive_model(**batch_encoding)
            binary_preds = logits[0].detach().cpu().numpy()[0]
            # Score is determined by the logits at position 1.
            score = binary_preds[1]
    elif args.abd_pred_method == "contrastive":
        raise NotImplementedError("Prediction method {} not"
                                  " done yet!".format(args.abd_pred_method))
    return score


def pairwise_score(args, pairwise_model, tokenizer, curr_sent, next_sent):
    batch_encoding = tokenizer(
        [(curr_sent, next_sent)],
        max_length=args.max_seq_length,
        padding="max_length",
        truncation=True,
        return_token_type_ids=True,
        return_tensors="pt",
    )
    batch_encoding = {k: batch_encoding[k].to(args.device)
                      for k in batch_encoding}
    with torch.no_grad():
        logits = pairwise_model(**batch_encoding)
        binary_preds = logits[0].detach().cpu().numpy()[0]
        # Score is determined by the logits at position 1.
        score = binary_preds[1]
    return score


def pure_class_inference(args, pure_class_model, seqs, tokenizer, id2label):
    batch_seqs = debatch_stories(seqs)
    preds = []
    cnt = 0
    loss = 0
    for seq_idx in range(len(batch_seqs)):
        curr_seq = batch_seqs[seq_idx]
       
        # Predicts the head first.
        batch_encoding = tokenizer(
            curr_seq,
            max_length=args.per_seq_max_length,
            padding="max_length",
            truncation=True,
        )

        seqs_input_ids = np.asarray(batch_encoding["input_ids"])
        padded_input_ids = np.ones(args.max_seq_length, dtype=int)
        padded_token_type_ids = np.zeros(args.max_seq_length, dtype=int)
        cat_input_ids = np.asarray([], dtype=int)
        cat_token_type_ids = np.asarray([], dtype=int)

        for i in range(len(seqs_input_ids)):
            seq_input_ids = seqs_input_ids[i]
            seq_input_ids_unpad = seq_input_ids[seq_input_ids!=1]
            cat_input_ids = np.concatenate((cat_input_ids,
                                            seq_input_ids_unpad), axis=0)
            token_type_ids = np.ones(len(seq_input_ids_unpad), dtype=int) * i
            cat_token_type_ids = np.concatenate((cat_token_type_ids,
                                                 token_type_ids), axis=0)
        max_length = min(args.max_seq_length, len(cat_input_ids))
        padded_input_ids[:max_length] = cat_input_ids[:max_length]
        padded_token_type_ids[:max_length] = cat_token_type_ids[:max_length]
        input_ids = torch.Tensor(padded_input_ids).long()
        attention_mask = (input_ids != 1).long()
        token_type_ids = torch.Tensor(padded_token_type_ids).long()

        input_ids = input_ids.unsqueeze(0)
        attention_mask = attention_mask.unsqueeze(0)
        token_type_ids = token_type_ids.unsqueeze(0)
        batch_encoding = {
            "input_ids": input_ids.to(args.device),
            "attention_mask": attention_mask.to(args.device),
            "token_type_ids": token_type_ids.to(args.device),
        }
        if not args.replace_token_type_embeddings:
            batch_encoding["token_type_ids"] = None
        with torch.no_grad():
            logits = pure_class_model(**batch_encoding)
            curr_preds = logits[0].detach().cpu().numpy()[0]
            preds_label = int(np.argmax(curr_preds))

        curr_pred = id2label[preds_label]
        preds.append(curr_pred)
        cnt += 1

    loss /= cnt
    return preds, loss


def pure_decode_inference(args, encoder_decoder_model, seqs, tokenizer):
    batch_seqs = debatch_stories(seqs)
    preds = []
    cnt = 0
    loss = 0
    for seq_idx in range(len(batch_seqs)):
        curr_seq = batch_seqs[seq_idx]
       
        # Predicts the head first.
        batch_encoding = tokenizer(
            curr_seq,
            max_length=args.per_seq_max_length,
            padding="max_length",
            truncation=True,
        )

        seqs_input_ids = np.asarray(batch_encoding["input_ids"])
        padded_input_ids = np.ones(args.max_seq_length, dtype=int)
        padded_token_type_ids = np.zeros(args.max_seq_length, dtype=int)
        cat_input_ids = np.asarray([], dtype=int)
        cat_token_type_ids = np.asarray([], dtype=int)

        for i in range(len(seqs_input_ids)):
            seq_input_ids = seqs_input_ids[i]
            seq_input_ids_unpad = seq_input_ids[seq_input_ids!=1]
            cat_input_ids = np.concatenate((cat_input_ids,
                                            seq_input_ids_unpad), axis=0)
            token_type_ids = np.ones(len(seq_input_ids_unpad), dtype=int) * i
            cat_token_type_ids = np.concatenate((cat_token_type_ids,
                                                 token_type_ids), axis=0)
        max_length = min(args.max_seq_length, len(cat_input_ids))
        padded_input_ids[:max_length] = cat_input_ids[:max_length]
        padded_token_type_ids[:max_length] = cat_token_type_ids[:max_length]
        input_ids = torch.Tensor(padded_input_ids).long()
        attention_mask = (input_ids != 1).long()
        token_type_ids = torch.Tensor(padded_token_type_ids).long()

        input_ids = input_ids.unsqueeze(0)
        attention_mask = attention_mask.unsqueeze(0)
        token_type_ids = token_type_ids.unsqueeze(0)
        batch_encoding = {
            "input_ids": input_ids.to(args.device),
            "attention_mask": attention_mask.to(args.device),
            "token_type_ids": token_type_ids.to(args.device),
        }
        if not args.replace_token_type_embeddings:
            batch_encoding["token_type_ids"] = None
        with torch.no_grad():
            # https://huggingface.co/blog/how-to-generate
            outputs = encoder_decoder_model.generate(
                batch_encoding["input_ids"],
                max_length=len(seqs),
                num_beams=5,
                no_repeat_ngram_size=2,
                decoder_start_token_id=encoder_decoder_model.config.decoder.pad_token_id,
            )

        curr_pred = list(outputs.cpu().numpy()[0])
        preds.append(curr_pred)
        cnt += 1

    loss /= cnt
    return preds, loss


def debatch_stories(seqs):
    len_seq = len(seqs)
    batch_size = len(seqs[0])
    batch_seqs = []
    for i in range(batch_size):
        batch_seq = []
        for j in range(len_seq):
            batch_seq.append(seqs[j][i])
        batch_seqs.append(batch_seq)
    return batch_seqs


def model_wise_evaluate(args, models, batch, tokenizer, id2label=None):
    stories, labels = batch
    if args.sort_method == "head_and_topological":
        pairwise_model, head_model = models[0], models[1]
        preds, loss = head_and_topological_inference(args, head_model,
                                                     pairwise_model,
                                                     stories, tokenizer)
    elif "topological" in args.sort_method:
        pairwise_model = models[0]
        preds, loss = topological_inference(args, pairwise_model,
                                            stories, tokenizer)
    elif args.sort_method == "head_and_pairwise":
        pairwise_model, head_model = models[0], models[1]
        preds, loss = head_and_sequential_inference(args, head_model,
                                                    pairwise_model,
                                                    stories, tokenizer)
    elif args.sort_method == "head_and_pairwise_abductive":
        pairwise_model, head_model = models[0], models[1]
        abductive_model = models[2]
        preds, loss = head_and_sequential_inference(args,
            head_model, pairwise_model,
            stories, tokenizer,
            abductive_model=abductive_model)
    elif args.sort_method == "pure_classification":
        pure_class_model = models[0]
        preds, loss = pure_class_inference(args, pure_class_model,
                                           stories, tokenizer, id2label)
    elif args.sort_method == "pure_decode":
        encoder_decoder_model = models[0]
        preds, loss = pure_decode_inference(args, encoder_decoder_model,
                                            stories, tokenizer)
    else:
        raise NotImplementedError("Sort method {} not "
                                  "implemented yet.".format(args.sort_method))

    return preds, labels, loss


def evaluate(args, models, tokenizer, prefix=""):
    eval_outputs_dir = args.output_dir

    results = {}

    for eval_task in args.task_names:
        task_proc_class = processors[eval_task]
        if task_proc_class is None:
            logger.error("No processor for task: {}".format(eval_task))
            continue

        task_proc = processors[eval_task](
            max_story_length=args.max_story_length)

        if args.data_phase == "val":
            eval_examples = task_proc.get_dev_examples()
        elif args.data_phase == "train":
            eval_examples = task_proc.get_train_examples()
        else:
            eval_examples = task_proc.get_test_examples()

        eval_dataset = SortDatasetV1(eval_examples, tokenizer,
                                     max_length=args.max_seq_length,
                                     max_story_length=args.max_story_length,
                                     seed=args.seed)

        if "pure_classification" == args.sort_method:
            dummy_dataset = PureClassDataset(eval_examples, tokenizer,
                                             max_length=args.max_seq_length,
                                             max_story_length=args.max_story_length,
                                             seed=args.seed)
            id2label = dummy_dataset.id2label
        else:
            id2label = None

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # multi-gpu eval
        if args.n_gpu > 1 and not isinstance(models[0], torch.nn.DataParallel):
            for i in range(len(models)):
                models[i] = torch.nn.DataParallel(models[i])

        # Eval!
        logger.info("***** Running evaluation on {} *****".format(eval_task))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = []
        labels = []
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            b_preds, b_labels, b_eval_loss = model_wise_evaluate(args, models,
                                                                 batch,
                                                                 tokenizer,
                                                                 id2label=id2label)
            preds += b_preds
            labels += b_labels
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps

        for metrics in args.metrics:
            acc = compute_metrics(args, metrics, preds, labels)
            acc = round(acc, 6)
            result = {
                eval_task+"_accuracy on {}".format(metrics): acc,
            }
            results.update(result)

    output_eval_file = os.path.join(eval_outputs_dir, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval Results *****")
        for key in sorted(results.keys()):
            logger.info("  %s = %s", key, str(results[key]))
            writer.write("%s = %s\n" % (key, str(results[key])))

    return results


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir.",
    )
    parser.add_argument(
        "--model_name_or_path_1",
        default=None,
        type=str,
        required=True,
        help=("Path to pretrained model or model identifier from "
              "huggingface.co/models"),
    )
    parser.add_argument(
        "--model_name_or_path_2",
        default=None,
        type=str,
        required=False,
        help=("Path to pretrained model or model identifier from "
              "huggingface.co/models"),
    )
    parser.add_argument(
        "--model_name_or_path_3",
        default=None,
        type=str,
        required=False,
        help=("Path to pretrained model or model identifier from "
              "huggingface.co/models"),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=False,
        help=("The output directory where the results will be stored."),
    )
    parser.add_argument(
        "--max_story_length",
        default=5,
        type=int,
        required=False,
        help=("The maximum length of the story sequence."),
    )
    parser.add_argument(
        "--per_seq_max_length",
        default=64,
        type=int,
        required=False,
        help=("The maximum length of EACH of the story sequence."),
    )
    parser.add_argument(
        "--sort_method",
        default=None,
        type=str,
        choices=[
            "topological", "topological_sort", "head_and_topological",
            "head_and_pairwise", "head_and_pairwise_abductive",
            "pure_classification", "pure_decode",
        ],
        required=True,
        help=("The method for predicting the sorted sequence."),
    )
    parser.add_argument(
        "--abd_pred_method",
        default="binary",
        type=str,
        required=False,
        choices=["binary", "contrastive"],
        help=("The prediction method, see datasets/roc.py"),
    )
    parser.add_argument(
        "--data_phase",
        default="val",
        choices=["train", "val", "test"],
        type=str,
        required=False,
        help=("The phase of evaluation data loading."),
    )
    parser.add_argument(
        "--task_names",
        default=None,
        nargs="+",
        type=str,
        required=True,
        help=("The list of task names."),
    )
    parser.add_argument(
        "--metrics",
        default="partial_match",
        nargs="+",
        choices=["partial_match", "exact_match", "distance_based",
                 "longest_common_subsequence", "lcs",
                 "ms", "wms", "others"],
        type=str,
        required=False,
        help=("The metrics for evaluating the sorted sequence."),
    )
    parser.add_argument(
        "--replace_token_type_embeddings",
        action="store_true",
        help="If replace the pretrained token type embedding with new one.",
    )

    # Other parameters
    parser.add_argument(
        "--config_name_1", default="", type=str, help=("Pretrained config name "
            "or path if not the same as model_name")
    )
    parser.add_argument(
        "--tokenizer_name_1",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--config_name_2", default="", type=str, help=("Pretrained config name "
            "or path if not the same as model_name")
    )
    parser.add_argument(
        "--tokenizer_name_2",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--config_name_3", default="", type=str, help=("Pretrained config name "
            "or path if not the same as model_name")
    )
    parser.add_argument(
        "--tokenizer_name_3",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. "
             "Sequences longer than this will be truncated, sequences "
             "shorter will be padded.",
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", 
        help="Set this flag if you are using an uncased model."
    )

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=1, type=int,
        help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--logging_steps", type=int, default=500,
                        help="Log every X updates steps.")
    parser.add_argument("--no_cuda", action="store_true",
                        help="Avoid using CUDA when available")
    parser.add_argument("--seed", type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision "
             "(through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in "
             "['O0', 'O1', 'O2', and 'O3']."
             "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="",
                        help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="",
                        help="For distant debugging.")
    args = parser.parse_args()
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    # Initializes the distributed backend which will take care of
    # sychronizing nodes/GPUs
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed "
        "training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    if "head_and_topological" in args.sort_method:
        (pairwise_model, head_model, tokenizer,
         pairwise_config, head_config) = get_models(args)
        models = [pairwise_model, head_model]
    elif "topological" in args.sort_method:
        pairwise_model, tokenizer, pairwise_config = get_models(args)
        models = [pairwise_model]
    elif args.sort_method == "head_and_pairwise":
        (pairwise_model, head_model, tokenizer,
         pairwise_config, head_config) = get_models(args)
        models = [pairwise_model, head_model]
    elif args.sort_method == "head_and_pairwise_abductive":
        (pairwise_model, head_model, abductive_model, tokenizer,
         pairwise_config, head_config, abductive_config) = get_models(args)
        models = [pairwise_model, head_model, abductive_model]
    elif "pure_classification" in args.sort_method:
        pure_class_model, tokenizer, pure_class_config = get_models(args)
        models = [pure_class_model]
    elif "pure_decode" in args.sort_method:
        encoder_decoder_model, tokenizer, pure_decode_config = get_models(args)
        models = [encoder_decoder_model]
    else:
        raise NotImplementedError("Sort method {} not "
                                  "implemented yet.".format(args.sort_method))

    logger.info("Training/evaluation parameters %s", args)

    # Evaluation
    results = {}
    if args.local_rank in [-1, 0]:
        prefix = ""
        result = evaluate(args, models, tokenizer, prefix=prefix)
        result = dict((k, v) for k, v in result.items())
        results.update(result)

    return results


if __name__ == "__main__":
    main()
