import torch
from torch._C import dtype
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from argparse import ArgumentParser
import os
import pprint
import jsonlines
from tqdm import tqdm
from collections import defaultdict
import random
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from typing import *
from utils import Foobar_pruning
from functools import reduce
import copy

# relations considered in the experiments
ConceptNetRelations = ['AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'Desires', 'HasA', 'HasPrerequisite',
                       'HasProperty', 'HasSubevent', 'IsA', 'MadeOf', 'MotivatedByGoal', 'NotDesires', 'PartOf', 'ReceivesAction', 'UsedFor']

id2relation = {
    0: 'HasSubevent',
    1: 'MadeOf',
    2: 'HasPrerequisite',
    3: 'MotivatedByGoal',
    4: 'AtLocation',
    5: 'CausesDesire',
    6: 'IsA',
    7: 'NotDesires',
    8: 'Desires',
    9: 'CapableOf',
    10: 'PartOf',
    11: 'HasA',
    12: 'UsedFor',
    13: 'ReceivesAction',
    14: 'Causes',
    15: 'HasProperty'
}

relation2id = {relation: id for id, relation in id2relation.items()}


def load_masks(model_name: str, bli: int, tli: int, relations: List, init_method: str):
    masks = []
    for relation in relations:
        mask_pth = "/home1/roy/commonsense/LAMA/masks/{}_{}_{}_{}_{}_init>{}.pickle".format(
            model_name, relation, (tli-bli+1)*6, bli, tli, init_method)
        with open(mask_pth, mode='rb') as f:
            mask = torch.load(f)
            masks.append(mask)
        print("Loading mask from {}".format(mask_pth))
    return masks


def union_masks(*masks):
    thresholded_masks = []
    for mask in masks:
        tmp = []
        assert isinstance(mask[0], torch.nn.Parameter)
        for matrix in mask:
            prob = torch.sigmoid(matrix.data)
            prob[prob > 0.5] = 1
            prob[prob <= 0.5] = 0
            prob = prob.bool()
            tmp.append(prob)
        thresholded_masks.append(tmp)
    final_masks = []
    for mask_for_all_relations in zip(*thresholded_masks):
        tmp_mask = reduce(lambda x, y: torch.logical_or(
            x, y), mask_for_all_relations)
        final_masks.append(tmp_mask)
    cnt = 0
    num_0_all = 0
    for mask in final_masks:
        num_0 = (mask.int() == 0).sum().item()
        cnt += mask.nelement()
        num_0_all += num_0
    print(num_0_all)
    print(cnt)
    print(num_0_all / cnt)
    return final_masks


def apply_masks(model, model_name, bli, tli, masks):
    backbone = model
    model_type = model_name.split('-')[0]
    if 'roberta' in model_name:
        layers = backbone.roberta.encoder.layer
    elif 'distil' in model_name:
        layers = backbone.distilbert.transformer.layer
    else:
        layers = backbone.bert.encoder.layer

    # load pre-trained masks
    parameters_tobe_pruned = []
    for i in range(bli, tli+1):
        try:
            parameters_tobe_pruned.append(
                (layers[i].attention.self.query, 'weight'))
            parameters_tobe_pruned.append(
                (layers[i].attention.self.key, 'weight'))
            parameters_tobe_pruned.append(
                (layers[i].attention.self.value, 'weight'))
            parameters_tobe_pruned.append(
                (layers[i].attention.output.dense, 'weight'))
            parameters_tobe_pruned.append(
                (layers[i].intermediate.dense, 'weight'))
            parameters_tobe_pruned.append(
                (layers[i].output.dense, 'weight'))
        except Exception:
            parameters_tobe_pruned.append(
                (layers[i].attention.q_lin, 'weight')
            )
            parameters_tobe_pruned.append(
                (layers[i].attention.k_lin, 'weight')
            )
            parameters_tobe_pruned.append(
                (layers[i].attention.v_lin, 'weight')
            )
            parameters_tobe_pruned.append(
                (layers[i].attention.out_lin, 'weight')
            )
            parameters_tobe_pruned.append(
                (layers[i].ffn.lin1, 'weight')
            )
            parameters_tobe_pruned.append(
                (layers[i].ffn.lin2, 'weight')
            )
    assert len(masks) == len(
        parameters_tobe_pruned), f"{parameters_tobe_pruned} != {len(masks)}"
    for mask, (module, name) in zip(masks, parameters_tobe_pruned):
        # prune.random_unstructured(module, name, amount=0.125)
        prune.custom_from_mask(module, name, mask)
        prune.remove(module, name)
    print("Pre-computed mask applied to {}".format(model_name))


