import os, sys, json, logging, pprint, tqdm
import numpy as np
import torch
import jieba
from torch.utils.data import DataLoader
from transformers import MBart50Tokenizer, MT5Tokenizer, XLMRobertaTokenizer
from model import GenerativeModel, Prefix_fn_cls
from data import GenDataset, EEDataset
from utils import compute_f1, lang_map
from argparse import ArgumentParser, Namespace
from template_generate_ace import eve_template_generator
import ipdb
from collections import Counter

# configuration
parser = ArgumentParser()
parser.add_argument('-ced', '--ed_config', required=False)
parser.add_argument('-ceae', '--eae_config', required=True)
parser.add_argument('-ed', '--ed_model', required=False)
parser.add_argument('-eae', '--eae_model', required=True)
parser.add_argument('-g', '--gold_trigger', action='store_true', default=False)
parser.add_argument('-w', '--write_output', type=str, required=False)
parser.add_argument('--no_dev', action='store_true', default=False)
parser.add_argument('--ere', action='store_true', default=False)
parser.add_argument('--constrained_decode', default=False, action='store_true')
parser.add_argument('--beam', type=int)
args = parser.parse_args()
with open(args.eae_config) as fp:
    eae_config = json.load(fp)
eae_config = Namespace(**eae_config)

if args.beam:
    eae_config.beam_size = args.beam

if args.ere:
    from template_generate_ere import eve_template_generator

# fix random seed
np.random.seed(eae_config.seed)
torch.manual_seed(eae_config.seed)
torch.backends.cudnn.enabled = False

# logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)

def get_span_idx(lang, pi, pieces, token_start_idxs, span, tokenizer, no_bos=False, trigger_span=None):
    if lang == "english" or lang == "arabic" or lang == "spanish":
        t_span = span.split(' ')
    elif lang == "chinese":
        t_span = [_ for _ in jieba.cut(span) if _.strip()]
    
    words = []
    for s in t_span:
        if not no_bos:
            words.extend(tokenizer.encode(s, add_special_tokens=True)[1:-1]) # ignore [SOS] and [EOS]
        else:
            words.extend(tokenizer.encode(s, add_special_tokens=True)[:-1]) # ignore [SOS] and [EOS]
    
    candidates = []
    for i in range(len(pieces)):
        j = 0
        k = 0
        while j < len(words) and i+k < len(pieces):
            if pieces[i+k] == words[j]:
                j += 1
                k += 1
            elif tokenizer.decode(words[j]) == "":
                j += 1
            elif tokenizer.decode(pieces[i+k]) == "":
                k += 1
            else:
                break
        if j == len(words):
            candidates.append((i, i+k))
            
    candidates = [(token_start_idxs.index(c1), token_start_idxs.index(c2))for c1, c2 in candidates if c1 in token_start_idxs and c2 in token_start_idxs]
    if len(candidates) < 1:
        return -1, -1
    else:
        if trigger_span is None:
            return candidates[0]
        else:
            return sorted(candidates, key=lambda x: np.abs(trigger_span[0]-x[0]))[0]
        
def get_span_idx2(tokens, span, trigger_span=None):
    candidates = []
    for i in range(len(tokens)):
        for j in range(i, len(tokens)):
            c_string = "".join(tokens[i:j+1])
            if c_string == span:
                candidates.append((i, j+1))
                break
            elif not span.startswith(c_string):
                break
                
    if len(candidates) < 1:
        return -1, -1
    else:
        if trigger_span is None:
            return candidates[0]
        else:
            return sorted(candidates, key=lambda x: np.abs(trigger_span[0]-x[0]))[0]
        
