from worldformer2.dataset.jerichoworld_dataset_conditional import JerichoWorldDataset
from worldformer2.dataset.training_dataset_rafbart_v2 import JerichoWorldTorchTrainDataset

from worldformer2.tools.text_utils import parse_actions_text
from worldformer2.tools.metrics import action_metrics
from worldformer2.tools.metrics import print_test_metrics
from worldformer2.tools.python_utils import split

from transformers.generation_utils import BeamSearchScorer
from transformers.generation_logits_process import LogitsProcessorList, MinLengthLogitsProcessor

from transformers import AutoTokenizer
from pathlib import Path
import numpy as np
import copy
import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import torch
from torch import optim

from torch.cuda.amp import autocast, GradScaler

import torch.distributed as dist
import torch.multiprocessing as mp

#from transformers import T5Tokenizer, T5Model
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForMaskedLM, AutoModelForCausalLM
from torch.utils.data import DataLoader

from worldformer2.tokenization.custom_tokenizers import get_tokenizer
#from worldformer2.models.bart import BartModel
from worldformer2.models.rafbart_v2 import RafBartModel

import logging
from worldformer2.tools.logging_util import basic_logging
basic_logging()

import json
import ipdb
import wandb
from pprint import pprint

#from tmp import main_worker
import argparse
import torch
from torch import optim


import torch.distributed as dist
import torch.multiprocessing as mp
import ipdb