class Relation2Sentence:
    """
    util class
    """
    lama_path = "./data/ConceptNet/test.jsonl"

    def __init__(self) -> None:
        super().__init__()
        self.relation2sentences = defaultdict(lambda: set())
        # self.collect_templates()
        self.collect_template_manual()

    def collect_templates(self):
        file = open(self.lama_path, mode='r', encoding='utf-8')
        self.cnt = 0
        for instance in tqdm(jsonlines.Reader(file)):
            relation = instance['pred']
            subj = instance['sub']
            obj = instance['obj']
            obj_label = instance['obj_label']
            masked_sentence = instance['masked_sentences'][0]
            sentence_template = masked_sentence.replace(subj, '[subj]')
            if '[subj]' not in sentence_template:
                continue
            sentence_template = sentence_template.replace('[MASK]', '[obj]')
            if '[obj]' not in sentence_template:
                continue
            self.relation2sentences[relation].add(sentence_template)
        for rel in self.relation2sentences:
            print("{} templates for relation {}".format(
                len(self.relation2sentences[rel]), rel))
            self.cnt += len(self.relation2sentences[rel])
        print("Templates collected")
        print("Total num: {}".format(self.cnt))

    def collect_template_manual(self):
        """
        Natural language templates for each commonsense relation in ConceptNet(LAMA)
        """
        self.relation2sentences['AtLocation'] = [
            "Something you find at [obj] is [subj]."]
        self.relation2sentences['CapableOf'] = ["[subj] can [obj]."]
        self.relation2sentences['Causes'] = ["[subj] causes [obj]."]
        self.relation2sentences['CausesDesire'] = [
            "[subj] would make you want to [obj]."]
        self.relation2sentences['Desires'] = ["[subj] wants [obj]."]
        self.relation2sentences['HasA'] = ["[subj] contains [obj]."]
        self.relation2sentences['HasPrerequisite'] = ["[subj] requires [obj]."]
        self.relation2sentences['HasProperty'] = ["[subj] can be [obj]."]
        self.relation2sentences['HasSubevent'] = ["when [subj], [obj]."]
        self.relation2sentences['IsA'] = ["[subj] is a [obj]."]
        self.relation2sentences['MadeOf'] = ["[subj] can be made of [obj]."]
        self.relation2sentences['MotivatedByGoal'] = [
            "you would [subj] because [obj]."]
        self.relation2sentences['NotDesires'] = ["[subj] does not want [obj]."]
        self.relation2sentences['PartOf'] = ["[subj] is part of [obj]."]
        self.relation2sentences['ReceivesAction'] = ["[subj] can be [obj]."]
        self.relation2sentences['UsedFor'] = ["[subj] may be used for [obj]."]
        print("Finish collecting templates for commonsense relations")

    def __getitem__(self, key):
        return self.relation2sentences[key]


class SplitDataset(Dataset):
    """
    Dataset class for a train/dev1/dev2/test split with structured tripls transformed into natural language sentences
    """

    def __init__(self, split: str, triples: list, transformation: dict) -> None:
        super().__init__()
        self.split = split
        self.triples = triples
        self.transformation = transformation
        self.datas = []
        self.transform()
        assert len(self.datas) > 0

    def transform(self):
        """
        transform structured triples into natural language sentence using LAMA
        """
        for i in range(len(self.triples)):
            if self.split != "train":
                relation, head, tail, label = self.triples[i]
                template = random.sample(self.transformation[relation], 1)[0]
                sentence = template.replace(
                    '[subj]', head).replace('[obj]', tail)
                if isinstance(label, str):
                    label = int(label)
                self.datas.append(
                    [sentence, head, tail, relation, template, label])
            else:
                relation, head, tail, _ = self.triples[i]
                template = random.sample(self.transformation[relation], 1)[0]
                sentence = template.replace(
                    '[subj]', head).replace('[obj]', tail)
                self.datas.append(
                    [sentence, head, tail, relation, template, 1])
        assert len(self.datas) == len(self.triples)
        print("transformation of {} triples finished".format(len(self.datas)))

    def __len__(self) -> int:
        return len(self.datas)

    def __getitem__(self, index: int):
        return self.datas[index]