def cal_scores(gold_triggers, pred_triggers, gold_roles, pred_roles):
    # tri_id
    gold_tri_id_num, pred_tri_id_num, match_tri_id_num = 0, 0, 0
    for gold_trigger, pred_trigger in zip(gold_triggers, pred_triggers):
        gold_set = set([(t[0], t[1]) for t in gold_trigger])
        pred_set = set([(t[0], t[1]) for t in pred_trigger])
        gold_tri_id_num += len(gold_set)
        pred_tri_id_num += len(pred_set)
        match_tri_id_num += len(gold_set & pred_set)
    
    # tri_cls
    gold_tri_cls_num, pred_tri_cls_num, match_tri_cls_num = 0, 0, 0
    for gold_trigger, pred_trigger in zip(gold_triggers, pred_triggers):
        gold_set = set(gold_trigger)
        pred_set = set(pred_trigger)
        gold_tri_cls_num += len(gold_set)
        pred_tri_cls_num += len(pred_set)
        match_tri_cls_num += len(gold_set & pred_set)
    
    # arg_id
    gold_arg_id_num, pred_arg_id_num, match_arg_id_num = 0, 0, 0
    for gold_role, pred_role in zip(gold_roles, pred_roles):
        gold_set = set([(r[0][2],)+r[1][:-1] for r in gold_role])
        pred_set = set([(r[0][2],)+r[1][:-1] for r in pred_role])
        
        gold_arg_id_num += len(gold_set)
        pred_arg_id_num += len(pred_set)
        match_arg_id_num += len(gold_set & pred_set)
        
    # arg_cls
    gold_arg_cls_num, pred_arg_cls_num, match_arg_cls_num = 0, 0, 0
    for gold_role, pred_role in zip(gold_roles, pred_roles):
        gold_set = set([(r[0][2],)+r[1] for r in gold_role])
        pred_set = set([(r[0][2],)+r[1] for r in pred_role])
        
        gold_arg_cls_num += len(gold_set)
        pred_arg_cls_num += len(pred_set)
        match_arg_cls_num += len(gold_set & pred_set)
    
    scores = {
        'tri_id': (gold_tri_id_num, pred_tri_id_num, match_tri_id_num) + compute_f1(pred_tri_id_num, gold_tri_id_num, match_tri_id_num),
        'tri_cls': (gold_tri_cls_num, pred_tri_cls_num, match_tri_cls_num) + compute_f1(pred_tri_cls_num, gold_tri_cls_num, match_tri_cls_num),
        'arg_id': (gold_arg_id_num, pred_arg_id_num, match_arg_id_num) + compute_f1(pred_arg_id_num, gold_arg_id_num, match_arg_id_num),
        'arg_cls': (gold_arg_cls_num, pred_arg_cls_num, match_arg_cls_num) + compute_f1(pred_arg_cls_num, gold_arg_cls_num, match_arg_cls_num),
    }
    
    return scores

# set GPU device
torch.cuda.set_device(eae_config.gpu_device)

# check ed_model
assert (args.ed_config and args.ed_model) or args.gold_trigger
if args.ed_model:
    with open(args.ed_config) as fp:
        ed_config = json.load(fp)
    ed_config = Namespace(**ed_config)

# check valid styles
assert np.all([style in ['special_type', 'en_type', 'event_type', 'event_type_sent', 'triggers', 'triggerword', 'template'] for style in eae_config.input_style])
assert np.all([style in ['argument:sentence', 'argument:roletype', 'argument:englishrole'] for style in eae_config.output_style])
if args.ed_model:
    assert np.all([style in ['event_type_sent', 'keywords', 'template'] for style in ed_config.input_style])
    assert np.all([style in ['trigger:sentence'] for style in ed_config.output_style])
    
# check valid language
assert eae_config.lang in lang_map
lang_code = lang_map[eae_config.lang]
              
# tokenizer
if eae_config.model_name == "facebook/mbart-large-50":
    eae_tokenizer = MBart50Tokenizer.from_pretrained(eae_config.model_name, src_lang=lang_code, tgt_lang=lang_code, cache_dir=eae_config.cache_dir)
elif eae_config.model_name.startswith("copy+facebook/mbart-large-50"):
    model_name = eae_config.model_name.split('copy+', 1)[1]
    eae_tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang=lang_code, tgt_lang=lang_code, cache_dir=eae_config.cache_dir)
elif eae_config.model_name.startswith("google/mt5-"):
    eae_tokenizer = MT5Tokenizer.from_pretrained(eae_config.model_name, cache_dir=eae_config.cache_dir)
elif eae_config.model_name.startswith("copy+google/mt5-"):
    model_name = eae_config.model_name.split('copy+', 1)[1]
    eae_tokenizer = MT5Tokenizer.from_pretrained(model_name, cache_dir=eae_config.cache_dir)
elif eae_config.model_name.startswith("xlm-roberta-"):
    eae_tokenizer = XLMRobertaTokenizer.from_pretrained(eae_config.model_name, cache_dir=eae_config.cache_dir)