class JerichoWorldTrainer(object):
    """
    """

    def __init__(self, args):
        # Prepare tokenizer and dataset
        N_SHARDS=8
        #N_SHARDS=4

        self.args = args
        if self.args.mine_negatives:
            # These are the train roms
            test_roms = ['ballyhoo', 'omniquest', 'wishbringer', 'karn', 'yomomma', 'jewel', 'zork2', 'reverb', 'huntdark', 'zork3', 'acorncourt', 'enter', 'moonlit', 'night', 'dragon', 'loose', 'tryst205', '905', 'zenon', 'inhumane', 'snacktime', 'gold', 'murdac', 'weapon', 'afflicted', 'adventureland', 'enchanter']
        else:
            test_roms = ['zork1','library','detective','balances','pentari','ztuu','ludicorp','deephome','temple']
        self.test_roms_shard = self.args.test_roms_shard
        assert self.test_roms_shard < N_SHARDS
        if self.args.test_roms_shard == -1:
            self.test_roms = test_roms
        else:
            self.test_roms = list(split(test_roms, N_SHARDS))[self.test_roms_shard]
        #self.test_roms = ['zork1','library','detective','balances','pentari','ztuu','ludicorp','deephome','temple']
        #self.test_roms = ['zork1','library','detective','balances']
        #self.test_roms = ['deephome','pentari','ztuu','ludicorp','temple']
        #self.test_roms = ['temple']
        #self.test_roms = ['library']
        #self.test_roms = ['deephome']
        #self.test_roms = ['zork1']

        # Train mode
        if self.args.mode == 'train':
            self.register_model()
            self.model = self.model.cuda()                                                                
        
            #self.model = torch.nn.DataParallel(self.model).cuda()

            if not self.args.pretrained_ckpt is None:
                self.model.load_state_dict(torch.load(self.args.pretrained_ckpt), strict=False) # Only cls head should be incompatible
                #state = {k[23:] :v for k,v in torch.load(self.args.pretrained_ckpt).items()}
                #if self.args.init_mode in ['both', 'graph']:
                #    self.model.module.graph_encoder_decoder.load_state_dict(copy.deepcopy(state))
                #    logging.info(f"Loaded weights from {self.args.pretrained_ckpt} into graph bart")
                #if self.args.init_mode in ['both', 'text']:
                #    self.model.module.text_encoder_decoder.load_state_dict(copy.deepcopy(state))
                #    logging.info(f"Loaded weights from {self.args.pretrained_ckpt} into action bart")

            #self.load_checkpoint('/home/mnskim/workspace/tbg/tbg1/ckpts/models/worldformer/exp_newdata_1_notie/model_best.pt')

            self.register_optim()
            self.register_loss()

            if self.args.distributed:
                print(self.args.distributed, '## Training')
                
                
                self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.args.local_rank], 
                                                                    output_device=self.args.local_rank,
                                                                    find_unused_parameters=True)

            self.register_data()
            self.register_training_logger()
            if not self.args.wandb_project is None:
                self.register_wandb()

        elif self.args.mode == 'predict':
            self.register_model()
            #self.model = torch.nn.DataParallel(self.model).cuda()
            #self.model = torch.nn.DataParallel(self.model)
            #self.load_checkpoint(self.args.load_path)
            self.model = self.model.cuda()            
            #self.model = torch.nn.DataParallel(self.model)
            self.load_checkpoint(self.args.load_path)

            """
            to_load = {}
            # NOTE tmp hack
            tmp = torch.load(self.args.load_path)
            for k, v in tmp.items():
                to_load[k.replace('module.','')] = v
            #ipdb.set_trace()
            self.text_encoder.load_state_dict(to_load)
            #self.text_encoder = torch.nn.DataParallel(self.text_encoder).cuda()
            self.text_encoder = self.text_encoder.cuda()
            """

            self.register_data()
            #self.register_training_logger()


    def register_wandb(self):
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        #wandb.config.update({"dataset": self.args.cache_path})

    def register_model(self):
        # TODO make function for this?
        #self.text_encoder.resize_token_embeddings(len(self.tokenizer))
        #model_state_dict = self.text_encoder.state_dict()
        #orig_pos = model_state_dict['bert.embeddings.position_embeddings.weight']
        #extended_pos = torch.cat([orig_pos, orig_pos],0)
        #model_state_dict['bert.embeddings.position_embeddings.weight'] = extended_pos

        logging.info(f"Loading tokenizer from {self.args.tokenizer_path}")
        self.tokenizer = get_tokenizer(self.args.tokenizer_path, self.args.added_tokens_path, self.args.add_atomic_tokens)

        # Graph encoder
        self.graph_encoder_tokenizer = get_tokenizer(self.args.graph_encoder_tokenizer_path, self.args.graph_encoder_added_tokens_path, self.args.add_atomic_tokens)

        # Action decoder
        self.action_decoder_tokenizer = get_tokenizer(self.args.action_decoder_tokenizer_path, self.args.action_decoder_added_tokens_path, self.args.add_atomic_tokens)

        # Graph decoder
        self.graph_decoder_tokenizer = get_tokenizer(self.args.graph_decoder_tokenizer_path, self.args.graph_decoder_added_tokens_path, self.args.add_atomic_tokens)
    
        
        config = copy.deepcopy(self.args)
        config.input_text_n_vocab = len(self.tokenizer)
        #config.input_graph_n_vocab = len(self.graph_encoder_tokenizer)
        #config.output_action_n_vocab = len(self.action_decoder_tokenizer)
        #config.output_graph_n_vocab = len(self.graph_decoder_tokenizer)

        self.model = RafBartModel(config)


    def register_optim(self):

        self.optim = optim.Adam(self.model.parameters(), lr=self.args.lr)

    def register_loss(self):
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    def register_data(self):

        # Paths
        train_cache_path = Path(self.args.cache_path) / f"train_cached.pth"
        val_cache_path = Path(self.args.cache_path) / f"val_cached.pth"
        test_cache_path = Path(self.args.cache_path) / f"test_cached.pth"

        #if self.args.refresh_cache:
        if True:
            self.train_dataset = JerichoWorldDataset(self.args.train_file)
            self.train_dataset.flatten()
            val_instances = self.train_dataset.split_validation(0.1, seed=self.args.seed)
 
            self.val_dataset = JerichoWorldDataset(self.args.train_file)
            self.val_dataset.instances = val_instances

            self.test_dataset = JerichoWorldDataset(self.args.test_file)
            self.test_dataset.flatten()

            Path(self.args.cache_path).mkdir(exist_ok=True,parents=True)


        # Train
        #if self.args.mode == 'train'
        self.train_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer,
                                                                 self.graph_encoder_tokenizer,
                                                                 self.action_decoder_tokenizer,
                                                                 self.graph_decoder_tokenizer,
                                                                 train_cache_path,
                                                                 refresh_cache=self.args.refresh_cache,
                                                                 save_text=False,
                                                                 negative_mode=self.args.train_negative_mode,
                                                                 skip_nohole=self.args.skip_nohole,
                                                                )
        self.train_dataset_torch.register_examples(self.train_dataset)
        self.train_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache,
                                                      add_ckg=self.args.add_ckg)
        if self.args.mine_negatives:
            logging.info('-- Negatives mining')    
            self.train_dataset_torch.filter_roms(self.test_roms)

        if self.args.distributed:    
            print('## DIST sampler')    
            train_sampler = torch.utils.data.distributed.DistributedSampler(self.train_dataset_torch, shuffle=True)
        else:
            train_sampler = None

        self.train_dataloader = DataLoader(self.train_dataset_torch,
                                           batch_size=self.args.train_batchsize,
                                           shuffle=(train_sampler is None),                                           
                                           num_workers=self.args.num_workers,
                                           sampler=train_sampler,
                                           batch_sampler=None,
                                           collate_fn=self.train_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )
        #"""                                          

        # Val
        self.val_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer, 
                                                               self.graph_encoder_tokenizer,
                                                               self.action_decoder_tokenizer,
                                                               self.graph_decoder_tokenizer,
                                                               val_cache_path,
                                                               refresh_cache=self.args.refresh_cache,
                                                               save_text=False,
                                                               negative_mode=self.args.val_negative_mode,
                                                               skip_nohole=self.args.skip_nohole
                                                              )
        self.val_dataset_torch.register_examples(self.val_dataset)
        self.val_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache,
                                                    add_ckg=self.args.add_ckg)

        self.val_dataloader = DataLoader(self.val_dataset_torch,
                                           batch_size=self.args.val_batchsize,
                                           shuffle=False,
                                           num_workers=self.args.num_workers,
                                           sampler=None,
                                           batch_sampler=None,
                                           collate_fn=self.val_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )

        #"""

        # Test
        self.test_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer, 
                                                                self.graph_encoder_tokenizer,
                                                                self.action_decoder_tokenizer,
                                                                self.graph_decoder_tokenizer,
                                                                test_cache_path,
                                                                refresh_cache=self.args.refresh_cache,
                                                                save_text=False,
                                                                negative_mode=self.args.test_negative_mode,
                                                                skip_nohole=self.args.skip_nohole
                                                               )
        self.test_dataset_torch.register_examples(self.test_dataset)
        self.test_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache,
                                                     add_ckg=self.args.add_ckg)        

        self.test_dataloader = DataLoader(self.test_dataset_torch,
                                           batch_size=self.args.test_batchsize,
                                           shuffle=False,
                                           num_workers=self.args.num_workers,
                                           sampler=None,
                                           batch_sampler=None,
                                           collate_fn=self.test_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )

        #self.train_dataset_torch[0]

    def model_forward(self, batch, call_module=False):
        text_encoder_input_ids = batch['input_text'].cuda()
        text_encoder_attention_mask = batch['input_text_mask'].cuda()
        graph_encoder_input_ids = batch['input_graph'].cuda()
        graph_encoder_attention_mask = batch['input_graph_mask'].cuda()
        action_decoder_input_ids = batch['output_valid_act'].cuda()
        action_decoder_attention_mask = batch['output_valid_act_mask'].cuda()
        graph_decoder_input_ids = batch['output_graph_diff'].cuda()
        graph_decoder_attention_mask = batch['output_graph_diff_mask'].cuda()
        labels = batch['binary_labels'].cuda()


        if call_module:
            model_fn = self.model.module
        else:
            model_fn = self.model

        output = model_fn(text_encoder_input_ids=text_encoder_input_ids,
                            text_encoder_attention_mask=text_encoder_attention_mask,
                            graph_encoder_input_ids=graph_encoder_input_ids,
                            graph_encoder_attention_mask=graph_encoder_attention_mask,
                            action_decoder_input_ids=action_decoder_input_ids,
                            action_decoder_attention_mask=action_decoder_attention_mask,
                            graph_decoder_input_ids=graph_decoder_input_ids,
                            graph_decoder_attention_mask=graph_decoder_attention_mask,
                            labels=labels                            
                        )

        return output, action_decoder_input_ids, graph_decoder_input_ids

    def model_loss(self, batch, batch_losses_logging, acc=None):
        output, action_decoder_input_ids, graph_decoder_input_ids = self.model_forward(batch)
        
        loss = 0

        pos_template_flag = batch['binary_labels'].cuda()
        nonzero_holes_flag = (torch.tensor(batch['n_holes']).unsqueeze(1) > 0).long().cuda()
        #print(pos_template_flag, nonzero_holes_flag)

        if self.args.task_mode in ['raf_multitask', 'action_only', 'action_multi']:
            template_cls_loss = self.model.text_encoder_decoder.cls_loss(output['template_cls_logits'], batch['binary_labels'].cuda())
            loss = loss + template_cls_loss
            batch_losses_logging['template_cls_loss'].append(template_cls_loss.item())
            
            # Loss masking (not applied on template classification)
            #pos_template_flag = batch['binary_labels'].cuda()
            #nonzero_holes_flag = (torch.tensor(batch['n_holes']).unsqueeze(1) > 0).long().cuda()

        if self.args.task_mode in ['raf_multitask', 'full_multitask', 'action_only', 'action_multi']:

            # Loss calculation
            action_labels = action_decoder_input_ids[:,1:].contiguous()
            action_labels_mask = action_labels != self.tokenizer.pad_token_id
            #action_labels_mask = action_labels_mask * pos_template_flag # exp3 (original)
            action_labels_mask = action_labels_mask * pos_template_flag * nonzero_holes_flag # exp 3 postemplateflagongraphloss + nonzero holes flag
            action_labels = action_labels * action_labels_mask 
            action_labels -= (1 - action_labels_mask.long()) * 100
            action_preds = output['action_logits'][:,:-1,:].contiguous()

            action_loss = self.loss_fn(action_preds.view(-1, action_preds.size()[2]), action_labels.view(-1))
            loss = loss + action_loss * self.args.lambda1
            batch_losses_logging['action_loss'].append(action_loss.item())

        if self.args.task_mode in ['raf_multitask', 'full_multitask', 'graph_only', 'graph_multi']:

            graph_labels = graph_decoder_input_ids[:,1:].contiguous()
            graph_labels_mask = graph_labels != self.tokenizer.pad_token_id
            #graph_labels_mask = graph_labels_mask * pos_template_flag # exp 3 postemplateflagongraphloss
            graph_labels_mask = graph_labels_mask * pos_template_flag * nonzero_holes_flag # exp 3 postemplateflagongraphloss + nonzero holes flag
            graph_labels = graph_labels * graph_labels_mask 
            graph_labels -= (1 - graph_labels_mask.long()) * 100
            graph_preds = output['graph_logits'][:,:-1,:].contiguous()

            graph_loss = self.loss_fn(graph_preds.view(-1, graph_preds.size()[2]), graph_labels.view(-1))
            loss = loss + graph_loss * self.args.lambda2
            batch_losses_logging['graph_loss'].append(graph_loss.item())        

        if self.args.task_mode in ['raf_multitask']:
            if acc is not None:
                preds = torch.argmax(torch.softmax(output['template_cls_logits'], 1),1) 
                n_correct = (pos_template_flag.view(-1) == preds).sum()
                n_incorrect = (pos_template_flag.view(-1) != preds).sum()
                acc['n_correct'] += n_correct
                acc['n_incorrect'] += n_incorrect 

        return loss

    def train(self, dataloader):
        """
        Train a single epoch
        """
        #self.info['curr_epoch'] += 1 # Increment epoch
        self.model.train()
        #loss_mini_batch = 0
        batch_losses_logging = {'action_loss':[], 'graph_loss': [], 'template_cls_loss': []}


        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):
            #if bidx == 100:
            #    break

            #losses_logging = {'action_loss':[], 'graph_loss': []}

            #loss = 0
            if self.args.fp16:
                with autocast():
                    loss = self.model_loss(batch, batch_losses_logging)
                    loss = loss / self.args.n_accum             
                self.scaler.scale(loss).backward()
            else:
                loss = self.model_loss(batch, batch_losses_logging)                
                loss = loss / self.args.n_accum # normalize for accumulation            
                loss.backward()            

            #loss_mini_batch += loss.item()

            if ((bidx + 1) % self.args.n_accum == 0) or (bidx + 1 == len(dataloader)):
                # Iter counts the number of actual gradient update states
                #torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)                                
                if self.args.fp16:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                self.model.zero_grad(set_to_none=True)    
                #self.optim.zero_grad()
                #self.model.zero_grad()

                #self.log_train(loss_mini_batch)

                batch_losses_logging = {k: np.sum(v)/self.args.n_accum for k, v in batch_losses_logging.items()}
                self.info['epoch_metrics'][self.info['curr_epoch']]['train'].append(batch_losses_logging)

                if self.args.wandb_project is not None:
                    wandb_update = {f"train_{k}": v for k, v in batch_losses_logging.items()}
                    #wandb_update['epoch'] = self.info['curr_epoch']
                    #wandb.log(wandb_update, step=self.info['curr_iter'])
                    wandb.log({"epoch": self.info['curr_epoch'], **wandb_update}, step=self.info['curr_iter'])

                # Reset
                batch_losses_logging = {'action_loss':[], 'graph_loss': [], 'template_cls_loss': []}

                self.info['curr_iter'] += 1

                #_tmp = self.graph_decoder_tokenizer.convert_ids_to_tokens(torch.max(graph_preds[0], dim=1)[1])
                #print(_tmp)

                #_tmp = self.tokenizer.convert_ids_to_tokens(torch.max(action_preds[0], dim=1)[1])
                #print(_tmp)
                #loss_mini_batch = 0
        

        #return np.mean(epoch_losses)

    def validate(self, dataloader, mode='val', save_ckpt=False):
        """
        Validation
        """
        losses = {'action_loss':[], 'graph_loss': [], 'template_cls_loss': []}
        acc = {'n_correct':0, 'n_incorrect':0}
        self.model.eval()

        epoch_val_action_loss = 0
        epoch_val_graph_loss = 0

        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):
            loss = self.model_loss(batch, losses, acc=acc)
            
        epoch_val_action_loss = np.mean(losses['action_loss'])
        epoch_val_graph_loss = np.mean(losses['graph_loss'])

        if self.args.task_mode in ['raf_multitask']:            
            epoch_val_template_cls_loss = np.mean(losses['template_cls_loss'])
            epoch_val_template_cls_acc = acc['n_correct']/float(acc['n_correct']+acc['n_incorrect'])
        else:
            epoch_val_template_cls_loss = 0
            epoch_val_template_cls_acc = 0

        #if do_logging:
        if mode == 'val':
            self.info['epoch_metrics'][self.info['curr_epoch']]['val'] = {'action_loss': epoch_val_action_loss,
                                                                          'graph_loss': epoch_val_graph_loss,
                                                                          'template_cls_loss': epoch_val_template_cls_loss,                                                                          
                                                                          'template_cls_acc': epoch_val_template_cls_acc}

        if self.args.wandb_project is not None:
            wandb_update = {f"{mode}_action_loss": epoch_val_action_loss,
                            f"{mode}_graph_loss": epoch_val_graph_loss,
                            f"{mode}_template_cls_loss": epoch_val_template_cls_loss,
                            f"{mode}_template_cls_acc": epoch_val_template_cls_acc
                            }
            wandb_update['epoch'] = self.info['curr_epoch']
            wandb.log(wandb_update, step=self.info['curr_iter'])

        epoch_train_action_loss = np.mean([item['action_loss'] for item in self.info['epoch_metrics'][self.info['curr_epoch']]['train']])
        epoch_train_graph_loss = np.mean([item['graph_loss'] for item in self.info['epoch_metrics'][self.info['curr_epoch']]['train']])
        #self.log_val(losses)

        log_str = f"### Epoch {self.info['curr_epoch']} finished. Train loss - action: {epoch_train_action_loss:.2f}, graph: {epoch_train_graph_loss:.2f}, {mode} eval: loss - action: {epoch_val_action_loss:.2f}, graph: {epoch_val_graph_loss:.2f}, template cls loss: {epoch_val_template_cls_loss:.2f}, template cls acc: {epoch_val_template_cls_acc:.2f}"
        logging.info(log_str)

        if save_ckpt:
            if self.args.ckpt_policy == 'all':
                if self.args.distributed:
                    torch.save(self.model.module.state_dict(), self.save_path / f"model_epoch{self.info['curr_epoch']}.pt")
                else:
                    torch.save(self.model.state_dict(), self.save_path / f"model_epoch{self.info['curr_epoch']}.pt")

        #return np.mean(epoch_losses)

    def calc_recall(self, res_rom):
        # per example recall of examples in the rom
        recalls = []
        for k,v in res_rom.items():
            recall_values = np.array([item['template_label']==item['template_pred'] for item in v if item['template_label']==1])
            if len(recall_values) == 0:
                # NOTE No positives. There shouldnt be too many of these
                pass
            else:
                recall_avg = recall_values.mean()            
                recalls.append(recall_avg)
        return recalls

    def inference(self, dataloader, mode='val', save_ckpt=False):
    
        """
        Inference
        """
        #num_beams = 25
        num_beams = 15

        # Savefiles
        name_ckpt = self.args.load_path.split('/')[-2]
        name_epoch = self.args.load_path.split('/')[-1]
        save_name = f"{name_ckpt}_{name_epoch}"

        if self.args.save_gen:
            path_save_gen = Path('/home/mnskim/workspace/tbg/tbg1/results/generations/') / save_name
            if not path_save_gen.exists():
                path_save_gen.mkdir(parents=True,exist_ok=True)
            data_name_stem = dataloader.dataset.__dict__['cache_path'].stem                
            save_gen_file = open(path_save_gen / f"{data_name_stem}.jsonl", 'w')

        batch_losses = []
        self.model.eval()
        metrics = {}
        rom_ex_seen = set()
        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):
            bsize = batch['input_text'].size(0)

            text_encoder_attention_mask = batch['input_text_mask'].cuda()
            graph_encoder_attention_mask = batch['input_graph_mask'].cuda()
            aggregator_input_mask = torch.cat([text_encoder_attention_mask, graph_encoder_attention_mask], 1)

            for ridx in range(bsize):
                rom = batch['rom'][ridx]
                if not rom in self.test_roms: # NOTE 
                    continue
                ex_id = batch['ex_id'][ridx]                
                template = batch['template'][ridx]
                template_label = batch['binary_labels'][ridx].item()
                #template_pred = batch['template_pred'][ridx]
                n_holes = batch['n_holes'][ridx]
                #n_holes = get_nholes(template, ['<mask>'])


                if not rom in metrics:
                    metrics[rom] = {'action': {'f1': [], 'em': [], 'examples':{}, 'outputs':{}},
                                    'graph': {'f1': [], 'em': [], 'examples':{}, 'outputs':{}},
                                    'text_input': {},
                                    'graph_input': {},
                                   }

                if not ex_id in metrics[rom]['action']['examples']:
                    assert ex_id not in metrics[rom]['graph']['examples']
                    metrics[rom]['action']['examples'][ex_id] = {'n_tp': 0, 'n_fp': 0, 'n_fn': 0}
                    metrics[rom]['graph']['examples'][ex_id] = {'n_tp': 0, 'n_fp': 0, 'n_fn': 0}

                if not ex_id in metrics[rom]['action']['outputs']:
                    assert ex_id not in metrics[rom]['graph']['outputs']
                    metrics[rom]['action']['outputs'][ex_id] = {}

                #res[rom][ex_id]                                   
                text_input = batch['input_text'][ridx][:torch.sum(batch['input_text_mask'][ridx]==1)]
                #text_input_aslist = text_input.numpy().tolist()
                graph_input = batch['input_graph'][ridx][:torch.sum(batch['input_graph_mask'][ridx]==1)]
                #graph_input_aslist = graph_input.numpy().tolist()
                text_input = text_input.unsqueeze(0)
                graph_input = graph_input.unsqueeze(0)

                # action
                ff = batch['output_valid_act'][ridx][:torch.sum(batch['output_valid_act'][ridx]!=self.tokenizer.pad_token_id)]
                #assert ff[0] == self.tokenizer.bos_token_id
                assert ff[0] == self.tokenizer.bos_token_id
                if ff[-1] == self.tokenizer.eos_token_id:
                    target_text = self.tokenizer.decode(ff[1:-1])
                else:
                    target_text = self.tokenizer.decode(ff[1:])
                target_text = target_text.strip()
                #target_text = self.tokenizer.decode(ff)
                #target_text = target_text.rstrip("</s>")

                # graph
                gg = batch['output_graph_diff'][ridx][:torch.sum(batch['output_graph_diff'][ridx]!=self.tokenizer.pad_token_id)]
                #assert gg[0] == self.tokenizer.bos_token_id
                #graph_target_text = self.tokenizer.decode(gg)
                #graph_target_text = graph_target_text.rstrip("</s>")
                assert gg[0] == self.tokenizer.bos_token_id
                #graph_target_text = self.tokenizer.decode(gg[1:])
                #graph_target_text = graph_target_text.rstrip("</s>")                        
                if gg[-1] == self.tokenizer.eos_token_id:
                    graph_target_text = self.tokenizer.decode(gg[1:-1])
                else:
                    graph_target_text = self.tokenizer.decode(gg[1:])
                graph_target_text = graph_target_text.strip()
                
                    
                text_input = text_input.cuda()
                graph_input = graph_input.cuda()

                # Replicate the inputs num_beams
                text_encoder_outputs = self.model.text_encoder_decoder.get_encoder()(
                                            text_input.repeat_interleave(num_beams, dim=0), return_dict=True)
                graph_encoder_outputs = self.model.graph_encoder_decoder.get_encoder()(
                                            graph_input.repeat_interleave(num_beams, dim=0), return_dict=True)

                # Get just one
                text_input = text_input[0].unsqueeze(0)
                text_last_hidden = text_encoder_outputs['last_hidden_state'][0].unsqueeze(0)
                graph_input = graph_input[0].unsqueeze(0)
                graph_last_hidden = graph_encoder_outputs['last_hidden_state'][0].unsqueeze(0)

                # Aggregator - no mask required here                
                aggregator_inputs = torch.cat([text_encoder_outputs['last_hidden_state'],
                                               graph_encoder_outputs['last_hidden_state']], 1)
                #aggregator_outputs = self.aggregator(inputs_embeds=aggregator_inputs, attention_mask=aggregator_input_mask)
                aggregator_outputs = self.model.aggregator(inputs_embeds=aggregator_inputs,)
                #aggregator_outputs_sum = torch.sum(aggregator_input_mask.unsqueeze(2) * aggregator_outputs['last_hidden_state'], 1)
                #state_vec = aggregator_outputs_sum / (aggregator_input_mask.sum(1)).unsqueeze(1)
                state_vec = aggregator_outputs['last_hidden_state'].mean(1)
                                
                if not self.args.use_oracle_template:
                    sentence_representation = self.model.text_encoder_decoder.get_encoder_sentence_rep(text_input, text_last_hidden)
                    if self.args.use_state_vec:
                        sentence_representation = torch.cat([sentence_representation, state_vec[0].unsqueeze(0)], 1)
                    if self.args.use_graph_vec:
                        graph_sentence_representation = self.model.graph_encoder_decoder.get_encoder_sentence_rep(graph_input, graph_last_hidden)
                        sentence_representation = torch.cat([sentence_representation, graph_sentence_representation], 1)
                    template_cls_logits = self.model.text_encoder_decoder.classification_head(sentence_representation)
                    template_pred = torch.argmax(torch.softmax(template_cls_logits, 1),1).item()
                    class_logits = torch.nn.functional.softmax(template_cls_logits,1).tolist()[0]
                else:
                    template_pred = copy.deepcopy(template_label)
                    class_logits = [-1, -1]

                metrics[rom]['action']['outputs'][ex_id][template] = {'t_label': template_label,
                                                                      't_pred': template_pred,
                                                                      'n_holes': n_holes,
                                                                      'class_logits': class_logits,
                                                                      }

                if template_pred == 0 and template_label == 0: # TN 
                    """
                    res[rom][ex_id]['templates'].append({'template': template,
                                                'template_label': 0,
                                                'template_pred': 0,
                                                'action_pred': None,
                                                'action_target': target_text,
                                                'graph_pred': None,
                                                'graph_target': graph_target_text,
                                                'text_input': text_input_aslist,
                                                'graph_input': graph_input_aslist,
                                                'rom': rom
                                                })                                               
                    """
                    #continue
                    pass

                if self.args.task_mode in ['raf_multitask', 'full_multitask', 'action_only', 'action_multi']:
                    model_kwargs = {
                        "state_vector": state_vec,
                        "encoder_outputs": text_encoder_outputs
                    }
                        #"encoder_outputs": self.model.module.text_encoder_decoder.get_encoder()(text_input.repeat_interleave(num_beams, dim=0), return_dict=True)}

                    if template_pred == 1:
                        #"""
                        #pred_text = self.tokenizer.decode(self.model.module.text_encoder_decoder.generate(text_input, num_beams=15, max_length=1024)[0])
                        #dd = self.model.module.text_encoder_decoder.generate(text_input, num_beams=15, max_length=1024)[0]
                        pred_ids = self.model.text_encoder_decoder.generate(text_input, 
                                                                            num_beams=num_beams,
                                                                            decoder_start_token_id=0,
                                                                            max_length=1024,
                                                                            **model_kwargs)[0]
                       
                        try:
                            assert pred_ids[:2].tolist() == [0,0]
                        except AssertionError as err:
                            ipdb.set_trace()
                        if pred_ids[-1] == 2:
                            pred_ids = pred_ids[2:-1]
                        else:
                            pred_ids = pred_ids[2:]

                        pred_text = self.tokenizer.decode(pred_ids)
                        pred_text = pred_text.strip()
                        # action generation

                        # assuming start and ends with </s>"
                        #assert pred_text.startswith("</s>")
                        #assert pred_text.endswith("</s>")
                        #pred_text = pred_text.lstrip(self.tokenizer.bos_token)
                        #pred_text = pred_text.rstrip(self.tokenizer.eos_token).strip()

                        pred_actions = parse_actions_text(pred_text, skip_empty=True)
                        #"""

                        metrics[rom]['action']['outputs'][ex_id][template]['pred_actions'] = pred_actions
                    else:
                        metrics[rom]['action']['outputs'][ex_id][template]['pred_actions'] = None

                    #target_text = target_text.lstrip(self.tokenizer.bos_token)
                    #target_text = target_text.rstrip(self.tokenizer.eos_token).strip()
                    target_actions = parse_actions_text(target_text, skip_empty=True)
                    
                    metrics[rom]['action']['outputs'][ex_id][template]['target_actions'] = target_actions

                    if template_label == 1:
                        if template_pred == 1: # TP
                            if n_holes == 0:
                                metrics[rom]['action']['examples'][ex_id]['n_tp'] += 1                                
                            else:
                                action_em, action_f1, counts = action_metrics(pred_actions, 
                                                                              target_actions,
                                                                              return_counts=True)
                                action_tp, action_fp, action_fn = counts
                                metrics[rom]['action']['examples'][ex_id]['n_tp'] += action_tp
                                metrics[rom]['action']['examples'][ex_id]['n_fp'] += action_fp
                                metrics[rom]['action']['examples'][ex_id]['n_fn'] += action_fn
                            #print('tp')
                        elif template_pred == 0: # FN
                            if n_holes == 0:
                                metrics[rom]['action']['examples'][ex_id]['n_fn'] += 1
                            else:                        
                                metrics[rom]['action']['examples'][ex_id]['n_fn'] += len(target_actions) 
                    elif template_label == 0:
                        if template_pred == 1: # FP
                            if n_holes == 0:
                                metrics[rom]['action']['examples'][ex_id]['n_fp'] += 1
                            else:                                        
                                metrics[rom]['action']['examples'][ex_id]['n_fp'] += len(pred_actions)                            
                            #print('fp')

                    #action_em, action_f1 = action_metrics(pred_actions, target_actions)
                    
                    #metrics[rom]['action']['em'].append(action_em)
                    #metrics[rom]['action']['f1'].append(action_f1)

                    #print(target_actions)
                    #print(pred_actions)

        
                if self.args.task_mode in ['raf_multitask', 'full_multitask', 'graph_only', 'graph_multi']:
                    if not f"{rom}_{ex_id}" in rom_ex_seen:

                        # TODO improve
                        metrics[rom]['text_input'][ex_id] = self.tokenizer.decode(text_input[0])
                        metrics[rom]['graph_input'][ex_id] = self.tokenizer.decode(graph_input[0])


                        model_kwargs = {
                            "state_vector": state_vec,
                            "encoder_outputs": graph_encoder_outputs
                        }
                        graph_pred_ids = self.model.graph_encoder_decoder.generate(graph_input,
                                                                                   num_beams=num_beams, 
                                                                                   max_length=1024, 
                                                                                   decoder_start_token_id=0, 
                                                                                   **model_kwargs)[0]
                        #graph_pred_text = self.tokenizer.decode(graph_pred_ids) 
                        try:
                            assert graph_pred_ids[:2].tolist() == [0,0]                            
                        except AssertionError as err:
                            ipdb.set_trace()
                        if graph_pred_ids[-1] == 2:
                            graph_pred_ids = graph_pred_ids[2:-1]
                        else:
                            graph_pred_ids = graph_pred_ids[2:]                            

                        graph_pred_text = self.tokenizer.decode(graph_pred_ids)
                        graph_pred_text = graph_pred_text.strip()

                        # assuming start and ends with </s>"
                        #assert graph_pred_text.startswith("</s>")
                        #assert graph_pred_text.endswith("</s>")
                        #graph_pred_text = graph_pred_text.strip("</s>")
                        #graph_pred_text = graph_pred_text.lstrip(self.tokenizer.bos_token)
                        #graph_pred_text = graph_pred_text.rstrip(self.tokenizer.eos_token).strip()
                        #graph_target_text = graph_target_text.lstrip(self.tokenizer.bos_token)
                        #graph_target_text = graph_target_text.rstrip(self.tokenizer.eos_token).strip()                  

                        pred_graph = parse_actions_text(graph_pred_text, action_token='[TRIPLE]', skip_empty=True)
                        target_graph = parse_actions_text(graph_target_text, action_token='[TRIPLE]', skip_empty=True)

                        metrics[rom]['action']['outputs'][ex_id][template]['pred_graph'] = pred_graph
                        metrics[rom]['action']['outputs'][ex_id][template]['target_graph'] = target_graph

                        graph_em, graph_f1, counts = action_metrics(pred_graph, target_graph, return_counts=True)
                        graph_tp, graph_fp, graph_fn = counts
                        metrics[rom]['graph']['examples'][ex_id]['n_tp'] += graph_tp
                        metrics[rom]['graph']['examples'][ex_id]['n_fp'] += graph_fp
                        metrics[rom]['graph']['examples'][ex_id]['n_fn'] += graph_fn
                        
                        rom_ex_seen.add(f"{rom}_{ex_id}")

                  
                        """
                        res[rom][ex_id]['templates'].append({'template': template,
                                                            'template_label': template_label,
                                                            'template_pred': template_pred,
                                                            'action_pred': pred_ids,
                                                            'action_target': target_text,
                                                            'graph_pred': graph_pred_ids,
                                                            'graph_target': graph_target_text,
                                                                })
                        """
                #res[rom][ex_id]['text_input'] = self.tokenizer.decode(text_input[0])                

                if self.args.save_gen:
                    single_ex = {'bidx': bidx,
                                 'ridx': ridx,
                                 'text_input': self.tokenizer.decode(text_input[0]),
                                 'graph_input': self.tokenizer.decode(graph_input[0]),
                                 'pred_actions': pred_actions,
                                 'target_actions': target_actions,
                                 'pred_graph': pred_graph,
                                 'target_graph': target_graph
                                }
                    save_gen_file.write(json.dumps(single_ex) + '\n')


            #rom_lengths = [(key,len(metrics[key]['action']['em'])) for key in metrics.keys()]            
            #n_total = sum(np.array([item[1] for item in rom_lengths]))
            #rom_weights = {key: len(metrics[key]['action']['em'])/n_total for key in metrics.keys()}

            
            for rom in self.test_roms:
                if rom in set(metrics.keys()):
                    for task in ['action', 'graph']:
                        f1_list = []
                        em_list = []
                        overall_tp = 0
                        overall_fp = 0
                        overall_fn = 0
                        for _ex_id in metrics[rom][task]['examples'].keys():
                            f1_divisor = float(metrics[rom][task]['examples'][_ex_id]['n_tp']
                                            + 0.5*(metrics[rom][task]['examples'][_ex_id]['n_fp']
                                            + metrics[rom][task]['examples'][_ex_id]['n_fn']))
                            if f1_divisor == 0: # TODO improve 
                                # this means fp,fn,tp are all 0
                                #f1 = 0 # this was a bug
                                f1 = 1.0
                            else:                                                                                
                                f1 = metrics[rom][task]['examples'][_ex_id]['n_tp'] / f1_divisor
                            f1_list.append(f1)      

                            """
                            # EM calculation - essentially recall
                            maximum = metrics[rom][task]['examples'][_ex_id]['n_fn'] + metrics[rom][task]['examples'][_ex_id]['n_tp'] 
                            if maximum == 0:
                                if metrics[rom][task]['examples'][_ex_id]['n_fp'] == 0:
                                    em = 1
                                else:
                                    em = 0
                            else:
                                em = metrics[rom][task]['examples'][_ex_id]['n_tp'] /float(maximum)
                            em_list.append(em)
                            """

                            # EM Calculation fix
                            #(metrics[rom][task]['examples'][_ex_id]['n_tp'])/float(metrics[rom][task]['examples'][_ex_id]['n_tp'] +
                            #                                                       metrics[rom][task]['examples'][_ex_id]['n_fp'] + 
                            #                                                       metrics[rom][task]['examples'][_ex_id]['n_fn'])
                            
                            # EM calculation - fix
                            maximum = metrics[rom][task]['examples'][_ex_id]['n_fn'] + metrics[rom][task]['examples'][_ex_id]['n_tp'] + metrics[rom][task]['examples'][_ex_id]['n_fp']
                            if maximum == 0: # no true positives
                                if (metrics[rom][task]['examples'][_ex_id]['n_fp'] == 0 and metrics[rom][task]['examples'][_ex_id]['n_fn'] == 0):
                                    em = 1
                                else:
                                    em = 0
                            else:
                                em = metrics[rom][task]['examples'][_ex_id]['n_tp'] /float(maximum)
                            em_list.append(em)
                            

                            overall_tp += metrics[rom][task]['examples'][_ex_id]['n_tp']
                            overall_fp += metrics[rom][task]['examples'][_ex_id]['n_fp']
                            overall_fn += metrics[rom][task]['examples'][_ex_id]['n_fn']
                            #

                        metrics[rom][task]['f1'] = copy.deepcopy(f1_list)
                        metrics[rom][task]['em'] = copy.deepcopy(em_list)

                        #print(f"{rom} {task}, F1: {np.mean(f1_list):4f}, EM: {np.mean(em_list):4f}, TP: {overall_tp}, FP: {overall_fp}, FN: {overall_fn}")

            if bidx % 50 == 0:

                for key in metrics.keys():
                    print(f"{key} Action Macro EM: {np.mean(metrics[key]['action']['em'])} over {len(metrics[key]['action']['em'])} examples")
                    print(f"{key} Action Macro F1: {np.mean(metrics[key]['action']['f1'])} over {len(metrics[key]['action']['f1'])} examples")
                    print(f"{key} Graph Macro EM: {np.mean(metrics[key]['graph']['em'])} over {len(metrics[key]['graph']['em'])} examples")
                    print(f"{key} Graph Macro F1: {np.mean(metrics[key]['graph']['f1'])} over {len(metrics[key]['graph']['f1'])} examples")

                rom_lengths = [(key,len(metrics[key]['action']['em'])) for key in metrics.keys()]            
                n_total = sum(np.array([item[1] for item in rom_lengths]))
                rom_weights = {key: len(metrics[key]['action']['em'])/n_total for key in metrics.keys()}

                for task in ['action', 'graph']:
                    for met in ['em', 'f1']:
                        weighted_val = 0
                        for key in metrics.keys():
                            weighted_val += np.mean(metrics[key][task][met]) * rom_weights[key]
                        print(task, met, weighted_val)

        # Final print
        for key in metrics.keys():
            print(f"{key} Action Macro EM: {np.mean(metrics[key]['action']['em'])} over {len(metrics[key]['action']['em'])} examples")
            print(f"{key} Action Macro F1: {np.mean(metrics[key]['action']['f1'])} over {len(metrics[key]['action']['f1'])} examples")
            print(f"{key} Graph Macro EM: {np.mean(metrics[key]['graph']['em'])} over {len(metrics[key]['graph']['em'])} examples")
            print(f"{key} Graph Macro F1: {np.mean(metrics[key]['graph']['f1'])} over {len(metrics[key]['graph']['f1'])} examples")

        rom_lengths = [(key,len(metrics[key]['action']['em'])) for key in metrics.keys()]            
        n_total = sum(np.array([item[1] for item in rom_lengths]))
        rom_weights = {key: len(metrics[key]['action']['em'])/n_total for key in metrics.keys()}

        for task in ['action', 'graph']:
            for met in ['em', 'f1']:
                weighted_val = 0
                for key in metrics.keys():
                    weighted_val += np.mean(metrics[key][task][met]) * rom_weights[key]
                print(task, met, weighted_val)


        for task in ['graph','action']:
            for met in ['em', 'f1']:
                print(task,met,[round(np.mean(metrics[key][task][met]),12) for key in self.test_roms]) 
        
        ipdb.set_trace()
        #save_name = f"{Path(self.args.load_path).parent.stem}_{Path(self.args.load_path).stem}"
        print_test_metrics(metrics, save_name)




    def register_training_logger(self):
        """
        Training logger
        optionally handle wandb
        """
        self.info = {'curr_epoch': 0,
                     'curr_iter': 0,
                     'epoch_metrics': {},
                     'train_metrics': [],
                     'val_metrics': [],
                     'best_val': None
                    }

        self.save_path = Path(self.args.save_path)
        self.save_path.mkdir(parents=True, exist_ok=True)

    def log_train(self, loss):
        """
        Call on every batch (==iter)
        """
        self.info['curr_iter'] += 1
        if not self.info['curr_epoch'] in self.info['epoch_metrics']:
            self.info['epoch_metrics'][self.info['curr_epoch']] = {}
        #self.info['train_metrics'].append((self.info['curr_epoch'], self.info['curr_iter'], loss))
        self.info['epoch_metrics'][self.info['curr_epoch']].append({})
        if self.info['curr_iter'] % self.args.log_freq == 0:
            total_avg_loss = np.mean([item[2] for item in self.info['train_metrics']])
            logging.info(f"Iter {self.info['curr_iter']} loss: {loss:.2f}. total avg loss: {total_avg_loss:.2f}")

    def log_val(self, entry):
        """
        Call on val end
        """ 
        entry = (self.info['curr_epoch'], self.info['curr_iter'], entry)
        ipdb.set_trace()
        epoch_train_loss = np.mean([item[2] for item in self.info['train_metrics'] if item[0] == self.info['curr_epoch']])
        output_str = f"Epoch {self.info['curr_epoch']} val loss: {loss:.2f}, train loss: {epoch_train_loss:.2f}."
        self.info['val_metrics'].append(entry)

        if self.info['best_val'] is None:
            self.info['best_val'] = entry
        else:
            if self.info['best_val'][2] > entry[2]:
                self.info['best_val'] = entry
                output_str += ' (new best ckpt)'
                torch.save(self.model.state_dict(), self.save_path / f"model_best.pt")
                if self.args.ckpt_policy == 'best_only':
                    pass
                elif self.args.ckpt_policy == 'all':
                    torch.save(self.model.state_dict(), self.save_path / f"model_epoch{self.info['curr_epoch']}.pt")
        logging.info(output_str)

    def generate(self, input_text, decoder):
        """
        args:
            input_text:
            decoder: 'graph' or 'action'
        """
        input_ids = self.graph_decoder_tokenizer.encode(input_text, return_tensors='pt')
        beam_output = self.model.generate(
                        input_ids,
                        max_length=50,
                        num_beams=15,
                        early_stopping=True
        )
        return beam_output

    def load_checkpoint(self, ckpt):
        state = torch.load(ckpt)
        #ipdb.set_trace()
        self.model.load_state_dict(state, strict=False)
        print(f"Loaded weights from {ckpt}")



