#from tkinter import E
from worldformer2.dataset.jerichoworld_dataset_conditional import JerichoWorldDataset
#from worldformer2.dataset.jerichoworld_dataset import JerichoWorldDataset
from worldformer2.dataset.training_dataset_bart import JerichoWorldTorchTrainDataset
from worldformer2.tools.text_utils import parse_actions_text
from worldformer2.tools.metrics import action_metrics
from worldformer2.tools.metrics import print_test_metrics
from pprint import pprint

from transformers.generation_utils import BeamSearchScorer
from transformers.generation_logits_process import LogitsProcessorList, MinLengthLogitsProcessor

from transformers import AutoTokenizer
from pathlib import Path
import numpy as np
import copy
import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import torch
from torch import optim
import argparse
import pickle

#from transformers import T5Tokenizer, T5Model
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForMaskedLM, AutoModelForCausalLM
from torch.utils.data import DataLoader

from worldformer2.tokenization.custom_tokenizers import get_tokenizer
from worldformer2.models.bart import BartModel

import logging
from worldformer2.tools.logging_util import basic_logging
basic_logging()

import json
import ipdb
import wandb

class JerichoWorldTrainer(object):
    """
    """

    def __init__(self, args):
        # Prepare tokenizer and dataset
        self.args = args

        # Train mode
        if self.args.mode == 'train':
            self.register_model()
            self.model = torch.nn.DataParallel(self.model).cuda()

            if not self.args.pretrained_ckpt is None:
                state = {k[23:] :v for k,v in torch.load(self.args.pretrained_ckpt).items()}
                if self.args.init_mode in ['both', 'graph']:
                    self.model.module.graph_encoder_decoder.load_state_dict(copy.deepcopy(state))
                    logging.info(f"Loaded weights from {self.args.pretrained_ckpt} into graph bart")
                if self.args.init_mode in ['both', 'text']:
                    self.model.module.text_encoder_decoder.load_state_dict(copy.deepcopy(state))
                    logging.info(f"Loaded weights from {self.args.pretrained_ckpt} into action bart")

            #self.load_checkpoint('/home/mnskim/workspace/tbg/tbg1/ckpts/models/worldformer/exp_newdata_1_notie/model_best.pt')

            self.register_optim()
            self.register_loss()
            self.register_data()
            self.register_training_logger()
            if not self.args.wandb_project is None:
                self.register_wandb()

        elif self.args.mode == 'predict':
            self.register_model()
            self.model = torch.nn.DataParallel(self.model).cuda()
            self.load_checkpoint(self.args.load_path)
            #self.model = self.model.cuda()

            """
            to_load = {}
            # NOTE tmp hack
            tmp = torch.load(self.args.load_path)
            for k, v in tmp.items():
                to_load[k.replace('module.','')] = v
            #ipdb.set_trace()
            self.text_encoder.load_state_dict(to_load)
            #self.text_encoder = torch.nn.DataParallel(self.text_encoder).cuda()
            self.text_encoder = self.text_encoder.cuda()
            """

            self.register_data()
            #self.register_training_logger()


    def register_wandb(self):
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        #wandb.config.update({"dataset": self.args.cache_path})

    def register_model(self):
        # TODO make function for this?
        #self.text_encoder.resize_token_embeddings(len(self.tokenizer))
        #model_state_dict = self.text_encoder.state_dict()
        #orig_pos = model_state_dict['bert.embeddings.position_embeddings.weight']
        #extended_pos = torch.cat([orig_pos, orig_pos],0)
        #model_state_dict['bert.embeddings.position_embeddings.weight'] = extended_pos

        logging.info(f"Loading tokenizer from {self.args.tokenizer_path}")
        self.tokenizer = get_tokenizer(self.args.tokenizer_path, self.args.added_tokens_path, self.args.add_atomic_tokens)

        # Graph encoder
        self.graph_encoder_tokenizer = get_tokenizer(self.args.graph_encoder_tokenizer_path, self.args.graph_encoder_added_tokens_path, self.args.add_atomic_tokens)

        # Action decoder
        self.action_decoder_tokenizer = get_tokenizer(self.args.action_decoder_tokenizer_path, self.args.action_decoder_added_tokens_path, self.args.add_atomic_tokens)

        # Graph decoder
        self.graph_decoder_tokenizer = get_tokenizer(self.args.graph_decoder_tokenizer_path, self.args.graph_decoder_added_tokens_path, self.args.add_atomic_tokens)
    
        
        config = copy.deepcopy(self.args)
        config.input_text_n_vocab = len(self.tokenizer)
        #config.input_graph_n_vocab = len(self.graph_encoder_tokenizer)
        #config.output_action_n_vocab = len(self.action_decoder_tokenizer)
        #config.output_graph_n_vocab = len(self.graph_decoder_tokenizer)

        self.model = BartModel(config)


    def register_optim(self):

        self.optim = optim.Adam(self.model.parameters(), lr=self.args.lr)

    def register_loss(self):
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    def register_data(self):

        # Paths
        train_cache_path = Path(args.cache_path) / f"train_cached.pth"
        val_cache_path = Path(args.cache_path) / f"val_cached.pth"
        test_cache_path = Path(args.cache_path) / f"test_cached.pth"

        #if self.args.refresh_cache:
        if True:
            self.train_dataset = JerichoWorldDataset(self.args.train_file)
            self.train_dataset.flatten()
            val_instances = self.train_dataset.split_validation(0.1, seed=args.seed)
 
            self.val_dataset = JerichoWorldDataset(self.args.train_file)
            self.val_dataset.instances = val_instances

            self.test_dataset = JerichoWorldDataset(self.args.test_file)
            self.test_dataset.flatten()

            Path(args.cache_path).mkdir(exist_ok=True,parents=True)


        # Train
        #"""
        self.train_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer,
                                                                 self.graph_encoder_tokenizer,
                                                                 self.action_decoder_tokenizer,
                                                                 self.graph_decoder_tokenizer,
                                                                 train_cache_path,
                                                                 refresh_cache=args.refresh_cache
                                                                )
        self.train_dataset_torch.register_examples(self.train_dataset)
        self.train_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache, add_ckg=args.add_ckg)

        self.train_dataloader = DataLoader(self.train_dataset_torch,
                                           batch_size=args.train_batchsize,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           sampler=None,
                                           batch_sampler=None,
                                           collate_fn=self.train_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )

        # Val
        self.val_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer, 
                                                               self.graph_encoder_tokenizer,
                                                               self.action_decoder_tokenizer,
                                                               self.graph_decoder_tokenizer,
                                                               val_cache_path,
                                                               refresh_cache=args.refresh_cache
                                                              )
        self.val_dataset_torch.register_examples(self.val_dataset)
        self.val_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache, add_ckg=args.add_ckg)

        self.val_dataloader = DataLoader(self.val_dataset_torch,
                                           batch_size=args.train_batchsize,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           sampler=None,
                                           batch_sampler=None,
                                           collate_fn=self.val_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )

        #"""

        # Test
        self.test_dataset_torch = JerichoWorldTorchTrainDataset(self.tokenizer, 
                                                                self.graph_encoder_tokenizer,
                                                                self.action_decoder_tokenizer,
                                                                self.graph_decoder_tokenizer,
                                                                test_cache_path,
                                                                refresh_cache=args.refresh_cache
                                                               )
        self.test_dataset_torch.register_examples(self.test_dataset)
        self.test_dataset_torch.pretokenize_examples(refresh_cache=self.args.refresh_cache, add_ckg=args.add_ckg)

        self.test_dataloader = DataLoader(self.test_dataset_torch,
                                           batch_size=args.train_batchsize,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           sampler=None,
                                           batch_sampler=None,
                                           collate_fn=self.test_dataset_torch.collate_fn,
                                           pin_memory=False,
                                           worker_init_fn=None,
                                          )

        #self.train_dataset_torch[0]

    def model_forward(self, batch):
        text_encoder_input_ids = batch['input_text'].cuda()
        text_encoder_attention_mask = batch['input_text_mask'].cuda()
        graph_encoder_input_ids = batch['input_graph'].cuda()
        graph_encoder_attention_mask = batch['input_graph_mask'].cuda()
        action_decoder_input_ids = batch['output_valid_act'].cuda()
        action_decoder_attention_mask = batch['output_valid_act_mask'].cuda()
        graph_decoder_input_ids = batch['output_graph_diff'].cuda()
        graph_decoder_attention_mask = batch['output_graph_diff_mask'].cuda()
        #labels = batch['labels'].cuda()


        output = self.model(text_encoder_input_ids=text_encoder_input_ids,
                            text_encoder_attention_mask=text_encoder_attention_mask,
                            graph_encoder_input_ids=graph_encoder_input_ids,
                            graph_encoder_attention_mask=graph_encoder_attention_mask,
                            action_decoder_input_ids=action_decoder_input_ids,
                            action_decoder_attention_mask=action_decoder_attention_mask,
                            graph_decoder_input_ids=graph_decoder_input_ids,
                            graph_decoder_attention_mask=graph_decoder_attention_mask,
                           )

        return output, action_decoder_input_ids, graph_decoder_input_ids


    def train(self, dataloader):
        """
        Train a single epoch
        """
        #self.info['curr_epoch'] += 1 # Increment epoch
        self.model.train()
        loss_mini_batch = 0
        batch_losses_logging = {'action_loss':[], 'graph_loss': []}

        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):
            #if bidx == 100:
            #    break

            #losses_logging = {'action_loss':[], 'graph_loss': []}

            output, action_decoder_input_ids, graph_decoder_input_ids = self.model_forward(batch)

            loss = 0
            if self.args.task_mode in ['full_multitask', 'action_only', 'action_multi']:

                # Loss calculation
                action_labels = action_decoder_input_ids[:,1:].contiguous()
                action_labels_mask = action_labels != self.tokenizer.pad_token_id
                action_labels = action_labels * action_labels_mask
                action_labels -= (1 - action_labels_mask.long()) * 100
                action_preds = output['action_logits'][:,:-1,:].contiguous()

                action_loss = self.loss_fn(action_preds.view(-1, action_preds.size()[2]), action_labels.view(-1))
                loss = loss + action_loss
                batch_losses_logging['action_loss'].append(action_loss.item())

            if self.args.task_mode in ['full_multitask', 'graph_only', 'graph_multi']:

                graph_labels = graph_decoder_input_ids[:,1:].contiguous()
                graph_labels_mask = graph_labels != self.tokenizer.pad_token_id
                graph_labels = graph_labels * graph_labels_mask
                graph_labels -= (1 - graph_labels_mask.long()) * 100
                graph_preds = output['graph_logits'][:,:-1,:].contiguous()

                graph_loss = self.loss_fn(graph_preds.view(-1, graph_preds.size()[2]), graph_labels.view(-1))
                loss = loss + graph_loss
                batch_losses_logging['graph_loss'].append(graph_loss.item())
        
            # multitask loss
            #loss = action_loss #+ graph_loss
            loss = loss / self.args.n_accum # normalize for accumulation
            
            loss.backward()
            loss_mini_batch += loss.item()

            if ((bidx + 1) % self.args.n_accum == 0) or (bidx + 1 == len(dataloader)):
                # Iter counts the number of actual gradient update states
                #torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)                
                self.optim.step()
                self.model.zero_grad()
                #self.log_train(loss_mini_batch)

                batch_losses_logging = {k: np.sum(v)/self.args.n_accum for k, v in batch_losses_logging.items()}
                self.info['epoch_metrics'][self.info['curr_epoch']]['train'].append(batch_losses_logging)

                if self.args.wandb_project is not None:
                    wandb_update = {f"train_{k}": v for k, v in batch_losses_logging.items()}
                    #wandb_update['epoch'] = self.info['curr_epoch']
                    #wandb.log(wandb_update, step=self.info['curr_iter'])
                    wandb.log({"epoch": self.info['curr_epoch'], **wandb_update}, step=self.info['curr_iter'])

                # Reset
                batch_losses_logging = {'action_loss':[], 'graph_loss': []}

                self.info['curr_iter'] += 1

                #_tmp = self.graph_decoder_tokenizer.convert_ids_to_tokens(torch.max(graph_preds[0], dim=1)[1])
                #print(_tmp)

                #_tmp = self.tokenizer.convert_ids_to_tokens(torch.max(action_preds[0], dim=1)[1])
                #print(_tmp)
                loss_mini_batch = 0
        

        #return np.mean(epoch_losses)

    def validate(self, dataloader, mode='val', save_ckpt=False):
        """
        Validation
        """
        losses = {'action_loss':[], 'graph_loss': []}
        self.model.eval()
        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):

            with torch.no_grad():
                output, action_decoder_input_ids, graph_decoder_input_ids = self.model_forward(batch)

            #loss = 0
            if self.args.task_mode in ['full_multitask', 'action_only', 'action_multi']:

                # Loss calculation
                action_labels = action_decoder_input_ids[:,1:].contiguous()
                action_labels_mask = action_labels != self.tokenizer.pad_token_id
                action_labels = action_labels * action_labels_mask
                action_labels -= (1 - action_labels_mask.long()) * 100
                action_preds = output['action_logits'][:,:-1,:].contiguous()

                action_loss = self.loss_fn(action_preds.view(-1, action_preds.size()[2]), action_labels.view(-1))
                #loss = loss + action_loss
                losses['action_loss'].append(action_loss.item())

            if self.args.task_mode in ['full_multitask', 'graph_only', 'graph_multi']:

                graph_labels = graph_decoder_input_ids[:,1:].contiguous()
                graph_labels_mask = graph_labels != self.tokenizer.pad_token_id
                graph_labels = graph_labels * graph_labels_mask
                graph_labels -= (1 - graph_labels_mask.long()) * 100
                graph_preds = output['graph_logits'][:,:-1,:].contiguous()

                graph_loss = self.loss_fn(graph_preds.view(-1, graph_preds.size()[2]), graph_labels.view(-1))
                #loss = loss + graph_loss
                losses['graph_loss'].append(graph_loss.item())
        
            # multitask loss
            #loss = action_loss #+ graph_loss
            #loss = loss / self.args.n_accum # normalize for accumulation

            #batch_losses.append(loss.item())


        epoch_val_action_loss = np.mean(losses['action_loss'])
        epoch_val_graph_loss = np.mean(losses['graph_loss'])
        #if do_logging:
        if mode == 'val':
            self.info['epoch_metrics'][self.info['curr_epoch']]['val'] = {'action_loss': epoch_val_action_loss,'graph_loss': epoch_val_graph_loss}

        if self.args.wandb_project is not None:
            wandb_update = {f"{mode}_action_loss": epoch_val_action_loss,
                            f"{mode}_graph_loss": epoch_val_graph_loss}
            wandb_update['epoch'] = self.info['curr_epoch']
            wandb.log(wandb_update, step=self.info['curr_iter'])

        epoch_train_action_loss = np.mean([item['action_loss'] for item in self.info['epoch_metrics'][self.info['curr_epoch']]['train']])
        epoch_train_graph_loss = np.mean([item['graph_loss'] for item in self.info['epoch_metrics'][self.info['curr_epoch']]['train']])
        #self.log_val(losses)

        log_str = f"### Epoch {self.info['curr_epoch']} finished. Train loss - action: {epoch_train_action_loss:.2f}, graph: {epoch_train_graph_loss:.2f}, Valid loss - action: {epoch_val_action_loss:.2f}, graph: {epoch_val_graph_loss:.2f}"
        logging.info(log_str)

        if save_ckpt:
            if self.args.ckpt_policy == 'all':
                torch.save(self.model.state_dict(), self.save_path / f"model_epoch{self.info['curr_epoch']}.pt")


        #return np.mean(epoch_losses)

    def predict(self, dataloader, do_logging):
        """
        Inference
        """
        test_games = ['zork1','library','detective','balances','pentari','ztuu','ludicorp','deephome','temple']
        num_beams = 15

        # Savefiles
        name_ckpt = self.args.load_path.split('/')[-2]
        name_epoch = self.args.load_path.split('/')[-1]
        save_name = f"{name_ckpt}_{name_epoch}"

        if self.args.save_gen:
            path_save_gen = Path('/home/mnskim/workspace/tbg/tbg1/results/generations/') / save_name
            if not path_save_gen.exists():
                path_save_gen.mkdir(parents=True,exist_ok=True)
            data_name_stem = dataloader.dataset.__dict__['cache_path'].stem                
            save_gen_file = open(path_save_gen / f"{data_name_stem}.jsonl", 'w')

        batch_losses = []
        self.model.eval()
        metrics = {}
        for bidx, batch in enumerate(tqdm.tqdm(dataloader)):
            bsize = batch['input_text'].size(0)

            text_encoder_attention_mask = batch['input_text_mask'].cuda()
            graph_encoder_attention_mask = batch['input_graph_mask'].cuda()
            aggregator_input_mask = torch.cat([text_encoder_attention_mask, graph_encoder_attention_mask], 1)

            for ridx in range(bsize):
                rom = batch['rom'][ridx]
                ex_id = batch['ex_id'][ridx]

                if not rom in metrics:                    
                    metrics[rom] = {}
                
                if not ex_id in metrics[rom]:
                    metrics[rom][ex_id] = {'action': {'f1': [], 'em': [], 'n_tp': 0, 'n_fp': 0, 'n_fn': 0},
                                           'graph': {'f1': [], 'em': [], 'n_tp': 0, 'n_fp': 0, 'n_fn': 0},
                                           'inputs': {'text': None, 'graph': None},
                                           'outputs': {'pred_actions': None, 'target_actions': None,
                                                       'pred_graph': None, 'target_graph': None,
                                                        }
                                          }
                
                text_input = batch['input_text'][ridx][:torch.sum(batch['input_text_mask'][ridx]==1)].cuda()
                graph_input = batch['input_graph'][ridx][:torch.sum(batch['input_graph_mask'][ridx]==1)].cuda()
                text_input = text_input.unsqueeze(0)
                graph_input = graph_input.unsqueeze(0)
    
                
                with torch.no_grad():

                    text_encoder_outputs = self.model.module.text_encoder_decoder.get_encoder()(text_input.repeat_interleave(num_beams, dim=0), return_dict=True)
                    graph_encoder_outputs = self.model.module.graph_encoder_decoder.get_encoder()(graph_input.repeat_interleave(num_beams, dim=0), return_dict=True)

                    # Aggregator - no mask required here                
                    aggregator_inputs = torch.cat([text_encoder_outputs['last_hidden_state'], graph_encoder_outputs['last_hidden_state']], 1)
                    #aggregator_outputs = self.aggregator(inputs_embeds=aggregator_inputs, attention_mask=aggregator_input_mask)
                    aggregator_outputs = self.model.module.aggregator(inputs_embeds=aggregator_inputs,)
                    #aggregator_outputs_sum = torch.sum(aggregator_input_mask.unsqueeze(2) * aggregator_outputs['last_hidden_state'], 1)
                    #state_vec = aggregator_outputs_sum / (aggregator_input_mask.sum(1)).unsqueeze(1)
                    state_vec = aggregator_outputs['last_hidden_state'].mean(1)

                    if self.args.task_mode in ['full_multitask', 'action_only', 'action_multi']:

                        model_kwargs = {
                            "state_vector": state_vec,
                            "encoder_outputs": text_encoder_outputs
                        }
                            #"encoder_outputs": self.model.module.text_encoder_decoder.get_encoder()(text_input.repeat_interleave(num_beams, dim=0), return_dict=True)}


                        #pred_text = self.tokenizer.decode(self.model.module.text_encoder_decoder.generate(text_input, num_beams=15, max_length=1024)[0])
                        #dd = self.model.module.text_encoder_decoder.generate(text_input, num_beams=15, max_length=1024)[0]
                        pred_ids = self.model.module.text_encoder_decoder.generate(text_input, num_beams=15, max_length=1024, **model_kwargs)[0]
                        try:
                            assert pred_ids[:2].tolist() == [2,0]                            
                        except AssertionError as err:
                            ipdb.set_trace()
                        if pred_ids[-1] == 2:
                            pred_ids = pred_ids[2:-1]
                        else:
                            pred_ids = pred_ids[2:]

                        pred_text = self.tokenizer.decode(pred_ids)
                        pred_text = pred_text.strip()
                        # action generation

                        # assuming start and ends with </s>"
                        # ex: </s><s> take light [ACT] take all from floor</s>
                        #assert pred_text.startswith("</s>")
                        #assert pred_text.endswith("</s>")
                        #pred_text = pred_text.strip("</s>")
                        ff = batch['output_valid_act'][ridx][:torch.sum(batch['output_valid_act'][ridx]!=self.tokenizer.pad_token_id)].cuda()
                        assert ff[0] == self.tokenizer.bos_token_id
                        if ff[-1] == self.tokenizer.eos_token_id:
                            target_text = self.tokenizer.decode(ff[1:-1])
                        else:
                            target_text = self.tokenizer.decode(ff[1:])
                        target_text = target_text.strip()
                        #target_text = target_text.rstrip("</s>")

                        pred_actions = parse_actions_text(pred_text, skip_empty=True)
                        target_actions = parse_actions_text(target_text, skip_empty=True)

                        action_em, action_f1, counts = action_metrics(pred_actions, target_actions, return_counts=True)
                        tp, fp, fn = counts
                        metrics[rom][ex_id]['action']['n_tp'] = tp
                        metrics[rom][ex_id]['action']['n_fp'] = fp
                        metrics[rom][ex_id]['action']['n_fn'] = fn

                        #action_em, action_f1 = action_metrics(pred_actions, target_actions)
                        metrics[rom][ex_id]['action']['em'] = action_em
                        metrics[rom][ex_id]['action']['f1'] = action_f1

                        #print(target_actions)
                        #print(pred_actions)

            
                    if self.args.task_mode in ['full_multitask', 'graph_only', 'graph_multi']:
                        model_kwargs = {
                            "state_vector": state_vec,
                            "encoder_outputs": graph_encoder_outputs
                        }
                        graph_pred_ids = self.model.module.graph_encoder_decoder.generate(graph_input, num_beams=15, max_length=1024, **model_kwargs)[0]
                        try:
                            assert graph_pred_ids[:2].tolist() == [2,0]                            
                        except AssertionError as err:
                            ipdb.set_trace()
                        if graph_pred_ids[-1] == 2:
                            graph_pred_ids = graph_pred_ids[2:-1]
                        else:
                            graph_pred_ids = graph_pred_ids[2:]                            

                        graph_pred_text = self.tokenizer.decode(graph_pred_ids)
                        graph_pred_text = graph_pred_text.strip()
                        

                        # assuming start and ends with </s>"
                        #assert graph_pred_text.startswith("</s>")
                        #assert graph_pred_text.endswith("</s>")
                        #graph_pred_text = graph_pred_text.strip("</s>")

                        gg = batch['output_graph_diff'][ridx][:torch.sum(batch['output_graph_diff'][ridx]!=self.tokenizer.pad_token_id)].cuda()
                        assert gg[0] == self.tokenizer.bos_token_id
                        #graph_target_text = self.tokenizer.decode(gg[1:])
                        #graph_target_text = graph_target_text.rstrip("</s>")                        
                        if gg[-1] == self.tokenizer.eos_token_id:
                            graph_target_text = self.tokenizer.decode(gg[1:-1])
                        else:
                            graph_target_text = self.tokenizer.decode(gg[1:])
                        graph_target_text = graph_target_text.strip()

                        pred_graph = parse_actions_text(graph_pred_text, action_token='[TRIPLE]', skip_empty=True)
                        target_graph = parse_actions_text(graph_target_text, action_token='[TRIPLE]', skip_empty=True)

                        graph_em, graph_f1 = action_metrics(pred_graph, target_graph)
                        metrics[rom][ex_id]['graph']['em'] = graph_em
                        metrics[rom][ex_id]['graph']['f1'] = graph_f1

                        #print(target_graph)
                        #print(pred_graph)

                metrics[rom][ex_id]['inputs']['text'] = self.tokenizer.decode(text_input[0])
                metrics[rom][ex_id]['inputs']['graph'] = self.tokenizer.decode(graph_input[0])
                metrics[rom][ex_id]['outputs']['pred_actions'] = pred_actions
                metrics[rom][ex_id]['outputs']['target_actions'] = target_actions
                metrics[rom][ex_id]['outputs']['pred_graph'] = pred_graph
                metrics[rom][ex_id]['outputs']['target_graph'] = target_graph

                if self.args.save_gen:
                    single_ex = {'bidx': bidx,
                                 'ridx': ridx,
                                 'text_input': self.tokenizer.decode(text_input[0]),
                                 'graph_input': self.tokenizer.decode(graph_input[0]),
                                 'pred_actions': pred_actions,
                                 'target_actions': target_actions,
                                 'pred_graph': pred_graph,
                                 'target_graph': target_graph
                                }
                    save_gen_file.write(json.dumps(single_ex) + '\n')
               
            for rom in test_games:
                if rom in set(metrics.keys()):
                    #f1 = metrics[rom]['action']['n_tp']/float(metrics[rom]['action']['n_tp'] + 0.5*(metrics[rom]['action']['n_fp']+metrics[rom]['action']['n_fn']) )
                    #print(f"{rom}, F1: {f1:4f}, TP: {metrics[rom]['action']['n_tp']:4f}, FP: {metrics[rom]['action']['n_fp']:4f}, FN: {metrics[rom]['action']['n_fn']:4f}")
                    #print(f"{rom}, TP: {metrics[rom]['action']['n_tp']:4f}, FP: {metrics[rom]['action']['n_fp']:4f}, FN: {metrics[rom]['action']['n_fn']:4f}")
                    1

            for rom in metrics.keys():
                action_ems = [item['action']['em'] for item in metrics[rom].values()]
                action_f1s = [item['action']['f1'] for item in metrics[rom].values()]
                graph_ems = [item['graph']['em'] for item in metrics[rom].values()]
                graph_f1s = [item['graph']['f1'] for item in metrics[rom].values()]

                print(f"{rom} Action Macro EM: {np.mean(action_ems)} over {len(action_ems)} examples")
                print(f"{rom} Action Macro F1: {np.mean(action_f1s)} over {len(action_f1s)} examples")
                print(f"{rom} Graph Macro EM: {np.mean(graph_ems)} over {len(graph_ems)} examples")
                print(f"{rom} Graph Macro F1: {np.mean(graph_f1s)} over {len(graph_f1s)} examples")

            """
            rom_lengths = [(key,len(metrics[key]['action']['em'])) for key in metrics.keys()]            
            n_total = sum(np.array([item[1] for item in rom_lengths]))
            rom_weights = {key: len(metrics[key]['action']['em'])/n_total for key in metrics.keys()}

            for task in ['action', 'graph']:
                for met in ['em', 'f1']:
                    weighted_val = 0
                    for key in metrics.keys():
                        weighted_val += np.mean(metrics[key][task][met]) * rom_weights[key]
                    print(task, met, weighted_val)
            """
            #ipdb.set_trace()

        ipdb.set_trace()
        for task in ['graph','action']:
            for met in ['em', 'f1']:
                print(task,met,[round(np.mean(metrics[key][task][met]),12) for key in ['zork1','library','detective','balances','pentari','ztuu','ludicorp','deephome','temple']]) 
        ipdb.set_trace()
        #save_name = f"{Path(self.args.load_path).parent.stem}_{Path(self.args.load_path).stem}"
        #print_test_metrics(metrics, save_name)


    
    

    def register_training_logger(self):
        """
        Training logger
        optionally handle wandb
        """
        self.info = {'curr_epoch': 0,
                     'curr_iter': 0,
                     'epoch_metrics': {},
                     'train_metrics': [],
                     'val_metrics': [],
                     'best_val': None
                    }

        self.save_path = Path(self.args.save_path)
        self.save_path.mkdir(parents=True, exist_ok=True)

    def log_train(self, loss):
        """
        Call on every batch (==iter)
        """
        self.info['curr_iter'] += 1
        if not self.info['curr_epoch'] in self.info['epoch_metrics']:
            self.info['epoch_metrics'][self.info['curr_epoch']] = {}
        #self.info['train_metrics'].append((self.info['curr_epoch'], self.info['curr_iter'], loss))
        self.info['epoch_metrics'][self.info['curr_epoch']].append({})
        if self.info['curr_iter'] % self.args.log_freq == 0:
            total_avg_loss = np.mean([item[2] for item in self.info['train_metrics']])
            logging.info(f"Iter {self.info['curr_iter']} loss: {loss:.2f}. total avg loss: {total_avg_loss:.2f}")

    def log_val(self, entry):
        """
        Call on val end
        """ 
        entry = (self.info['curr_epoch'], self.info['curr_iter'], entry)
        ipdb.set_trace()
        epoch_train_loss = np.mean([item[2] for item in self.info['train_metrics'] if item[0] == self.info['curr_epoch']])
        output_str = f"Epoch {self.info['curr_epoch']} val loss: {loss:.2f}, train loss: {epoch_train_loss:.2f}."
        self.info['val_metrics'].append(entry)

        if self.info['best_val'] is None:
            self.info['best_val'] = entry
        else:
            if self.info['best_val'][2] > entry[2]:
                self.info['best_val'] = entry
                output_str += ' (new best ckpt)'
                torch.save(self.model.state_dict(), self.save_path / f"model_best.pt")
                if self.args.ckpt_policy == 'best_only':
                    pass
                elif self.args.ckpt_policy == 'all':
                    torch.save(self.model.state_dict(), self.save_path / f"model_epoch{self.info['curr_epoch']}.pt")
        logging.info(output_str)

    def generate(self, input_text, decoder):
        """
        args:
            input_text:
            decoder: 'graph' or 'action'
        """
        input_ids = self.graph_decoder_tokenizer.encode(input_text, return_tensors='pt')
        beam_output = self.model.generate(
                        input_ids,
                        max_length=50,
                        num_beams=15,
                        early_stopping=True
        )
        return beam_output

    def load_checkpoint(self, ckpt):
        state = torch.load(ckpt)
        self.model.load_state_dict(state)
        print(f"Loaded weights from {ckpt}")