elif eae_config.model_name.startswith("copy+xlm-roberta-"):
    model_name = eae_config.model_name.split('copy+', 1)[1]
    eae_tokenizer = XLMRobertaTokenizer.from_pretrained(model_name, cache_dir=eae_config.cache_dir)

special_tokens = []
sep_tokens = []
if "event_type" in eae_config.input_style:
    sep_tokens += ["<--e_type-->", "<--s_type-->"]
if "event_type_sent" in eae_config.input_style:
    sep_tokens += ["<--e_sent-->"]
if "triggerword" in eae_config.input_style:
    sep_tokens += ["<--triggerword-->"]
if "template" in eae_config.input_style:
    sep_tokens += ["<--template-->"]
if "special_type" in eae_config.input_style:
    special_tokens += [
        '<--movement-->', '<--transaction-->', '<--business-->', '<--conflict-->', '<--contact-->', '<--personnel-->', '<--justice-->', 
        '<--be-born-->', '<--marry-->', '<--divorce-->', '<--injure-->', '<--die-->', '<--transport-->', '<--transfer-ownership-->', 
        '<--transfer-money-->', '<--start-organization-->', '<--merge-organization-->', '<--declare-bankruptcy-->', '<--end-organization-->', 
        '<--attack-->', '<--demonstrate-->', '<--meet-->', '<--phone-write-->', '<--start-position-->', '<--end-position-->', '<--nominate-->', 
        '<--elect-->', '<--arrest-jail-->', '<--release-parole-->', '<--trial-hearing-->', '<--charge-indict-->', '<--sue-->', '<--convict-->', 
        '<--sentence-->', '<--fine-->', '<--execute-->', '<--extradite-->', '<--acquit-->', '<--pardon-->', '<--appeal-->', 
    ]
if "argument:roletype" in eae_config.output_style:
    special_tokens += [
        "<--Person-->", "</--Person-->", 
        "<--Entity-->", "</--Entity-->", 
        "<--Defendant-->", "</--Defendant-->", 
        "<--Prosecutor-->", "</--Prosecutor-->", 
        "<--Plaintiff-->", "</--Plaintiff-->", 
        "<--Buyer-->", "</--Buyer-->", 
        "<--Artifact-->", "</--Artifact-->", 
        "<--Seller-->", "</--Seller-->", 
        "<--Destination-->", "</--Destination-->", 
        "<--Origin-->", "</--Origin-->", 
        "<--Vehicle-->", "</--Vehicle-->", 
        "<--Agent-->", "</--Agent-->", 
        "<--Attacker-->", "</--Attacker-->", 
        "<--Target-->", "</--Target-->", 
        "<--Victim-->", "</--Victim-->", 
        "<--Instrument-->", "</--Instrument-->", 
        "<--Giver-->", "</--Giver-->", 
        "<--Recipient-->", "</--Recipient-->", 
        "<--Org-->", "</--Org-->", 
        "<--Place-->", "</--Place-->", 
        "<--Adjudicator-->", "</--Adjudicator-->", 
        "[and]" ,"[None]"]
if "argument:englishrole" in eae_config.output_style:
    special_tokens += [
        "[CLS_SEP]", "[and]" ,"[None]"]
eae_tokenizer.add_tokens(sep_tokens+special_tokens)    
if args.ed_model:
    ed_tokenizer = AutoTokenizer.from_pretrained(ed_config.model_name, cache_dir=ed_config.cache_dir)
    special_tokens = ['<Trigger>', '<sep>']
    ed_tokenizer.add_tokens(special_tokens)
    
if eae_config.in_start_code == "lang_code":
    in_start_code = eae_tokenizer.lang_code_to_id[lang_code]
elif eae_config.in_start_code == "bos_code":
    in_start_code = eae_tokenizer.bos_token_id
elif eae_config.in_start_code == "none":
    in_start_code = -1
    
if eae_config.out_start_code == "lang_code":
    out_start_code = eae_tokenizer.lang_code_to_id[lang_code]
elif eae_config.out_start_code == "bos_code":
    out_start_code = eae_tokenizer.bos_token_id
elif eae_config.out_start_code == "eos_code":
    out_start_code = eae_tokenizer.eos_token_id
elif eae_config.out_start_code == "pad_code":
    out_start_code = eae_tokenizer.pad_token_id
    
no_bos = True if ((eae_config.model_name.startswith("google/mt5-")) or (eae_config.model_name.startswith("copy+google/mt5-"))) else False