def main_worker(args):
    """
    Pretrain the text encoder on mlm loss. Do training and eval.
    """
      
    
    #args.gpu = gpu

    #if args.gpu is not None:
    #    print("Use GPU: {} for training".format(args.gpu))
      

    if args.mode == 'train':
        # Prepare trainer
        trainer = JerichoWorldTrainer(args)
        model = trainer.model
        
        #model.cuda()
        
        #if args.distributed:
        
        if args.fp16:
            trainer.scaler = GradScaler()            

        for epoch in range(args.n_epochs):
            trainer.info['curr_epoch'] = epoch
            trainer.info['epoch_metrics'][epoch] = {'train': [], 'val': []}

            #ipdb.set_trace()
            trainer.train(trainer.train_dataloader)
            
            if args.distributed:
                if args.local_rank == 0:
                    with torch.no_grad():
                        trainer.validate(trainer.val_dataloader, mode='val', save_ckpt=True)
                        trainer.validate(trainer.test_dataloader, mode='test', save_ckpt=False)    
                dist.barrier()
            else:
                with torch.no_grad():
                    trainer.validate(trainer.val_dataloader, mode='val', save_ckpt=True)
                    trainer.validate(trainer.test_dataloader, mode='test', save_ckpt=False)    
            #ipdb.set_trace()

    if args.mode == 'predict':

        predictor = JerichoWorldTrainer(args)

        if not args.mine_negatives:
            with torch.no_grad():
                predictor.inference(predictor.test_dataloader, mode='test', save_ckpt=False)
        else: # Negatives mining
            #ipdb.set_trace()            
            with torch.no_grad():
                predictor.inference(predictor.train_dataloader, mode='test', save_ckpt=False)
        #predictor.predict(predictor.val_dataloader, do_logging=False)
        #predictor.predict(predictor.test_dataloader, do_logging=False)
    