class ConceptNet100kDataset:
    """
    Base dataset class
    """

    def __init__(self, train_path: str, dev1_path: str, dev2_path: str, dev_total_path: str, test_path: str) -> None:
        super().__init__()
        self.train_path = train_path
        self.dev1_path = dev1_path
        self.dev2_path = dev2_path
        self.dev_total_path = dev_total_path
        self.test_path = test_path
        self.train_triples = self.load_triples(self.train_path)
        self.dev1_triples = self.load_triples(self.dev1_path)
        self.dev2_triples = self.load_triples(self.dev2_path)
        self.dev_total_triples = self.load_triples(self.dev_total_path)
        self.test_triples = self.load_triples(self.test_path)
        print("Number of triples in train set before cleaning: {}".format(
            len(self.train_triples)))
        print("Number of triples in dev1 set before cleaning: {}".format(
            len(self.dev1_triples)))
        print("Number of triples in dev2 set before cleaning: {}".format(
            len(self.dev2_triples)))
        print("Number of triples in dev_total set before cleaning: {}".format(
            len(self.dev_total_triples)))
        print("Number of triples in test set before cleaning: {}".format(
            len(self.test_triples)))

        # preprocessing step
        self.preprocess()

        # pytorch dataset
        self.train_dataset = SplitDataset(
            'train', self.train_triples, relation2sentence)
        self.dev1_dataset = SplitDataset(
            'dev1', self.dev1_triples, relation2sentence)
        self.dev2_dataset = SplitDataset(
            'dev2', self.dev2_triples, relation2sentence)
        self.dev_total_dataset = SplitDataset(
            'dev_total', self.dev_total_triples, relation2sentence)
        self.test_dataset = SplitDataset(
            'test', self.test_triples, relation2sentence)

    def load_triples(self, file_path: str):
        assert os.path.exists(file_path)
        triples = [line.strip().split("\t")
                   for line in open(file_path, mode='r', encoding='utf-8')]
        return triples

    def preprocess(self):
        set_triples = [self.train_triples, self.dev1_triples,
                       self.dev2_triples, self.dev_total_triples, self.test_triples]
        for triples in set_triples:
            cleaned_triples = []
            for triple in triples:
                relation = triple[0]
                if relation in ConceptNetRelations:
                    cleaned_triples.append(triple)
            triples.clear()
            triples.extend(cleaned_triples)
        print("Number of triples in train set after cleaning: {}".format(
            len(self.train_triples)))
        print("Number of triples in dev1 set after cleaning: {}".format(
            len(self.dev1_triples)))
        print("Number of triples in dev2 set after cleaning: {}".format(
            len(self.dev2_triples)))
        print("Number of triples in dev_total set after cleaning: {}".format(
            len(self.dev_total_triples)))
        print("Number of triples in test set after cleaning: {}".format(
            len(self.test_triples)))