# load data
dev_set = EEDataset(eae_tokenizer, eae_config.dev_file, max_length=eae_config.max_length, no_bos=no_bos)
test_set = EEDataset(eae_tokenizer, eae_config.test_file, max_length=eae_config.max_length, no_bos=no_bos)
dev_batch_num = len(dev_set) // eae_config.eval_batch_size + (len(dev_set) % eae_config.eval_batch_size != 0)
test_batch_num = len(test_set) // eae_config.eval_batch_size + (len(test_set) % eae_config.eval_batch_size != 0)
with open(eae_config.vocab_file) as f:
    vocab = json.load(f)

# load model
logger.info(f"Loading model from {args.eae_model}")
eae_model = GenerativeModel(eae_config, eae_tokenizer)

# from prettytable import PrettyTable

# def count_parameters(model):
#     table = PrettyTable(["Modules", "Parameters"])
#     total_params = 0
#     for name, parameter in model.named_parameters():
#         if not parameter.requires_grad: continue
#         param = parameter.numel()
#         table.add_row([name, param])
#         total_params+=param
#     print(table)
#     print(f"Total Trainable Params: {total_params}")
#     return total_params
# count_parameters(eae_model)

eae_model.load_state_dict(torch.load(args.eae_model, map_location=f'cuda:{eae_config.gpu_device}'))
eae_model.cuda(device=eae_config.gpu_device)
eae_model.eval()
if args.ed_model:
    logger.info(f"Loading model from {args.ed_model}")
    ed_model = GenerativeModel(ed_config, ed_tokenizer)
    ed_model.load_state_dict(torch.load(args.ed_model, map_location=f'cuda:{ed_config.gpu_device}'))
    ed_model.cuda(device=ed_config.gpu_device)
    ed_model.eval()