if __name__=="__main__":
        #save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/models/pretraining"

    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', type=str, help='train json file')
    parser.add_argument('--test_file', type=str, help='train json file')

    # Pretrained ckpt
    parser.add_argument('--pretrained_ckpt', type=str, default=None, help='path/name')
    parser.add_argument('--init_mode', type=str, default=None, choices=['both','text','graph'], help='how to use pretrained ckpt')


    # Tokenizers
    parser.add_argument('--tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--added_tokens_path', type=str, help='path to additional tokenizer data created from data')
    parser.add_argument('--add_atomic_tokens', action='store_true', default=False, help='')

    # Tokenizer for graph encoder
    parser.add_argument('--graph_encoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--graph_encoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')

    # Tokenizer for action decoder
    parser.add_argument('--action_decoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--action_decoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')

    # Tokenizer for graph decoder
    parser.add_argument('--graph_decoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--graph_decoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')

    parser.add_argument('--replacement_tokenizer_path', type=str, default=None, help='path to tokenizer')
    parser.add_argument('--n_vocab', type=int, default=None, help='')

    parser.add_argument('--save_path', type=str, help='path to save results (train mode)')
    parser.add_argument('--load_path', type=str, help='path to load weights (predict mode)')
    parser.add_argument('--cache_path', type=str, help='path to cache pretokenized data results')
    parser.add_argument('--refresh_cache', action='store_true', default=False, help='refresh data cache')

    parser.add_argument('--save_gen', action='store_true', default=False, help='Save generation results (test mode)')

    # Dataloading
    parser.add_argument('--train_batchsize', type=int, default=16, help='')
    parser.add_argument('--val_batchsize', type=int, default=16, help='')
    parser.add_argument('--test_batchsize', type=int, default=16, help='')
    parser.add_argument('--num_workers', type=int, default=8, help='')

    # Data
    parser.add_argument('--skip_nohole', action='store_true', default=False, help='')
    parser.add_argument('--train_negative_mode', type=str, default=None, help='template negative mode')
    parser.add_argument('--val_negative_mode', type=str, default=None, help='template negative mode')
    parser.add_argument('--test_negative_mode', type=str, default=None, help='template negative mode')

    parser.add_argument('--test_roms_shard', type=int, default=-1, help='for inference on shards')

    # Training/Infer Modes
    parser.add_argument('--mode', type=str, default=None, choices=['train', 'predict'], help='Select a mode')
    parser.add_argument('--mine_negatives', action='store_true', default=False, help='')
    # Task Modes
    parser.add_argument('--task_mode', type=str, default=None, choices=['raf_multitask', 'full_multitask', 'action_only', 'action_multi', 'graph_multi'], help='Select a mode')

    # Training
    parser.add_argument('--n_epochs', type=int, default=5, help='')
    parser.add_argument('--lr', type=float, default=2e-5)
    parser.add_argument('--n_accum', type=int, default=1, help='gradient accumulation')
    parser.add_argument('--log_freq', type=int, default=20, help='gradient accumulation')
    parser.add_argument('--ckpt_policy', type=str, default=None, choices=['best_only', 'all'], help='Select a mode')
    parser.add_argument('--clip', type=float, default=1.0, help='gradient clipping')
    parser.add_argument('--tie_embeddings', action='store_true', default=False, help='tie embedding weights')
    parser.add_argument('--lambda1', type=float, default=1.0, help='weight factor for action loss')
    parser.add_argument('--lambda2', type=float, default=1.0, help='weight factor for graph loss')

    # Commonsense data
    parser.add_argument('--add_ckg', action='store_true', default=False, help='add ckg triples')

    # scheduler

    # Other
    parser.add_argument('--seed', type=int, default=42, help='')

    # FP16
    parser.add_argument("--fp16", action='store_true', default=False)

    # DDP
    parser.add_argument('--distributed', action='store_true', default=False)
    parser.add_argument("--local_rank", type=int)
    #parser.add_argument("--local_world_size", type=int, default=1)
    #parser.add_argument('--world-size', default=-1, type=int,
    #                    help='number of nodes for distributed training')
    #parser.add_argument('--rank', default=-1, type=int,
    #                    help='node rank for distributed training')
    #parser.add_argument('--dist-url', default='env://', type=str,
    #                    help='url used to set up distributed training')
    #parser.add_argument('--dist-backend', default='nccl', type=str,
    #                    help='distributed backend')
    #parser.add_argument('--seed', default=None, type=int,
    #                    help='seed for initializing training. ')
    #parser.add_argument('--gpu', default=None, type=int,
    #                    help='GPU id to use.')
    #parser.add_argument('--multiprocessing-distributed', action='store_true',
    #                    help='Use multi-processing distributed training to launch '
    #                        'N processes per node, which has N GPUs. This is the '
    #                        'fastest way to use PyTorch for either single node or '
    #                        'multi node data parallel training')

    # CLS Module
    #parser.add_argument('--bart_encoder_cls', action='store_true', default=False, help='use encoder-only bart for classification')
    parser.add_argument('--use_state_vec', action='store_true', default=False, help='use state vec in classification')
    parser.add_argument('--use_graph_vec', action='store_true', default=False, help='use state vec in classification')

    parser.add_argument('--use_oracle_template', action='store_true', default=False, help='use state vec in classification')

    # Wandb 
    parser.add_argument('--wandb_project', type=str, default=None, help='wandb project name')
    parser.add_argument('--wandb_name', type=str, default=None, help='wandb run name')

    args = parser.parse_args()

    #ipdb.set_trace()
    
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    main_worker(args)