class ConceptNetCollator:
    """
    collator class
    """

    def __init__(self, model_name: str, split: str, sample_relation: bool = False) -> None:
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.split = split
        self.sample_relation = sample_relation

    def __call__(self, batch_list):
        """
        batch_list: List of [sentence: str, relation: str, label: int]
        """
        if self.split == 'train':
            # negative sampling within mini-batch
            heads = [instance[1] for instance in batch_list]
            tails = [instance[2] for instance in batch_list]
            relations = [instance[3] for instance in batch_list]

            batch_dict = dict()
            for i, instance in enumerate(batch_list):
                sentence, head, tail, relation, template, label = instance
                assert label == 1
                if not relation in batch_dict:
                    batch_dict[relation] = dict()
                    batch_dict[relation]['sentence'] = [sentence]
                    batch_dict[relation]['label'] = [label]
                else:
                    batch_dict[relation]['sentence'].append(sentence)
                    batch_dict[relation]['label'].append(label)
                # replace head
                head_rep_id = i
                while head_rep_id == i:
                    head_rep_id = random.choice(range(0, len(heads)))
                head_rep = heads[head_rep_id]
                head_rep_sentence = template.replace(
                    '[subj]', head_rep).replace('[obj]', tail)
                batch_dict[relation]['sentence'].append(head_rep_sentence)
                batch_dict[relation]['label'].append(0)
                # replace tail
                tail_rep_id = i
                while tail_rep_id == i:
                    tail_rep_id = random.choice(range(0, len(tails)))
                tail_rep = tails[tail_rep_id]
                tail_rep_sentence = template.replace(
                    '[subj]', head).replace('[obj]', tail_rep)
                batch_dict[relation]['sentence'].append(tail_rep_sentence)
                batch_dict[relation]['label'].append(0)
                if self.sample_relation:
                    # replace relation
                    relation_rep_id = i
                    while relation_rep_id == i:
                        relation_rep_id = random.choice(
                            range(0, len(relations)))
                    relation_rep = relations[relation_rep_id]
                    relation_rep_sentence = batch_list[relation_rep_id][4].replace(
                        '[subj]', head).replace('[obj]', tail)
                    if not relation_rep in batch_dict:
                        batch_dict[relation_rep] = dict()
                        batch_dict[relation_rep]['sentence'] = [
                            relation_rep_sentence]
                        batch_dict[relation_rep]['label'] = [0]
                    else:
                        batch_dict[relation_rep]['sentence'].append(
                            relation_rep_sentence)
                        batch_dict[relation_rep]['label'].append(0)
            all_relations = list(batch_dict.keys())
            sentences = []
            labels = []
            for relation in batch_dict:
                relation_sentences = batch_dict[relation]['sentence']
                relation_labels = batch_dict[relation]['label']
                encoded_sentences = self.tokenizer(
                    relation_sentences, padding='longest', truncation=False, return_tensors='pt', return_attention_mask=True)
                encoded_labels = torch.tensor(relation_labels).float()
                assert encoded_sentences['input_ids'].size(
                    0) == encoded_labels.size(0)
                sentences.append(encoded_sentences)
                labels.append(encoded_labels)
            return sentences, labels, all_relations
        else:
            batch_dict = dict()
            for i, instance in enumerate(batch_list):
                sentence, head, tail, relation, template, label = instance
                if not relation in batch_dict:
                    batch_dict[relation] = dict()
                    batch_dict[relation]['sentence'] = [sentence]
                    batch_dict[relation]['label'] = [label]
                else:
                    batch_dict[relation]['sentence'].append(sentence)
                    batch_dict[relation]['label'].append(label)
            all_relations = list(batch_dict.keys())
            sentences = []
            labels = []
            for relation in batch_dict:
                relation_sentences = batch_dict[relation]['sentence']
                relation_labels = batch_dict[relation]['label']
                encoded_sentences = self.tokenizer(
                    relation_sentences, padding='longest', truncation=False, return_tensors='pt', return_attention_mask=True)
                encoded_labels = torch.tensor(relation_labels).float()
                assert encoded_sentences['input_ids'].size(
                    0) == encoded_labels.size(0)
                sentences.append(encoded_sentences)
                labels.append(encoded_labels)
            return sentences, labels, all_relations