# eval dev set
if not args.no_dev:
    progress = tqdm.tqdm(total=dev_batch_num, ncols=75, desc='Dev')
    dev_gold_triggers, dev_gold_roles, dev_pred_triggers, dev_pred_roles = [], [], [], []
    dev_pred_wnd_ids, dev_gold_eae_outputs, dev_pred_eae_outputs, dev_eae_inputs = [], [], [], []
    for batch in DataLoader(dev_set, batch_size=eae_config.eval_batch_size, shuffle=False, collate_fn=dev_set.collate_fn):
        progress.update(1)
        
        # trigger predictions
        if args.gold_trigger:
            p_triggers = batch.triggers
        else:
            p_triggers = [[] for _ in range(eae_config.eval_batch_size)]
            for event_type in vocab['event_type_itos']:
                if args.ere:
                    theclass = getattr(sys.modules['template_generate_ere'], event_type.replace(':', '_').replace('-', '_'), False)
                else:
                    theclass = getattr(sys.modules['template_generate_ace'], event_type.replace(':', '_').replace('-', '_'), False)
                
                inputs = []
                for tokens in batch.tokens:
                    template = theclass(ed_config.input_style, ed_config.output_style, tokens, event_type)
                    inputs.append(template.generate_input_str(''))
                
                inputs = ed_tokenizer(inputs, return_tensors='pt', padding=True, max_length=ed_config.max_length)
                enc_idxs = inputs['input_ids'].cuda()
                enc_attn = inputs['attention_mask'].cuda()
                
                outputs = ed_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, num_beams=4, max_length=ed_config.max_output_length)
                final_outputs = [ed_tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in outputs]
                
                for bid, (tokens, p_text) in enumerate(zip(batch.tokens, final_outputs)):
                    template = theclass(ed_config.input_style, ed_config.output_style, tokens, event_type)
                    pred_object = template.decode(p_text)
                    triggers_ = [get_span_idx(batch.piece_idxs[bid], batch.token_start_idxs[bid], span, ed_tokenizer)+(event_type, ) for span, _, _ in pred_object]
                    triggers_ = [t for t in triggers_ if t[0] != -1]
                    p_triggers[bid].extend(triggers_)
                
            if ed_config.ignore_first_header:
                for bid, wnd_id in enumerate(batch.wnd_ids):
                    if int(wnd_id.split('-')[-1]) < 4:
                        p_triggers[bid] = []
        
        # argument predictions
        p_roles = [[] for _ in range(eae_config.eval_batch_size)]
        p_eae_outputs = [[] for _ in range(eae_config.eval_batch_size)]
        g_eae_outputs = [[] for _ in range(eae_config.eval_batch_size)]
        p_eae_inputs = [[] for _ in range(eae_config.eval_batch_size)]
        event_templates = []
        for tokens, triggers, roles in zip(batch.tokens, p_triggers, batch.roles):
            event_templates.append(eve_template_generator(tokens, triggers, roles, eae_config.input_style, eae_config.output_style, vocab, eae_config.lang, False))
            
        inputs = []
        gold_outputs = []
        events = []
        bids = []
        for i, event_temp in enumerate(event_templates):
            for data in event_temp.get_training_data():
                inputs.append(data[0])
                gold_outputs.append(data[1])
                events.append(data[2])
                bids.append(i)
                p_eae_inputs[i].append(data[0])
        
        if len(inputs) > 0:
            inputs = eae_tokenizer(inputs, return_tensors='pt', padding=True, max_length=eae_config.max_length+2)
            enc_idxs = inputs['input_ids']
            if not no_bos:
                enc_idxs[:, 0] = in_start_code
            enc_idxs = enc_idxs.cuda()
            enc_attn = inputs['attention_mask'].cuda()
            
            forced_bos_token_id = out_start_code if out_start_code != eae_model.model.config.decoder_start_token_id else None
            #eae_model.model._cache_input_ids = enc_idxs
            if eae_config.beam_size == 1:
                eae_model.model._cache_input_ids = enc_idxs
            else:
                expanded_return_idx = (
                    torch.arange(enc_idxs.shape[0]).view(-1, 1).repeat(1, eae_config.beam_size).view(-1).to(enc_idxs.device)
                )
                input_ids = enc_idxs.index_select(0, expanded_return_idx)
                eae_model.model._cache_input_ids = input_ids
            with torch.no_grad():
                if args.constrained_decode:
                    prefix_fn_obj = Prefix_fn_cls(eae_tokenizer, ["[and]"], enc_idxs)
                    outputs = eae_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, 
                            num_beams=eae_config.beam_size, 
                            #num_beams=1,
                            max_length=eae_config.max_output_length,
                            forced_bos_token_id=forced_bos_token_id,
                            prefix_allowed_tokens_fn=lambda batch_id, sent: prefix_fn_obj.get(batch_id, sent)
                            )
                else:
                    outputs = eae_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, 
                        num_beams=eae_config.beam_size, max_length=eae_config.max_output_length, 
                        forced_bos_token_id=forced_bos_token_id)
            final_outputs = [eae_tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in outputs]

            for p_text, g_text, info, bid in zip(final_outputs, gold_outputs, events, bids):
                if args.ere:
                    theclass = getattr(sys.modules['template_generate_ere'], info['event type'].replace(':', '_').replace('-', '_'), False)
                else:
                    theclass = getattr(sys.modules['template_generate_ace'], info['event type'].replace(':', '_').replace('-', '_'), False)
                assert theclass
                template = theclass(eae_config.input_style, eae_config.output_style, info['tokens'], info['event type'], eae_config.lang, info)
                
                pred_object = template.decode(p_text)