def main(args):
    """
    Pretrain the text encoder on mlm loss. Do training and eval.
    """  

    if args.mode == 'train':
        # Prepare trainer
        trainer = JerichoWorldTrainer(args)

        for epoch in range(args.n_epochs):
            trainer.info['curr_epoch'] = epoch
            trainer.info['epoch_metrics'][epoch] = {'train': [], 'val': []}

            trainer.train(trainer.train_dataloader)
            trainer.validate(trainer.val_dataloader, mode='val', save_ckpt=True)
            trainer.validate(trainer.test_dataloader, mode='test', save_ckpt=False)
            #ipdb.set_trace()

    if args.mode == 'predict':

        predictor = JerichoWorldTrainer(args)

        #predictor.predict(predictor.val_dataloader, do_logging=False)
        
        predictor.predict(predictor.test_dataloader, do_logging=False)
        
        #predictor.predict_raf_comparison(predictor.test_dataloader, do_logging=False)        


if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', type=str, help='train json file')
    parser.add_argument('--test_file', type=str, help='train json file')

    # Pretrained ckpt
    parser.add_argument('--pretrained_ckpt', type=str, default=None, help='path/name')
    parser.add_argument('--init_mode', type=str, default=None, choices=['both','text','graph'], help='how to use pretrained ckpt')


    # Tokenizers
    parser.add_argument('--tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--added_tokens_path', type=str, help='path to additional tokenizer data created from data')
    parser.add_argument('--add_atomic_tokens', action='store_true', default=False, help='')

    # Tokenizer for graph encoder
    parser.add_argument('--graph_encoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--graph_encoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')

    # Tokenizer for action decoder
    parser.add_argument('--action_decoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--action_decoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')

    # Tokenizer for graph decoder
    parser.add_argument('--graph_decoder_tokenizer_path', type=str, help='path/name to huggingface tokenizer')
    parser.add_argument('--graph_decoder_added_tokens_path', type=str, default=None, help='path to additional tokenizer data created from data')
 
    parser.add_argument('--replacement_tokenizer_path', type=str, default=None, help='path to tokenizer')
    parser.add_argument('--n_vocab', type=int, default=None, help='')

    parser.add_argument('--save_path', type=str, help='path to save results (train mode)')
    parser.add_argument('--load_path', type=str, help='path to load weights (predict mode)')
    parser.add_argument('--cache_path', type=str, help='path to cache pretokenized data results')
    parser.add_argument('--refresh_cache', action='store_true', default=False, help='refresh data cache')

    parser.add_argument('--save_gen', action='store_true', default=False, help='Save generation results (test mode)')

    # Dataloading
    parser.add_argument('--train_batchsize', type=int, default=16, help='')
    parser.add_argument('--val_batchsize', type=int, default=16, help='')
    parser.add_argument('--test_batchsize', type=int, default=16, help='')
    parser.add_argument('--num_workers', type=int, default=8, help='')

    # Training/Infer Modes
    parser.add_argument('--mode', type=str, default=None, choices=['train', 'predict'], help='Select a mode')
    # Task Modes
    parser.add_argument('--task_mode', type=str, default=None, choices=['full_multitask', 'action_only', 'action_multi', 'graph_multi'], help='Select a mode')

    # Training
    parser.add_argument('--n_epochs', type=int, default=5, help='')
    parser.add_argument('--lr', type=float, default=2e-5)
    parser.add_argument('--n_accum', type=int, default=1, help='gradient accumulation')
    parser.add_argument('--log_freq', type=int, default=20, help='gradient accumulation')
    parser.add_argument('--ckpt_policy', type=str, default=None, choices=['best_only', 'all'], help='Select a mode')
    parser.add_argument('--clip', type=float, default=1.0, help='gradient clipping')
    parser.add_argument('--tie_embeddings', action='store_true', default=False, help='tie embedding weights')

    # Commonsense data
    parser.add_argument('--add_ckg', action='store_true', default=False, help='add ckg triples')


    # scheduler

    # Other
    parser.add_argument('--seed', type=int, default=42, help='')

    # Wandb 
    parser.add_argument('--wandb_project', type=str, default=None, help='wandb project name')
    parser.add_argument('--wandb_name', type=str, default=None, help='wandb run name')

    args = parser.parse_args()

    save_path = "/home/mnskim/workspace/tbg/tbg1/ckpts/models/pretraining"

 
    main(args)