class TripleScorer(pl.LightningModule):
    """
    commonsense knowledge triple scorer
    """

    def __init__(self, model_name: str, num_of_relations: int, hidden_size: int, lr: float, bli: int, tli: int, relation2id: dict, apply_mask: bool, num_steps: int):
        super().__init__()
        self.save_hyperparameters()
        self.model_name = model_name
        self.bli = bli
        self.tli = tli
        self.hidden_size = hidden_size
        self.apply_mask = apply_mask
        self.num_steps = num_steps
        self.num_of_relations = num_of_relations
        self.relation2id = relation2id
        self.pretrained_language_model = AutoModel.from_pretrained(
            self.model_name, return_dict=True)
        self.init_state_dict = copy.deepcopy(
            self.pretrained_language_model.state_dict())
        # freeze params
        # self.pretrained_language_model.eval()
        # for p in self.pretrained_language_model.parameters():
        #     p.requires_grad = False

        self.parameters_tobe_pruned = []
        self.get_parameters_tobe_pruned(self.bli, self.tli)

        self.pruning_mask_generators = [[]] * len(ConceptNetRelations)

        # MLP classifiers
        self.classifiers = nn.ModuleList()
        # for _ in range(self.num_of_relations):
        #     self.classifiers.append(
        #         nn.Sequential(nn.Linear(self.pretrained_language_model.config.hidden_size, self.hparams.hidden_size), nn.ReLU(
        #         ), nn.Linear(self.hparams.hidden_size, 1), nn.Sigmoid())
        #     )
        for _ in range(1):
            self.classifiers.append(
                nn.Sequential(nn.Linear(self.pretrained_language_model.config.hidden_size, self.hparams.hidden_size), nn.ReLU(
                ), nn.Linear(self.hparams.hidden_size, 1), nn.Sigmoid())
            )
        

        # loss function
        self.loss_fn = nn.BCELoss()

    def get_parameters_tobe_pruned(self, bli, tli):
        if len(self.parameters_tobe_pruned) > 0:
            return
        parameters_tobe_pruned = []
        if 'albert' in self.model_name:
            layers = self.pretrained_language_model.encoder.albert_layer_groups[
                0].albert_layers[0]
        elif 'roberta' in self.model_name:
            layers = self.pretrained_language_model.encoder.layer
        elif 'distil' in self.model_name:
            layers = self.pretrained_language_model.transformer.layer
        else:
            layers = self.pretrained_language_model.encoder.layer
        if 'albert' in self.model_name:
            parameters_tobe_pruned.append((layers.attention.query, 'weight'))
            parameters_tobe_pruned.append((layers.attention.key, 'weight'))
            parameters_tobe_pruned.append((layers.attention.value, 'weight'))
            parameters_tobe_pruned.append((layers.attention.dense, 'weight'))
            parameters_tobe_pruned.append((layers.ffn, 'weight'))
            parameters_tobe_pruned.append((layers.ffn_output, 'weight'))
        else:
            for i in range(bli, tli+1):
                try:
                    parameters_tobe_pruned.append(
                        (layers[i].attention.self.query, 'weight'))
                    parameters_tobe_pruned.append(
                        (layers[i].attention.self.key, 'weight'))
                    parameters_tobe_pruned.append(
                        (layers[i].attention.self.value, 'weight'))
                    parameters_tobe_pruned.append(
                        (layers[i].attention.output.dense, 'weight'))
                    parameters_tobe_pruned.append(
                        (layers[i].intermediate.dense, 'weight'))
                    parameters_tobe_pruned.append(
                        (layers[i].output.dense, 'weight'))
                except Exception:
                    parameters_tobe_pruned.append(
                        (layers[i].attention.q_lin, 'weight')
                    )
                    parameters_tobe_pruned.append(
                        (layers[i].attention.k_lin, 'weight')
                    )
                    parameters_tobe_pruned.append(
                        (layers[i].attention.v_lin, 'weight')
                    )
                    parameters_tobe_pruned.append(
                        (layers[i].attention.out_lin, 'weight')
                    )
                    parameters_tobe_pruned.append(
                        (layers[i].ffn.lin1, 'weight')
                    )
                    parameters_tobe_pruned.append(
                        (layers[i].ffn.lin2, 'weight')
                    )
        self.parameters_tobe_pruned = tuple(parameters_tobe_pruned)

    def prune(self, pruning_masks):
        for pruning_mask, (module, name) in zip(pruning_masks, self.parameters_tobe_pruned):
            Foobar_pruning(module, name, pruning_mask)

    def restore(self):
        for module, name in self.parameters_tobe_pruned:
            prune.remove(module, name)
        self.pretrained_language_model.load_state_dict(self.init_state_dict)

    def forward(self, input_dict, label, relation_id):
        outputs = self.pretrained_language_model(**input_dict)
        last_hidden_state = outputs.last_hidden_state
        # (batch_size, hidden_size)
        cls_hidden_state = last_hidden_state[:, 0, :]
        logits = self.classifiers[0](
            cls_hidden_state)  # (batch_size, 1)
        if label is not None:
            loss = self.loss_fn(logits, label)
            return loss, logits
        else:
            return logits

    def configure_optimizers(self):
        """
        only optimize for parameters of classification layer
        """
        # Prepare optimizer
        param_optimizer = list(self.pretrained_language_model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
            {'params': self.classifiers.parameters()}
        ]
        # optimizer = optim.Adam(
        #     self.classifiers.parameters(), lr=self.hparams.lr)
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr)
        scheduler = get_linear_schedule_with_warmup(optimizer, 0.2 * self.num_steps, self.num_steps)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_id: int):
        """
        batch: [sentences, labels, relations] from training set
        """
        sentences, labels, relations = batch
        total_loss = torch.tensor(0, dtype=torch.float32).to(self.device)
        total_correct_predictions = 0
        cnt = 0
        for i, relation in enumerate(relations):
            cnt += sentences[i]['input_ids'].size(0)
            relation_id = self.relation2id[relation]
            if self.apply_mask:
                # pruning
                pruning_mask_generator = self.pruning_mask_generators[relation_id]
                hard_mask = []
                for mask in pruning_mask_generator:
                    logits = torch.sigmoid(mask)
                    logits[logits > 0.5] = 1
                    logits[logits <= 0.5] = 0
                    hard_mask.append(logits)
                self.prune(pruning_masks=hard_mask)
            # forward pass
            input_dict = sentences[i]
            label = labels[i]
            loss, logits = self(input_dict, label, relation_id)
            pred = logits.detach().clone()
            pred[pred > 0.5] = 1
            pred[pred <= 0.5] = 0
            num_correct_prediction = (pred == label).sum().item()
            total_correct_predictions += num_correct_prediction
            total_loss += loss
            if self.apply_mask:
                # restore original parameters
                self.restore()
        avg_loss = total_loss / len(relations)
        train_acc = total_correct_predictions / cnt
        tqdm_dict = {'train_accuracy': train_acc}
        return {'loss': avg_loss, 'progress_bar': tqdm_dict}

    def validation_step(self, batch, batch_id: int):
        """
        batch from dev set
        """
        sentences, labels, relations = batch
        total_correct_predictions = 0
        cnt = 0
        for i, relation in enumerate(relations):
            cnt += sentences[i]['input_ids'].size(0)
            relation_id = self.relation2id[relation]
            if self.apply_mask:
                pruning_mask_generator = self.pruning_mask_generators[relation_id]
                hard_mask = []
                for mask in pruning_mask_generator:
                    logits = torch.sigmoid(mask)
                    logits[logits > 0.5] = 1
                    logits[logits <= 0.5] = 0
                    hard_mask.append(logits)
                self.prune(pruning_masks=hard_mask)
            input_dict = sentences[i]
            label = labels[i]
            logits = self(input_dict, None, relation_id)
            pred = logits.detach().clone()
            pred[pred > 0.5] = 1
            pred[pred <= 0.5] = 0
            num_correct_prediction = (pred == label).sum().item()
            total_correct_predictions += num_correct_prediction
            if self.apply_mask:
                self.restore()
        val_acc = total_correct_predictions / cnt
        return {'val_acc': val_acc}

    def validation_epoch_end(self, outputs: List[Any]):
        val_acces = []
        for output in outputs:
            batch_val_acc = output['val_acc']
            val_acces.append(batch_val_acc)
        validation_accuracy = sum(val_acces) / len(val_acces)
        tqdm_dict = {'val_accuracy': validation_accuracy}
        return {'val_acc': validation_accuracy, 'progress_bar': tqdm_dict}

    def test_step(self, batch, batch_id: int):
        """
        batch from test set
        """
        sentences, labels, relations = batch
        total_correct_predictions = 0
        cnt = 0
        for i, relation in enumerate(relations):
            cnt += sentences[i]['input_ids'].size(0)
            relation_id = self.relation2id[relation]
            if self.apply_mask:
                pruning_mask_generator = self.pruning_mask_generators[relation_id]
                hard_mask = []
                for mask in pruning_mask_generator:
                    logits = torch.sigmoid(mask)
                    logits[logits > 0.5] = 1
                    logits[logits <= 0.5] = 0
                    hard_mask.append(logits)
                self.prune(pruning_masks=hard_mask)
            input_dict = sentences[i]
            label = labels[i]
            logits = self(input_dict, None, relation_id)
            pred = logits.detach().clone()
            pred[pred > 0.5] = 1
            pred[pred <= 0.5] = 0
            num_correct_prediction = (pred == label).sum().item()
            total_correct_predictions += num_correct_prediction
            if self.apply_mask:
                self.restore()
        val_acc = total_correct_predictions / cnt
        return {'test_acc': val_acc}

    def test_epoch_end(self, outputs: List[Any]) -> None:
        val_acces = []
        for output in outputs:
            batch_val_acc = output['test_acc']
            val_acces.append(batch_val_acc)
        validation_accuracy = sum(val_acces) / len(val_acces)
        tqdm_dict = {'test_accuracy': validation_accuracy}
        return {'progress_bar': tqdm_dict}