#                 pred_object = template.decode(g_text)

                for span, role_type, _ in pred_object:
                    if eae_config.lang == "chinese":
                        sid, eid = get_span_idx2(batch.tokens[bid], span, trigger_span=info['trigger span'])
                    else:
                        sid, eid = get_span_idx(eae_config.lang, batch.pieces[bid], batch.piece_idxs[bid], batch.token_start_idxs[bid], span, eae_tokenizer, no_bos, trigger_span=info['trigger span'])

                    if sid == -1:
                        continue
                    p_roles[bid].append(((info['trigger span']+(info['event type'],)), (sid, eid, role_type)))
                    
                g_eae_outputs[bid].append(g_text)
                p_eae_outputs[bid].append(p_text)

        p_roles = [sorted(set(role)) for role in p_roles]
        
        dev_gold_triggers.extend(batch.triggers)
        dev_gold_roles.extend(batch.roles)
        dev_pred_triggers.extend(p_triggers)
        dev_pred_roles.extend(p_roles)
        dev_pred_wnd_ids.extend(batch.wnd_ids)
        dev_gold_eae_outputs.extend(g_eae_outputs)
        dev_pred_eae_outputs.extend(p_eae_outputs)
        dev_eae_inputs.extend(p_eae_inputs)
                
    progress.close()
    
    # calculate scores
    dev_scores = cal_scores(dev_gold_triggers, dev_pred_triggers, dev_gold_roles, dev_pred_roles)
    
    print("---------------------------------------------------------------------")
    print('Trigger I  - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
        dev_scores['tri_id'][3] * 100.0, dev_scores['tri_id'][2], dev_scores['tri_id'][1], 
        dev_scores['tri_id'][4] * 100.0, dev_scores['tri_id'][2], dev_scores['tri_id'][0], dev_scores['tri_id'][5] * 100.0))
    print('Trigger C  - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
        dev_scores['tri_cls'][3] * 100.0, dev_scores['tri_cls'][2], dev_scores['tri_cls'][1], 
        dev_scores['tri_cls'][4] * 100.0, dev_scores['tri_cls'][2], dev_scores['tri_cls'][0], dev_scores['tri_cls'][5] * 100.0))
    print("---------------------------------------------------------------------")
    print('Role I     - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
        dev_scores['arg_id'][3] * 100.0, dev_scores['arg_id'][2], dev_scores['arg_id'][1], 
        dev_scores['arg_id'][4] * 100.0, dev_scores['arg_id'][2], dev_scores['arg_id'][0], dev_scores['arg_id'][5] * 100.0))
    print('Role C     - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
        dev_scores['arg_cls'][3] * 100.0, dev_scores['arg_cls'][2], dev_scores['arg_cls'][1], 
        dev_scores['arg_cls'][4] * 100.0, dev_scores['arg_cls'][2], dev_scores['arg_cls'][0], dev_scores['arg_cls'][5] * 100.0))
    print("---------------------------------------------------------------------")
    
    
    if args.write_output:
        outputs = {}
        for (dev_pred_wnd_id, dev_pred_trigger, dev_gold_role, dev_pred_role, dev_gold_eae_output, dev_pred_eae_output, dev_eae_input) in zip(
            dev_pred_wnd_ids, dev_pred_triggers, dev_gold_roles, dev_pred_roles, dev_gold_eae_outputs, dev_pred_eae_outputs, dev_eae_inputs):
            outputs[dev_pred_wnd_id] = {
                "input": dev_eae_input, 
                "triggers": dev_pred_trigger,
                "roles": [(t[0], t[1][:2]+('', ), t[1][2]) for t in dev_pred_role],
                "g_roles": dev_gold_role,
                "g_text": dev_gold_eae_output,
                "p_text": dev_pred_eae_output,
                "entities": [], 
                "entity coreference": [], 
                "tokens": [], 
            }
            
        if not os.path.exists(args.write_output):
            os.makedirs(args.write_output)
            
        with open(os.path.join(args.write_output, 'dev.pred.json'), 'w') as fp:
            json.dump(outputs, fp, indent=2)
    
    
# test set
progress = tqdm.tqdm(total=test_batch_num, ncols=75, desc='Test')
test_gold_triggers, test_gold_roles, test_pred_triggers, test_pred_roles = [], [], [], []
test_pred_wnd_ids, test_gold_eae_outputs, test_pred_eae_outputs, test_eae_inputs = [], [], [], []
gold_target_len = Counter()
pred_target_len = Counter()
gold_item_len = Counter()
pred_item_len = Counter()
for batch in DataLoader(test_set, batch_size=eae_config.eval_batch_size, shuffle=False, collate_fn=test_set.collate_fn):
    progress.update(1)
    
    # trigger predictions
    if args.gold_trigger:
        p_triggers = batch.triggers
    else:
        p_triggers = [[] for _ in range(eae_config.eval_batch_size)]
        for event_type in vocab['event_type_itos']:
            if args.ere:
                theclass = getattr(sys.modules['template_generate_ere'], event_type.replace(':', '_').replace('-', '_'), False)
            else:
                theclass = getattr(sys.modules['template_generate_ace'], event_type.replace(':', '_').replace('-', '_'), False)
            
            inputs = []
            for tokens in batch.tokens:
                template = theclass(ed_config.input_style, ed_config.output_style, tokens, event_type)
                inputs.append(template.generate_input_str(''))
            
            inputs = ed_tokenizer(inputs, return_tensors='pt', padding=True, max_length=ed_config.max_length)
            enc_idxs = inputs['input_ids'].cuda()
            enc_attn = inputs['attention_mask'].cuda()
            
            outputs = ed_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, num_beams=4, max_length=ed_config.max_output_length)
            final_outputs = [ed_tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in outputs]
            
            for bid, (tokens, p_text) in enumerate(zip(batch.tokens, final_outputs)):
                template = theclass(ed_config.input_style, ed_config.output_style, tokens, event_type)
                pred_object = template.decode(p_text)
                triggers_ = [get_span_idx(batch.piece_idxs[bid], batch.token_start_idxs[bid], span, ed_tokenizer)+(event_type, ) for span, _, _ in pred_object]
                triggers_ = [t for t in triggers_ if t[0] != -1]
                p_triggers[bid].extend(triggers_)
            
        if ed_config.ignore_first_header:
            for bid, wnd_id in enumerate(batch.wnd_ids):
                if int(wnd_id.split('-')[-1]) < 4:
                    p_triggers[bid] = []
    
    # argument predictions
    p_roles = [[] for _ in range(eae_config.eval_batch_size)]
    p_eae_outputs = [[] for _ in range(eae_config.eval_batch_size)]
    g_eae_outputs = [[] for _ in range(eae_config.eval_batch_size)]
    p_eae_inputs = [[] for _ in range(eae_config.eval_batch_size)]
    event_templates = []
    for tokens, triggers, roles in zip(batch.tokens, p_triggers, batch.roles):
        event_templates.append(eve_template_generator(tokens, triggers, roles, eae_config.input_style, eae_config.output_style, vocab, eae_config.lang, False))
        
    inputs = []
    gold_outputs = []
    events = []
    bids = []
    for i, event_temp in enumerate(event_templates):
        for data in event_temp.get_training_data():
            inputs.append(data[0])
            gold_outputs.append(data[1])
            events.append(data[2])
            bids.append(i)
            p_eae_inputs[i].append(data[0])
    
    if len(inputs) > 0:
        inputs = eae_tokenizer(inputs, return_tensors='pt', padding=True, max_length=eae_config.max_length+2)
        enc_idxs = inputs['input_ids']
        if not no_bos:
            enc_idxs[:, 0] = in_start_code
        enc_idxs = enc_idxs.cuda()
        enc_attn = inputs['attention_mask'].cuda()

        forced_bos_token_id = out_start_code if out_start_code != out_start_code != eae_model.model.config.decoder_start_token_id else None

        if eae_config.beam_size == 1:
            eae_model.model._cache_input_ids = enc_idxs
        else:
            expanded_return_idx = (
                torch.arange(enc_idxs.shape[0]).view(-1, 1).repeat(1, eae_config.beam_size).view(-1).to(enc_idxs.device)
            )
            input_ids = enc_idxs.index_select(0, expanded_return_idx)
            eae_model.model._cache_input_ids = input_ids

        with torch.no_grad():
            if args.constrained_decode:
                prefix_fn_obj = Prefix_fn_cls(eae_tokenizer, ["[and]"], enc_idxs)
                outputs = eae_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, 
                    num_beams=eae_config.beam_size, 
                    # num_beams=1,
                    max_length=eae_config.max_output_length,
                    forced_bos_token_id=forced_bos_token_id,
                    prefix_allowed_tokens_fn=lambda batch_id, sent: prefix_fn_obj.get(batch_id, sent)
                    )
            else:
                outputs = eae_model.model.generate(input_ids=enc_idxs, attention_mask=enc_attn, 
                    num_beams=eae_config.beam_size, 
                    # num_beams=1,
                    max_length=eae_config.max_output_length, 
                    forced_bos_token_id=forced_bos_token_id)

        final_outputs = [eae_tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in outputs]

        for p_text, g_text, info, bid in zip(final_outputs, gold_outputs, events, bids):
            if args.ere:
                theclass = getattr(sys.modules['template_generate_ere'], info['event type'].replace(':', '_').replace('-', '_'), False)
            else:
                theclass = getattr(sys.modules['template_generate_ace'], info['event type'].replace(':', '_').replace('-', '_'), False)
            assert theclass
            template = theclass(eae_config.input_style, eae_config.output_style, info['tokens'], info['event type'], eae_config.lang, info)
            
            pred_object = template.decode(p_text)

            for span, role_type, _ in pred_object:
                if eae_config.lang == "chinese":
                    sid, eid = get_span_idx2(batch.tokens[bid], span, trigger_span=info['trigger span'])
                else:
                    sid, eid = get_span_idx(eae_config.lang, batch.pieces[bid], batch.piece_idxs[bid], batch.token_start_idxs[bid], span, eae_tokenizer, no_bos, trigger_span=info['trigger span'])
                if sid == -1:
                    continue
                p_roles[bid].append(((info['trigger span']+(info['event type'],)), (sid, eid, role_type)))
            
            g_eae_outputs[bid].append(g_text)
            p_eae_outputs[bid].append(p_text)
            gold_target_len[len(eae_tokenizer.tokenize(g_text))] += 1
            pred_target_len[len(eae_tokenizer.tokenize(p_text))] += 1
            pred_item_len[len(pred_object)] += 1
            gold_item_len[len(template.decode(g_text))] += 1

    p_roles = [sorted(set(role)) for role in p_roles]
    
    test_gold_triggers.extend(batch.triggers)
    test_gold_roles.extend(batch.roles)
    test_pred_triggers.extend(p_triggers)
    test_pred_roles.extend(p_roles)
    test_pred_wnd_ids.extend(batch.wnd_ids)
    test_gold_eae_outputs.extend(g_eae_outputs)
    test_pred_eae_outputs.extend(p_eae_outputs)
    test_eae_inputs.extend(p_eae_inputs)
            