def main(args):
    pprint.pprint(vars(args))
    # seed
    seed_everything(args.seed)

    # data
    conceptnet_dataset = ConceptNet100kDataset(
        "./data/CKBC/train100k.txt",
        "./data/CKBC/dev1.txt",
        "./data/CKBC/dev2.txt",
        "./data/CKBC/dev_total.txt",
        "./data/CKBC/test.txt"
    )
    # train/dev/test set
    training_set = conceptnet_dataset.train_dataset
    dev1_set = conceptnet_dataset.dev1_dataset
    dev2_set = conceptnet_dataset.dev2_dataset
    dev_total_set = conceptnet_dataset.dev_total_dataset
    test_set = conceptnet_dataset.test_dataset

    # collator
    training_collator = ConceptNetCollator(
        args.model_name, 'train', args.sample_relation)
    other_collator = ConceptNetCollator(args.model_name, 'other')

    # train/dev/test dataloader
    training_dataloader = DataLoader(
        training_set, batch_size=args.batch_size, collate_fn=training_collator, shuffle=True, num_workers=2)
    dev1_dataloader = DataLoader(
        dev1_set, batch_size=args.batch_size, collate_fn=other_collator, shuffle=False)
    dev2_dataloader = DataLoader(
        dev2_set, batch_size=args.batch_size, collate_fn=other_collator, shuffle=False)
    dev_total_dataloader = DataLoader(
        dev_total_set, batch_size=args.batch_size, collate_fn=other_collator, shuffle=False)
    test_dataloader = DataLoader(
        test_set, batch_size=args.batch_size, collate_fn=other_collator, shuffle=False)

    # compute num steps
    num_steps = args.max_epochs * len(training_dataloader)

    # pl model
    pl_model = TripleScorer(args.model_name, len(
        relation2id), args.hidden_size, args.lr, args.bli, args.tli, relation2id, args.apply_mask, num_steps)

    # load precomputed masks
    for relation in pl_model.relation2id:
        _id = pl_model.relation2id[relation]
        mask_file_pth = "./masks/{}_{}_{}_{}_{}_init>normal.pickle".format(
            args.model_name, relation, 6*(args.tli-args.bli+1), args.bli, args.tli)
        assert os.path.exists(mask_file_pth)
        f = open(mask_file_pth, 'rb')
        mask = torch.load(f)
        pl_model.pruning_mask_generators[_id] = mask
        f.close()

    # callbacks
    early_stop_callback = EarlyStopping(
        monitor='val_acc', patience=3, mode='max')
    ckpt_callback = ModelCheckpoint(
        monitor='val_acc', mode='max', dirpath='./ckpt_ckbc/', filename="bert-{epoch:02d}-{val_acc:.2f}")


    # trainer
    trainer = Trainer.from_argparse_args(args, callbacks=[ckpt_callback,
                                                          early_stop_callback], logger=False, val_check_interval=1.0)
    if not args.apply_mask:
        print("Not applying pruning masks")

    # train
    trainer.fit(pl_model, train_dataloader=training_dataloader,
                val_dataloaders=dev_total_dataloader)

    # test
    loaded_model = TripleScorer.load_from_checkpoint(
        ckpt_callback.best_model_path)
    trainer.test(loaded_model, test_dataloader)


if __name__ == "__main__":
    parser = ArgumentParser("Commonsense Knowledge Base Completion as Triple Classification")
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument('--apply_mask', action='store_true', default=False)
    parser.add_argument('--bli', type=int, default=None)
    parser.add_argument('--tli', type=int, default=None)
    parser.add_argument('--init_method', type=str, default=None)
    parser.add_argument('--sample_relation',
                        action='store_true', default=False)
    parser.add_argument('--model_name', type=str, default='bert-base-uncased')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--hidden_size', type=int,
                        default=768//2, required=True)
    parser.add_argument('--batch_size', type=int, default=32)
    args = parser.parse_args()
    relation2sentence = Relation2Sentence()
    main(args)