progress.close()
print("gold target output length distribution: (length, count)")
print(sorted(gold_target_len.items()))
print("prediction length distribution: (length, count)")
print(sorted(pred_target_len.items()))

print("gold target item distribution: (role num, count)")
print(sorted(gold_item_len.items()))
print("prediction item distribution: (role num, count)")
print(sorted(pred_item_len.items()))


# calculate scores
dev_scores = cal_scores(test_gold_triggers, test_pred_triggers, test_gold_roles, test_pred_roles)

print("---------------------------------------------------------------------")
print('Trigger I  - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
    dev_scores['tri_id'][3] * 100.0, dev_scores['tri_id'][2], dev_scores['tri_id'][1], 
    dev_scores['tri_id'][4] * 100.0, dev_scores['tri_id'][2], dev_scores['tri_id'][0], dev_scores['tri_id'][5] * 100.0))
print('Trigger C  - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
    dev_scores['tri_cls'][3] * 100.0, dev_scores['tri_cls'][2], dev_scores['tri_cls'][1], 
    dev_scores['tri_cls'][4] * 100.0, dev_scores['tri_cls'][2], dev_scores['tri_cls'][0], dev_scores['tri_cls'][5] * 100.0))
print("---------------------------------------------------------------------")
print('Role I     - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
    dev_scores['arg_id'][3] * 100.0, dev_scores['arg_id'][2], dev_scores['arg_id'][1], 
    dev_scores['arg_id'][4] * 100.0, dev_scores['arg_id'][2], dev_scores['arg_id'][0], dev_scores['arg_id'][5] * 100.0))
print('Role C     - P: {:6.2f} ({:4d}/{:4d}), R: {:6.2f} ({:4d}/{:4d}), F: {:6.2f}'.format(
    dev_scores['arg_cls'][3] * 100.0, dev_scores['arg_cls'][2], dev_scores['arg_cls'][1], 
    dev_scores['arg_cls'][4] * 100.0, dev_scores['arg_cls'][2], dev_scores['arg_cls'][0], dev_scores['arg_cls'][5] * 100.0))
print("---------------------------------------------------------------------")

if args.write_output:
    outputs = {}
    for (test_pred_wnd_id, test_pred_trigger, test_gold_role, test_pred_role, test_gold_eae_output, test_pred_eae_output, test_eae_input) in zip(
        test_pred_wnd_ids, test_pred_triggers, test_gold_roles, test_pred_roles, test_gold_eae_outputs, test_pred_eae_outputs, test_eae_inputs):
        outputs[test_pred_wnd_id] = {
            "input": test_eae_input, 
            "triggers": test_pred_trigger,
            "roles": [(t[0], t[1][:2]+('', ), t[1][2]) for t in test_pred_role],
            "g_roles": test_gold_role,
            "g_text": test_gold_eae_output,
            "p_text": test_pred_eae_output,
            "entities": [], 
            "entity coreference": [], 
            "tokens": [], 
        }
        
    if not os.path.exists(args.write_output):
        os.makedirs(args.write_output)
        
    with open(os.path.join(args.write_output, 'test.pred.json'), 'w') as fp:
            json.dump(outputs, fp, indent=2)