import torch

torch.multiprocessing.set_sharing_strategy('file_system')
import torch.nn as nn
import torch.nn.functional as F
import json
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, GPT2PreTrainedModel, GPT2Model, AutoConfig, \
    LogitsProcessorList
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import copy
from time import sleep
from torch.optim import Adam
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import seed_everything
import multiprocessing
import argparse
import os, sys
import nltk
from nltk import word_tokenize, sent_tokenize
from itertools import chain
from collections import defaultdict

# from multiprocessing import freeze_support
# freeze_support()
from transformers.modeling_outputs import CausalLMOutputWithPastAndCrossAttentions
from bert_score import BERTScorer
from typing import Optional, Union
from util.graph_relation_old import get_conceptnet, KnowledgeGraph
from util.arg import display_argument
from preprocess.extract_emotion_event import extract_emotion_event_at_least_one


def get_idf_sents():
    def get_idf_docs():
        if os.path.exists('train_doc_idf.json'):
            with open('train_doc_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:
                for entry in scene['entries']:
                    hyps.append(entry['description'])
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    refs.append(card['description'])
            with open('train_doc_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')

            return hyps + refs

    def get_idf_split_sent():
        if os.path.exists('train_sent_idf.json'):
            with open('train_sent_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    # hyps.append(entry['description'])
                    hyps.extend(sent_tokenize(entry['description']))
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    # refs.append(card['description'])
                    refs.extend(sent_tokenize(card['description']))

            with open('train_sent_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')
            return hyps + refs

    # return get_idf_split_sent()
    return get_idf_docs()


def preprocess(scorer):
    def find_peak_sent(card_text, text):
        sents = nltk.sent_tokenize(text)
        # max_score = -1
        # max_idx = -1
        hyps = []
        refs = []
        for i in range(len(sents)):
            sent = sents[i]
            # sent = ' '.join(sents[:i + 1])
            hyps.append(sent)
            refs.append(card_text)

        p, r, f = scorer.score(hyps, refs)
        _, indices = torch.max(r, dim=0)
        max_idx = indices.item()
        # new_sents = ['<|beginofsentence|>' + sent for sent in sents]
        # assert max_idx != -1
        return max_idx

    def process_one_file(in_file_name, out_file_name):
        with open(in_file_name, encoding='utf-8') as f:
            a = json.load(f)
            for scene in tqdm(a):
                text = scene['entries'][-1]['description']
                cards = []
                for card in scene['entries'][-1]['cards']:
                    cards.append(card['description'])
                card_text = ' '.join(cards)
                sent_idx = find_peak_sent(card_text, text)
                # scene['entries'][-1]['description'] = new_text
                scene['peak_idx'] = sent_idx
                scene.pop('character')
                scene.pop('extra_cards')

        with open(out_file_name, 'w', encoding='utf-8') as f:
            json.dump(a, f, indent=1, ensure_ascii=False)

        print('finish process', in_file_name)

    process_one_file('test.json', 'test_peak.json')
    # exit()
    process_one_file('valid.json', 'valid_peak.json')
    process_one_file('train.json', 'train_peak.json')


class Helper():
    def __init__(self):
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2_en_ckpt_origin')
        # self.eoe_token = '<|endofentry|>'
        self.eoc_token = '<|endofcard|>'
        # self.soc_token = '<|sepofcard|>'
        # self.seg_text_token = '<|seg_text|>'
        # self.seg_card_token = '<|seg_card|>'
        # self.bos_token = '<|beginofsentence|>'
        self.bot_token = '<|beginoftarget|>'
        self.eot_token = '<|endoftarget|>'
        self.boo_token = '<|beginofoutline|>'
        self.bob_token = '<|beginofbedding|>'
        self.boe_token = '<|beginofending|>'
        self.soo_token = '<|sepofoutline|>'
        self.soos_token = '<|sepofoutlinesent|>'
        self.eop_token = '<|endofprompt|>'
        self.son_token = '<|sepofname|>'
        self.not_a_fact = 'NOT_A_FACT'
        # self.emo_token = '<|emotion|>'
        # self.eve_token = '<|event|>'
        self.tokenizer.add_tokens(
            # [self.eoe_token, self.seg_card_token, self.seg_text_token, self.eoc_token, self.soc_token])
            [self.eop_token, self.eoc_token, self.bot_token, self.eot_token, self.bob_token, self.boe_token,
             self.boo_token, self.soo_token, self.soos_token])#, self.son_token])

        new_token = self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        # print('new added token:', new_token)
        # self.tokenizer.pad_token = self.tokenizer.eos_token
        # self.tokenizer.pad_token = self.tokenizer.eos_token
        # print(self.tokenizer.encode('Hello World!'))
        # pass

    def get_token_id(self, token):
        return self.tokenizer.convert_tokens_to_ids(token)

    # def encode(self, text):
    #     # add <|endoftext|> after each sentence
    #     return self.tokenizer.encode(self.add_eos(text))
    #
    # def add_eos(self, text):
    #     return text.replace('.', '. <|endoftext|>').replace('!', '! <|endoftext|>').replace('?', '? <|endoftext|>')

    def pad_seq(self, ids, max_len):
        return self.tokenizer.pad(ids, max_len)

    def call_tokenizer(self, text):
        return self.tokenizer(text)

    def get_vocab_size(self):
        return len(self.tokenizer)


helper = Helper()
graph = None


class WordTokenizer():

    def __init__(self, file_path=''):
        with open(file_path, encoding='utf-8') as f:
            self.word2ids = json.load(f)

    def encode(self, word):
        try:
            result = self.word2ids[word]
        except:
            result = helper.tokenizer.encode(word, add_special_tokens=False)
            self.word2ids[word] = result
        return result
        # if word in self.word2ids:
        #     return self.word2ids[word]
        # else:
        #     w = helper.tokenizer.encode(word, add_special_tokens=False)
        #     self.word2ids[word] = w
        #     return w


wordTokenizer = WordTokenizer(file_path='vocab/outline_ids.json')


def get_nodes_dis(words, device=None, in_graph=True):
    if not in_graph:
        if device is None:
            return torch.zeros(helper.get_vocab_size(), dtype=torch.float)
        else:
            return torch.zeros(helper.get_vocab_size(), dtype=torch.float, device=device)

    seq = helper.soo_token.join(words)
    ids = helper.tokenizer.encode(seq)
    # print(f"ids len:{len(ids)}")
    if device is None:
        res = torch.zeros(helper.get_vocab_size(), dtype=torch.float)
    else:
        res = torch.zeros(helper.get_vocab_size(), dtype=torch.float, device=device)
    res[ids] = 1
    # res[helper.get_token_id(helper.bob_token)] = 1
    # res[helper.get_token_id(helper.soo_token)] = 1
    # res[helper.get_token_id(helper.soos_token)] = 1

    return (1 - res) * (-1e10)


class Gpt2OutLineModel(GPT2LMHeadModel):

    @property
    def wte(self):
        return self.transformer.wte

    def __init__(self, config):
        super().__init__(config)
        # print(super())
        # print(Gpt2MemoryModel.mro())
        assert self.lm_head is not None
        self.outline_classify_head = nn.Linear(3 * config.n_embd, 2, bias=False)
        self.outline_wquery = nn.Linear(config.n_embd, config.n_embd, bias=True)
        self.outline_wvalue = nn.Linear(config.n_embd, config.n_embd, bias=True)
        self.word_wkey = nn.Linear(config.n_embd, config.n_embd, bias=True)
        self.word_wvalue = nn.Linear(config.n_embd, config.n_embd, bias=True)
        self.outline_lm_head = nn.Linear(3 * config.n_embd, helper.get_vocab_size(), bias=False)
        self.relation_num = graph.get_relation_size()  # 得到关系个数
        print('relation num = ', self.relation_num)
        self.relation_tensor = torch.nn.Embedding(self.relation_num, config.n_embd)
        self.Wh = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.Wt = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.Wr = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.Wk = nn.Linear(2 * config.n_embd, config.n_embd, bias=False)
        self.dis_matrix = nn.Linear(3 * config.n_embd, config.n_embd, bias=False)
        # print(self.get_concept_embedding('NOT_A_FACT'))
        print('Gpt2OutLineModel init (with config)!')

    def get_graph_vector(self, word):
        raise Exception('get_graph_vector() not implented!')
        triples = graph.get_triples(word)  # 不在图谱里的词，连接到NOT_A_FACT
        # ts = torch.stack(embedding[pieces], dim=0)
        # return torch.mean(ts, dim=0)
        head_embs = []
        tail_embs = []
        betas = []

        return torch.ones(size=[len(words), 768 * 2], device=f'cuda:{args.gpu}')
        embedding_matrix = self.get_input_embeddings().weight  # load一次

        def get_concept_embedding(word):
            # pieces = helper.tokenizer.encode(word, add_special_tokens=False)
            pieces = [0]
            ts = embedding_matrix[pieces]
            return torch.mean(ts, dim=0)

        for h, r, t in triples:
            h_emb = get_concept_embedding(h)
            r_emb = self.relation_tensor(r)
            t_emb = get_concept_embedding(t)
            head_embs.append(h_emb)
            tail_embs.append(t_emb)
            beta = self.Wr(r_emb).dot(torch.tanh(self.Wh(h_emb) + self.Wt(t_emb)))
            betas.append(beta)
        return torch.ones(size=[len(words), 768 * 2], device=f'cuda:{args.gpu}')
        betas = torch.stack(betas)
        alphas = F.softmax(betas, dim=0)
        head_embs = torch.stack(head_embs, dim=0)
        tail_embs = torch.stack(tail_embs, dim=0)
        concat_ts = torch.cat([head_embs, tail_embs], dim=-1)
        alphas = alphas.unsqueeze(dim=-1).repeat(1, concat_ts.shape[-1])
        result = alphas * concat_ts
        return torch.sum(result, dim=0)

    def get_concept_embedding(self, word):
        # return self.ts_ones # 没有这边的backward + 直接return : 6/epoch
        # return [0, 1]
        return wordTokenizer.encode(word)
        pieces = helper.tokenizer.encode(word, add_special_tokens=False)
        # pieces = [0]
        return pieces
        # return embedding_matrix[pieces]
        # ts = embedding_matrix[pieces]
        # return torch.sum(ts, dim=0)

    def get_graph_vectors_words(self, words, device, generate=False):
        # return: [words_len, hidden_size]
        # head_embs = []
        relation_embs = []
        # tail_embs = []
        embedding_matrix = self.get_input_embeddings().weight

        lens = []

        encode_lens = []

        relation_ids = []
        entity_ids = []
        # tail_ids = []
        # return
        # self.ts_ones = torch.ones(768, device='cuda:{}'.format(args.gpu))
        # from util. timer import Timer
        import sys
        # timer = Timer()
        for word in words:
            triples = graph.get_triples(word)
            for h, r, t in triples:
                relation_ids.append(r)
                cur_head_id = self.get_concept_embedding(h)
                cur_tail_id = self.get_concept_embedding(t)
                # entity_ids.append([cur_head_id, cur_tail_id])
                entity_ids.append(cur_head_id)
                entity_ids.append(cur_tail_id)
                # encode_lens.append([len(cur_head_id), len(cur_tail_id)])
                encode_lens.append(len(cur_head_id))
                encode_lens.append(len(cur_tail_id))
                # h_emb = get_concept_embedding(h)
                # r_emb = self.relation_tensor[r]
                # t_emb = get_concept_embedding(t)
                # head_embs.append(h_emb)
                # relation_embs.append(r_emb)
                # tail_embs.append(t_emb)
            lens.append(len(triples))
        # # timer.log(sys._getframe().f_lineno)
        # 取得全部head, tail的embs
        # max_len = max(list(chain(*encode_lens)))
        max_len = max(encode_lens)
        # print('max_len = ', max_len)
        # print('entity_ids = ', entity_ids)
        # [triple_lens, max_word_piece_len]
        # padded_entity_ids = [np.pad(x, (0, max_len - len(x)), 'constant', constant_values=0) for x in entity_ids]
        padded_entity_ids = [x + [0] * (max_len - len(x)) for x in entity_ids]
        # print(len(entity_ids))
        # # timer.log(sys._getframe().f_lineno)
        pad_entity_ids = torch.tensor(padded_entity_ids, device=device)
        embeds = embedding_matrix[pad_entity_ids]  # [ triple_lens, max_word_piece_len, hidden_size]
        # # timer.log(sys._getframe().f_lineno)

        # [triple_lens, max_len]
        mask_emb = np.zeros((len(encode_lens), max_len))
        for data, l in zip(mask_emb, encode_lens):
            data[: l] = 1
        mask_emb = torch.tensor(mask_emb, device=device, dtype=torch.float).unsqueeze(-1)
        # # timer.log(sys._getframe().f_lineno)

        temp = embeds * mask_emb
        real_emb = torch.sum(temp, dim=1)  # [triple_len, hidden_size]
        real_emb = real_emb.reshape((-1, 2, 768))  # [real_triple_len, 2, hidden_size]
        head_embs = real_emb[:, 0, :]
        tail_embs = real_emb[:, 1, :]
        concat_emb = torch.cat([head_embs, tail_embs], dim=1)  # [real_triple_len, hidden_size * 2]

        # 取得全部relation的embs
        relation_embs = self.relation_tensor(torch.tensor(relation_ids, device=device))

        # # timer.log(sys._getframe().f_lineno)

        # head_embs = torch.stack(head_embs, dim=0)
        # relation_embs = torch.stack(relation_embs, dim=0)
        # tail_embs = torch.stack(tail_embs, dim=0)

        # lens = [5]*len(words) # 80/epoch
        # head_embs = embedding_matrix[:len(words)].repeat(5, 1)
        # relation_embs = (self.relation_tensor[0]).repeat(len(words)*5, 1)
        # tail_embs = embedding_matrix[:len(words)].repeat(5, 1)

        # head_tail_embs = torch.cat([head_embs, tail_embs], dim=-1)
        x = self.Wr(relation_embs)
        y = torch.tanh(self.Wh(head_embs) + self.Wt(tail_embs))
        betas = torch.sum(x * y, dim=-1)
        start = 0
        ans = []

        # return torch.ones(size=[len(words), 768 * 2], device=f'cuda:{args.gpu}') # 没有这边的backward: 500 ->50~60 /epoch
        for i, l in enumerate(lens):
            end = start + l
            b = betas[start:end]
            alphas = F.softmax(b, dim=0)
            concat_ts = concat_emb[start:end]
            alphas = alphas.unsqueeze(dim=-1)
            result = alphas * concat_ts
            ans.append(torch.sum(result, dim=0))
            start = end

        if generate:
            return ans
        else:
            return torch.stack(ans, dim=0)

    def compute_all_graph_vectors(self, context_kws, target_kws, outline_kws, device):
        res = []
        # self.ts_ones = torch.ones(768, device='cuda:{}'.format(args.gpu))
        # embedding_matrix = self.get_input_embeddings().weight  # load一次

        for c, t, o in zip(context_kws, target_kws, outline_kws):
            cv = self.get_graph_vectors_words(c, device) if c else None
            tv = self.get_graph_vectors_words(t, device) if t else None
            ov = self.get_graph_vectors_words(o, device) if o else None
            res.append((cv, tv, ov))
        return res

    def get_context_graph_vector(self, hidden_state, words=None, gvs=None, mask=None):
        # hidden_state : [seq_len, hidden_size]
        # graph_vectors : [kws_len, 2 * hidden_size]
        # mask: [seq_len, kws_len]
        # return : [seq_len, hidden_size * 2]
        betas = []
        graph_vectors = []

        if gvs is not None:
            graph_vectors = gvs
        else:
            for word in words:
                graph_vector = self.get_graph_vector(word)
                graph_vectors.append(graph_vector)
            graph_vectors = torch.stack(graph_vectors, dim=0)
        # betas: [seq_len, kws_len] hidden_state和每个word的graph_vector的关联度
        kws_len = graph_vectors.shape[0]
        seq_len = hidden_state.shape[0]
        # hidden_state = hidden_state.repeat([seq_len, kws_len]).reshape(seq_len, kws_len, -1)
        betas = torch.matmul(hidden_state, self.Wk(graph_vectors).T)
        # betas = hidden_state * self.Wk(graph_vectors)
        if mask is not None:
            betas.masked_fill_(mask, -1e10)

        alphas = torch.softmax(betas, dim=-1)
        # alphas: [seq_len, kws_len]
        # graph_vectors: [kws_len, 2 * hidden_size]
        return torch.matmul(alphas, graph_vectors)

    def get_logits(self, hidden_states, input_ids, logits_mask=None, logits_mask_on=None, generate=False,
                   outline_label=None, context_kws=None, target_kws=None, outline_kws=None):
        # a = self.lm_head(hidden_states)
        # print(f"input_ids.size:{input_ids.size()}")
        if not generate:
            start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero().item() for batch in input_ids]
            end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero().item() for batch in input_ids]
            # print(f"a.size:{a.size()},logits_mask.size:{logits_mask.size()}")
            a = []
            for idx, batch in enumerate(hidden_states):
                # print(outline_label[idx].size())
                # print(logits_mask[idx].size())
                before = batch[:start_idxs[idx]]
                after = batch[end_idxs[idx]:]
                x = batch[start_idxs[idx]: end_idxs[idx]]
                x2 = self.get_hidden_combine_kg(x, input_ids[idx][start_idxs[idx]: end_idxs[idx]],
                                                context_kws=context_kws[idx], target_kws=target_kws[idx],
                                                outline_kws=outline_kws[idx], batch_idx=idx)  # 用context + target + 之前
                # x = self.outline_wvalue(x)
                x = torch.cat([x, x2], dim=-1)
                x = self.outline_lm_head(x)
                # print(f"logits_mask:{logits_mask[idx].size()}, outline_label:{outline_label[idx].size()}")
                w = logits_mask[idx] * (outline_label[idx].unsqueeze(dim=-1))
                # print(f"x:{x.size()}, w:{w.size()}")
                x += w
                before = self.lm_head(before)
                after = self.lm_head(after)
                ts = torch.cat([before, x, after], dim=0)
                a.append(ts)
            return torch.stack(a, dim=0)
        else:
            # outline_logits = self.outline_classify_head(hidden_states[:, -1, :])
            # new_hid = self.outline_wvalue(hidden_states)
            res = []
            lm_logits = self.lm_head(hidden_states[:, -1, :])
            for idx, batch in enumerate(hidden_states):
                if logits_mask_on[idx] == False:
                    res.append(lm_logits[idx])
                    continue
                hid = batch[-1].unsqueeze(0)

                # hid2 = self.get_hidden_combine_kg(hidden_states[idx, -1, :].unsqueeze(0), logits_mask[idx])
                hid2 = self.get_context_graph_vector(hid, gvs=torch.stack(self.generated_graph_vectors[idx], dim=0))
                self.hidden_combine_kg_res[idx] = hid2
                combine_hid = torch.cat([hid, hid2], dim=-1).squeeze(0)
                logit = self.outline_classify_head(combine_hid)

                w = self.outline_lm_head(combine_hid)

                if logit[1] > logit[0]:
                    w += logits_mask[idx]

                res.append(w)
            return torch.stack(res, dim=0).unsqueeze(1)

    def get_hidden_combine_kg(self, hidden_states, input_ids, context_kws, target_kws, outline_kws, batch_idx):
        '''
        hidden_states: [seq_len, hidden_size]
        返回一个[seq_len, hidden_size]的向量

        同一个keyword的word piece, 拼接的向量一致

        需要传入context和target的kws, outline可以decode也可以直接传, 图谱要记录relation
        relation向量在__init__中初始化
        '''
        # 得到h_i和t_i的embedding表示
        # return torch.ones(size=[hidden_states.size(0), hidden_states.size(1) * 2], device=hidden_states.device)
        if batch_idx in self.hidden_combine_kg_res:
            return self.hidden_combine_kg_res[batch_idx]
        # cnt = 0
        # def check(kws_list, name):
        #     for w in kws_list:
        #         if not isinstance(w, str):
        #             print(f"{name}: {kws_list}, wrong!")
        # check(context_kws, 'context_kws')
        # check(target_kws, 'target_kws')
        # check(outline_kws, 'outline_kws')

        gvs = torch.cat(
            [self.graph_vectors[batch_idx][i] for i in range(3) if self.graph_vectors[batch_idx][i] is not None], dim=0)
        col_num = gvs.shape[0]
        cnt = 0
        for i in range(2):
            if self.graph_vectors[batch_idx][i] is not None:
                cnt += self.graph_vectors[batch_idx][i].size(0)
        # col_num = len(self.graph_vectors[batch_idx][0]) + len(self.graph_vectors[batch_idx][1]) + len(
        #     self.graph_vectors[batch_idx][2])
        mask_lens = []

        for idx, (hidden_state, input_id) in enumerate(zip(hidden_states, input_ids)):
            # 每遇到一个soo/soos, idx+1
            if input_id in [helper.get_token_id(helper.soo_token), helper.get_token_id(helper.soos_token)]:
                cnt += 1
            mask_lens.append(cnt)
            # context_vector = self.get_context_graph_vector(hidden_state, context_kws + target_kws + outline_kws[:cnt])
            # context_vector = self.get_context_graph_vector(hidden_state,
            #                                                gvs=self.graph_vectors[batch_idx][0] +
            #                                                    self.graph_vectors[batch_idx][1] +
            #                                                    self.graph_vectors[batch_idx][2][:cnt])
            #
            # result.append(context_vector)
        import torch.nn.functional as F
        masks = F.one_hot(torch.tensor(mask_lens, device=hidden_states.device, dtype=torch.long), col_num + 1)
        masks = torch.cumsum(masks, dim=-1)[:, :-1].bool()
        # gvs = torch.cat([self.graph_vectors[batch_idx][i] for i in range(3)], dim=0)
        self.hidden_combine_kg_res[batch_idx] = self.get_context_graph_vector(hidden_states, gvs=gvs, mask=masks)

        return self.hidden_combine_kg_res[batch_idx]

    def get_outline_classify_loss(self, hidden_states, input_ids, outline_label, context_kws, target_kws, outline_kws):
        start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero().item() for batch in input_ids]
        end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero().item() for batch in input_ids]

        # res = None
        hid_ts = None
        label_ts = None
        hid2_ts = None
        # new_hidden_states = self.outline_wvalue(hidden_states)
        new_hidden_states = hidden_states
        for idx, batch in enumerate(hidden_states):
            hid = batch[start_idxs[idx]: end_idxs[idx]]

            if hid_ts is None:
                hid_ts = new_hidden_states[idx][start_idxs[idx]: end_idxs[idx]]
                hid2_ts = self.get_hidden_combine_kg(hid, input_ids[idx][start_idxs[idx]: end_idxs[idx]],
                                                     context_kws=context_kws[idx], target_kws=target_kws[idx],
                                                     outline_kws=outline_kws[idx], batch_idx=idx)
            else:
                hid_ts = torch.cat([hid_ts, new_hidden_states[idx][start_idxs[idx]: end_idxs[idx]]], dim=0)
                hid2_ts = torch.cat([hid2_ts,
                                     self.get_hidden_combine_kg(hid, input_ids[idx][start_idxs[idx]: end_idxs[idx]],
                                                                context_kws=context_kws[idx],
                                                                target_kws=target_kws[idx],
                                                                outline_kws=outline_kws[idx], batch_idx=idx)], dim=0)

            if label_ts is None:
                label_ts = outline_label[idx]
            else:
                label_ts = torch.cat([label_ts, outline_label[idx]])

            # loss_fct = nn.CrossEntropyLoss()
            # logit = self.outline_classify_head(hid)
            # if res is None:
            #     res = loss_fct(logit, outline_label[idx])
            # else:
            #     res += loss_fct(logit, outline_label[idx])
        # print(f"hid_ts size:{hid_ts.size()}, label_ts size:{label_ts.size()}")
        new_hid_ts = torch.cat([hid_ts, hid2_ts], dim=-1)
        new_hid_ts = self.outline_classify_head(new_hid_ts)
        loss_fct = nn.CrossEntropyLoss()
        return loss_fct(new_hid_ts, label_ts), new_hid_ts

    def combine_relative_dis(self, hidden_states, input_ids, lm_logits, generate, logits_mask_on=None):

        if not generate:
            start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero().item() for batch in input_ids]
            end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero().item() for batch in input_ids]

            target_start_idxs = [(batch == helper.get_token_id(helper.bot_token)).nonzero().item() for batch in
                                 input_ids]
            target_end_idxs = start_idxs
        word_embeddings = self.get_input_embeddings().weight

        res = []
        # lm_logits = torch.softmax(lm_logits, dim=-1)
        for idx, batch in enumerate(hidden_states):
            if generate and logits_mask_on[idx] == False:
                res.append(lm_logits[idx])
                continue
            if not generate:
                target_ts = torch.mean(batch[target_start_idxs[idx]: target_end_idxs[idx]], dim=0)
            else:
                # print(helper.tokenizer.decode(input_ids[idx].tolist()))
                # print((input_ids[idx] == helper.get_token_id(helper.eot_token)).nonzero())
                # print((input_ids[idx] == helper.get_token_id(helper.bot_token)).nonzero())
                target_ts = torch.mean(torch.stack(self.target_hidden_states[idx], dim=0), dim=0)
                # target_start = (input_ids[idx] == helper.get_token_id(helper.bot_token)).nonzero().item()
                # target_end = (input_ids[idx] == helper.get_token_id(helper.eot_token)).nonzero().item()
                # target_ts = torch.mean(batch[target_start: target_end], dim=0)
            hid_combine_ts = self.hidden_combine_kg_res[idx]
            # [seq_len, hidden_size * 3]
            concat_ts = torch.cat([hid_combine_ts, target_ts.unsqueeze(0).repeat(hid_combine_ts.size(0), 1)], dim=-1)
            d = torch.matmul(self.dis_matrix(concat_ts), word_embeddings.T)
            # d = torch.softmax(d, dim=-1)
            if not generate:
                before = lm_logits[idx][:start_idxs[idx]]
                mid = lm_logits[idx][start_idxs[idx]:end_idxs[idx]]
                after = lm_logits[idx][end_idxs[idx]:]
                # mid_combine = (d + mid) / 2
                mid_combine = d + mid
                res.append(torch.cat([before, mid_combine, after], dim=0))
            else:
                res.append(lm_logits[idx][-1].unsqueeze(0) + d)

        return torch.stack(res, dim=0)

    def forward(
            self,
            input_ids=None,
            past_key_values=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            labels=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            # sents_dis=None,
            # peak_idx=None,
            # outline_dis=None,
            # return_outline_dis=None,
            # topk=20,
            context_kws=None,
            target_kws=None,
            outline_kws=None,

            logits_mask=None,
            logits_mask_on=None,
            generate=False,
            outline_label=None,
            train_step=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # from util.timer import Timer
        # timer = Timer()
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        if generate:
            for idx, hidden_state in enumerate(hidden_states):
                if self.target_flag[idx]:
                    self.target_hidden_states[idx].append(hidden_state[-1, :])
        # print(f"input_ids:{input_ids.size()},hidden_states:{hidden_states.size()}")

        # Set device for model parallelism
        # if self.model_parallel:
        #     torch.cuda.set_device(self.transformer.first_device)
        #     hidden_states = hidden_states.to(self.lm_head.weight.device)
        import sys

        # timer.log(sys._getframe().f_lineno)
        if not generate:
            self.graph_vectors = self.compute_all_graph_vectors(context_kws=context_kws, target_kws=target_kws,
                                                                outline_kws=outline_kws,
                                                                device=hidden_states.device)
        # timer.log(sys._getframe().f_lineno)
        self.hidden_combine_kg_res = {}
        lm_logits = self.get_logits(hidden_states, input_ids, logits_mask,
                                    logits_mask_on, generate, outline_label, context_kws=context_kws,
                                    target_kws=target_kws, outline_kws=outline_kws)
        # if train_step is not None and train_step > 100:
        # 应该先和d组合，再mask，而不是先mask再组合
        # 但是这里不需要改，因为logit+d+mask == logit+mask+d
        lm_logits = self.combine_relative_dis(hidden_states, input_ids, lm_logits, generate, logits_mask_on)
        loss = None
        loss2 = None
        cls_logits = None
        if labels is not None:
            # Shift so that tokens < n predict n
            # lm_logits = torch.softmax(lm_logits, dim=-1)
            # lm_logits = torch.log(lm_logits)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            # if train_step is not None and train_step > 100:
            # loss_fct = nn.NLLLoss()
            # else:
            loss_fct = nn.CrossEntropyLoss()

            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            # print('lm loss:', loss.item())
            loss2, cls_logits = self.get_outline_classify_loss(hidden_states, input_ids, outline_label,
                                                               context_kws=context_kws, target_kws=target_kws,
                                                               outline_kws=outline_kws)
            self.hidden_combine_kg_res = {}
            # timer.log(sys._getframe().f_lineno)
            # print(f"loss:{loss.item()},loss2:{loss2.item()}")
            loss += loss2
            # endoftarget_idxs = [
            #     (batch == helper.tokenizer.convert_tokens_to_ids(helper.eot_token)).nonzero()[0][0].item()
            #     for batch in input_ids]
            # lam = 0.05
            # outline_hid = torch.stack(
            #     [hidden_states[batch_idx][idx] for batch_idx, idx in enumerate(endoftarget_idxs)])
            # outline_logits = torch.log_softmax(self.outline_head(outline_hid), dim=-1)
            # # print("outline_dis:", outline_dis.nonzero())
            # # print("outline_logits:", outline_logits)
            # outline_loss = torch.sum(-outline_dis * outline_logits) / outline_hid.size(0)
            # # loss += lam * torch.sum(-emotion_dis * emotion_logits)
            # # print('outline_loss:', outline_loss.item())
            # loss += lam * outline_loss
            # print('overall loss:', loss.item())
            # print('lm loss:', loss.item())
        # print("loss2 = ", loss2)
        if not return_dict:
            output = (lm_logits,) + (cls_logits,) + transformer_outputs[1:]
            return ((loss, loss2) + output) if loss is not None else output

        return CausalLMOutputWithPastAndCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

    def init_state(self, input_ids):
        batch_size = input_ids.size(0)
        self.logits_mask_on = [False] * batch_size
        self.logits_mask = torch.zeros([batch_size, helper.get_vocab_size()], dtype=torch.float,
                                       device=input_ids.device)
        self.target_flag = [False] * batch_size
        self.target_hidden_states = [[] for _ in range(batch_size)]
        self.computed = [False] * batch_size
        self.last_spec_pos = [-1] * batch_size
        self.generated_outline = [[] for _ in range(batch_size)]
        self.generated_graph_vectors = [[] for _ in range(batch_size)]
        self.generated_target_kws = [[] for _ in range(batch_size)]

    def get_target_kws(self, input_ids):
        # input_ids: [seq_len]
        start_idx = (input_ids == helper.get_token_id(helper.bot_token)).nonzero()[0][0].item()
        end_idx = (input_ids == helper.get_token_id(helper.eot_token)).nonzero()[0][0].item()
        id = input_ids[start_idx + 1: end_idx].tolist()
        sent = helper.tokenizer.decode(id)
        words, _ = extract_emotion_event_at_least_one(sent, per_sent=True)
        assert isinstance(words[0], list)
        words = list(chain(*words))
        return words

    def update_state(self, input_ids, context_kws, device):
        # when <|endoftarget|> appears, compute intersect nodes and turn on logits_mask
        # when <|beginofbedding|> appears, turn off logits_mask
        for idx, batch in enumerate(input_ids):
            if batch[-1].item() == helper.get_token_id(helper.bot_token):
                self.target_flag[idx] = True

            if batch[-1].item() == helper.get_token_id(helper.eot_token):
                self.last_spec_pos[idx] = len(batch) - 1
                self.logits_mask_on[idx] = True
                self.target_flag[idx] = False

                self.generated_target_kws[idx] = self.get_target_kws(batch)
                self.generated_graph_vectors[idx] = self.get_graph_vectors_words(
                    context_kws[idx] + self.generated_target_kws[idx], device=device, generate=True)

            if self.logits_mask_on[idx] and (batch[-1].item() == helper.get_token_id(helper.soo_token) or batch[
                -1].item() == helper.get_token_id(helper.soos_token)) and self.computed[idx] == False:
                # self.logits_mask_on[idx] = True
                word_pieces = batch[self.last_spec_pos[idx] + 1:-1]
                word = helper.tokenizer.decode(word_pieces.tolist())
                self.generated_outline[idx].append(word)

                # 计算刚采样出的word的graph_vector, 维护一下
                # self.generated_graph_vectors[idx].append(self.get_graph_vector(word))
                self.generated_graph_vectors[idx].append(
                    self.get_graph_vectors_words([word], device=device, generate=True)[0])
                # if not self.generated_target_kws[idx]:
                #     self.generated_target_kws[idx] = self.get_target_kws(batch) # 只算一遍
                assert self.generated_target_kws[idx]
                target_kws = self.generated_target_kws[idx]
                words = set(self.generated_outline[idx])
                words |= set(target_kws + context_kws[idx])
                words = graph.get_hops_set(words, hop=1)  # 用context + target + already generated outline的1-hop
                self.logits_mask[idx] = get_nodes_dis(words, device=input_ids.device)
                self.last_spec_pos[idx] = len(batch) - 1

            elif batch[-1].item() == helper.get_token_id(helper.bob_token):
                self.logits_mask_on[idx] = False
                self.computed[idx] = True

    def sample(
            self,
            input_ids: torch.LongTensor,
            logits_processor: Optional[LogitsProcessorList] = None,
            logits_warper: Optional[LogitsProcessorList] = None,
            max_length: Optional[int] = None,
            pad_token_id: Optional[int] = None,
            eos_token_id: Optional[int] = None,
            **model_kwargs
    ):
        r"""
        Generates sequences for models with a language modeling head using multinomial sampling.

        Parameters:

            input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
                The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
                :obj:`torch.LongTensor` of shape :obj:`(1,)`.
            logits_processor (:obj:`LogitsProcessorList`, `optional`):
                An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
                :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
                head applied at each generation step.
            logits_warper (:obj:`LogitsProcessorList`, `optional`):
                An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
                :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
                modeling head applied before multinomial sampling at each generation step.
            max_length (:obj:`int`, `optional`, defaults to 20):
                The maximum length of the sequence to be generated.
            pad_token_id (:obj:`int`, `optional`):
                The id of the `padding` token.
            eos_token_id (:obj:`int`, `optional`):
                The id of the `end-of-sequence` token.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
                model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.

        Return:
            :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
            batches finished early due to the :obj:`eos_token_id`.

        Examples::

            >>> from transformers import (
            ...    AutoTokenizer,
            ...    AutoModelForCausalLM,
            ...    LogitsProcessorList,
            ...    MinLengthLogitsProcessor,
            ...    TopKLogitsWarper,
            ...    TemperatureLogitsWarper,
            ... )

            >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
            >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

            >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
            >>> model.config.pad_token_id = model.config.eos_token_id

            >>> input_prompt = "Today is a beautiful day, and"
            >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

            >>> # instantiate logits processors
            >>> logits_processor = LogitsProcessorList([
            ...     MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
            ... ])
            >>> # instantiate logits processors
            >>> logits_warper = LogitsProcessorList([
            ...     TopKLogitsWarper(50),
            ...     TemperatureLogitsWarper(0.7),
            ... ])

            >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)

            >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
        """

        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        # init sequence length tensors
        sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
            input_ids, max_length
        )

        # auto-regressive generation
        context_kws = model_kwargs['context_kws']
        self.init_state(input_ids)
        flag = False
        while cur_len < max_length:
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            self.update_state(input_ids, context_kws, f"cuda:{args.gpu}")
            # forward pass to get next token
            outputs = self(**model_inputs, return_dict=True, generate=True, logits_mask=self.logits_mask,
                           logits_mask_on=self.logits_mask_on)
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            scores = logits_processor(input_ids, next_token_logits)
            scores = logits_warper(input_ids, scores)

            # sample
            probs = F.softmax(scores, dim=-1)

            if False:
                print('next_token_logits.shape = ', probs.shape)
                print(f'cur_len = {cur_len}')
                temp = torch.topk(probs[0], 100)
                for value, index in zip(temp.values, temp.indices):
                    print(f'token = {helper.tokenizer.convert_ids_to_tokens([index])}, logits = {value}')
                    if value == 0:
                        break

            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            if helper.get_token_id(helper.eot_token) == next_tokens[0]:
                flag = True
            elif helper.get_token_id(helper.bob_token) == next_tokens[0]:
                flag = False

            # add code that transfomers next_tokens to tokens_to_add
            if eos_token_id is not None:
                assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
                next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)

            # add token and increase length by one
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

            cur_len = cur_len + 1

            # update sequence length
            if eos_token_id is not None:
                sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
                    sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
                )

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sequences.max() == 0:
                break

            # update model kwargs
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )

        return input_ids


class StoriumDataset(Dataset):

    def __init__(self, in_file, chunk=None, only_target=False, only_outline=False, only_bedding=False,
                 only_ending=False,
                 only_outlinebedding=False):
        super().__init__()
        self.in_file = in_file
        with open(self.in_file, encoding='utf-8') as f:
            self.a = json.load(f)
            if chunk:
                self.a = self.a[:chunk]
        self.vocab_size = len(helper.tokenizer)
        self.only_target = only_target
        self.only_outline = only_outline
        self.only_bedding = only_bedding
        self.only_ending = only_ending
        self.only_outlinebedding = only_outlinebedding

    def __len__(self):
        return len(self.a)

    def get_dis_from_words_list(self, words):
        if len(words) == 0:
            return torch.ones(self.vocab_size, dtype=torch.float) / self.vocab_size
        con_words = ''
        for word in words:
            con_words += ' ' + word
        # words = [' ' + word for word in words]
        # print("words:", con_words)
        word_ids = helper.tokenizer.encode(con_words)
        scores = torch.arange(len(word_ids), 0, -1, dtype=torch.float)
        scores /= len(word_ids)
        # scores = torch.softmax(scores, dim=-1)
        res = torch.full([self.vocab_size], -1e10, dtype=torch.float)
        # eps = 1e-8
        # res = torch.full([self.vocab_size], eps, dtype=torch.float)
        # print(res)
        for idx, score in enumerate(scores):
            if res[word_ids[idx]] < 0:
                res[word_ids[idx]] = score
            else:
                res[word_ids[idx]] += score
        # print("word_ids:", word_ids)
        # print(res)
        return torch.softmax(res, dim=-1)

    def get_dis_from_sent(self, sent):
        word_ids = helper.tokenizer.encode(sent)
        scores = torch.arange(len(word_ids), 0, -1, dtype=torch.float)
        scores /= len(word_ids)
        # scores = torch.softmax(scores, dim=-1)
        res = torch.full([self.vocab_size], -1e10, dtype=torch.float)
        # eps = 1e-8
        # res = torch.full([self.vocab_size], eps, dtype=torch.float)
        for idx, score in enumerate(scores):
            if res[word_ids[idx]] < 0:
                res[word_ids[idx]] = score
            else:
                res[word_ids[idx]] += score

        return torch.softmax(res, dim=-1)

    def __getitem__(self, idx):
        # with open(self.in_file, encoding='utf-8') as f:
        #     a = json.load(f)
        #     item = a[idx]
        item = self.a[idx]
        input = ''
        # input += item['entries'][-1]['role'] + helper.son_token
        for card in item['entries'][-1]['cards']:
            input += card['description']
        input += helper.eoc_token
        entri_des = []
        for entry in item['entries'][:-1]:
            entri_des.append(entry['description'])
        input += ' '.join(entri_des)
        input += helper.eop_token
        # input += item['entries'][0]['description'] + helper.eoe_token
        sents = nltk.sent_tokenize(item['entries'][-1]['description'])
        peak_idx = item['peak_idx']
        output = ''

        keywords_list = item['bedding_kws'] + item['ending_kws']
        old_keywords_list = item['bedding_kws'] + item['ending_kws']
        context_keywords_list = item['context_kws'] + list(chain(*item['target_kws']))
        if self.only_outline:
            input += helper.bot_token + sents[peak_idx] + helper.eot_token
            output += helper.soo_token.join(keywords_list)
            output += helper.bob_token
        elif self.only_bedding:
            input += helper.bot_token + sents[peak_idx] + helper.eot_token
            input += helper.soo_token.join(keywords_list)
            input += helper.bob_token
            output += ' '.join(sents[:peak_idx])
            output += helper.boe_token
        elif self.only_outlinebedding:
            input += helper.bot_token + sents[peak_idx] + helper.eot_token
            output += helper.soo_token.join(keywords_list)
            output += helper.bob_token
            output += ' '.join(sents[:peak_idx])
            output += helper.boe_token
            output += ' '.join(sents[peak_idx + 1:])
            output += helper.tokenizer.eos_token
        elif self.only_ending:
            input += helper.bot_token + sents[peak_idx] + helper.eot_token
            input += helper.soo_token.join(keywords_list)
            input += helper.bob_token
            input += ' '.join(sents[:peak_idx])
            input += helper.boe_token
            output += ' '.join(sents[peak_idx + 1:])
            output += helper.tokenizer.eos_token
        else:
            output += helper.bot_token + sents[peak_idx] + helper.eot_token
            if not self.only_target:
                for i in range(len(keywords_list)):
                    keywords_list[i] = helper.soo_token.join(keywords_list[i])

                output += helper.soos_token.join(keywords_list)
                # print(output)

                output += helper.bob_token
                output += ' '.join(sents[:peak_idx])
                output += helper.boe_token
                output += ' '.join(sents[peak_idx + 1:])
                output += helper.tokenizer.eos_token
        input_ts = torch.tensor(helper.tokenizer.encode(input, add_special_tokens=False), dtype=torch.long)
        output_ts = torch.tensor(helper.tokenizer.encode(output, add_special_tokens=False), dtype=torch.long)

        # assert len(keywords_list) > 0
        nodes_dis = []
        outline_mask = []
        keywords_list = list(chain(*old_keywords_list))
        for i in range(len(keywords_list)):
            words = graph.get_hops_set(keywords_list[:i] + context_keywords_list, hop=1)
            if keywords_list[i] in words:
                outline_mask.append(1)
                nodes_dis.append(get_nodes_dis(words))
                if args.debug:
                    keywords_masks = args.keywords_masks
                    if keywords_list[i] in keywords_masks:
                        keywords_masks[keywords_list[i]][1] += 1
                    else:
                        keywords_masks[keywords_list[i]] = [0, 1]
            else:
                outline_mask.append(0)
                nodes_dis.append(get_nodes_dis(words))
                if args.debug:
                    keywords_masks = args.keywords_masks
                    if keywords_list[i] in keywords_masks:
                        keywords_masks[keywords_list[i]][0] += 1
                    else:
                        keywords_masks[keywords_list[i]] = [1, 0]

            # # sepofoutline and beginofbedding
            # outline_mask.append(0)
            # nodes_dis.append(get_nodes_dis(words=None, in_graph=False))

        outline_mask_piece = []
        node_dis_piece = []
        ori_id = 0
        cat_tensor = torch.cat([input_ts, output_ts], dim=0)
        for i, value in enumerate(cat_tensor):
            if value == helper.get_token_id(helper.eot_token):
                for j, value in enumerate(cat_tensor[i + 1:]):
                    if value == helper.get_token_id(helper.soo_token) or value == helper.get_token_id(
                            helper.soos_token):
                        outline_mask_piece.append(0)  # sepofoutline算不在图谱中，置0
                        node_dis_piece.append(nodes_dis[ori_id])
                        ori_id += 1
                        continue
                    elif value == helper.get_token_id(helper.bob_token):
                        end = j
                        outline_mask_piece.append(0)
                        if len(keywords_list):
                            node_dis_piece.append(nodes_dis[ori_id])
                        else:
                            node_dis_piece.append(get_nodes_dis(words=None, in_graph=False))
                        break
                    else:
                        outline_mask_piece.append(outline_mask[ori_id])
                        node_dis_piece.append(nodes_dis[ori_id])
                break

        node_dis_ts = torch.stack(node_dis_piece, dim=0)
        # print(f"node_dis_ts:{node_dis_ts.size()}")
        # nodes_dis = torch.stack(nodes_dis, dim=0)
        outline_mask_ts = torch.tensor(outline_mask_piece, dtype=torch.long)

        start_idx = (cat_tensor == helper.get_token_id(helper.eot_token)).nonzero().item()
        end_idx = (cat_tensor == helper.get_token_id(helper.bob_token)).nonzero().item()
        ids = cat_tensor[start_idx + 1: end_idx + 1].tolist()
        # print(f"ids:{helper.tokenizer.convert_ids_to_tokens(ids)}")
        # print(f"masks:{outline_mask_ts}")

        context_kws = item['context_kws']
        target_kws = list(chain(*item['target_kws']))
        outline_kws = list(chain(*old_keywords_list))
        return {'input': input_ts, 'output': output_ts, 'peak_idx': peak_idx,
                'nodes_dis': node_dis_ts,
                'context_kws': context_kws, 'target_kws': target_kws, 'outline_kws': outline_kws,
                'outline_mask': outline_mask_ts}


def get_attention_mask(max_len, one_len):
    a = np.ones(max_len)
    a[one_len:] = 0
    return torch.tensor(a, dtype=torch.float)


def loss_mask(ts, start, end):
    ts[:start] = -100
    ts[end:] = -100
    return ts


def pad_collate(batch):
    # batch: list of {'input':xxx,'output':xxx}
    keys = batch[0].keys()
    res = {}

    # max_len = max(map(lambda x: len(x['input']) + len(x['output']), batch))

    # res['input_ids'] = torch.stack(dataset.helper.tokenizer.pad([torch.cat([x['input'], x['output']]).tolist() for x in batch]))
    res['input_ids'] = pad_sequence([torch.cat([x['input'], x['output']])[:1024] for x in batch], batch_first=True,
                                    padding_value=helper.tokenizer.pad_token_id)
    res['attention_mask'] = torch.stack(
        [get_attention_mask(res['input_ids'].size(1), len(x['input']) + len(x['output'])) for x in batch])
    res['labels'] = torch.stack(
        [loss_mask(copy.deepcopy(res['input_ids'][idx]), len(x['input']), len(x['input']) + len(x['output'])) for idx, x
         in enumerate(batch)])
    res['nodes_dis'] = [sample['nodes_dis'] for sample in batch]

    res['peak_idx'] = [sample['peak_idx'] for sample in batch]
    res['context_kws'] = [sample['context_kws'] for sample in batch]
    res['target_kws'] = [sample['target_kws'] for sample in batch]
    res['outline_kws'] = [sample['outline_kws'] for sample in batch]
    res['outline_mask'] = [sample['outline_mask'] for sample in batch]
    return res


class Gpt2(pl.LightningModule):
    def __init__(self, config_dir):
        super().__init__()
        self.model = Gpt2OutLineModel(config=AutoConfig.from_pretrained('{}/config.json'.format(config_dir)))
        self.model.transformer = GPT2Model.from_pretrained('gpt2_en_ckpt_origin')
        # self.model.transformer.load_state_dict(torch.load('gpt2_en_ckpt_origin/pytorch_model.bin'))
        # self.model = Gpt2SegmentedModel.from_pretrained(ckpt_path)
        self.model.resize_token_embeddings(len(helper.tokenizer))
        print(len(helper.tokenizer))
        self.cnt = 0
        # self.memory_model = Memory(len(helper.tokenizer))

    def get_inputs_embeds(self, input_ids, segment_ids=None):
        if segment_ids:
            return self.model.wte(input_ids) + self.model.wte(segment_ids)
        else:
            return self.model.wte(input_ids)

    def forward(self, x):
        # dicts = self.model(x, return_dict=True)
        # hidden_states = dicts['hidden_states']
        # now_hidden_states = hidden_states[:, -1, :]

        return self.model(x)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        # peak_idx = batch['peak_idx']
        logits_mask = batch['nodes_dis']
        outline_mask = batch['outline_mask']

        context_kws = batch['context_kws']
        target_kws = batch['target_kws']
        outline_kws = batch['outline_kws']

        # print(f"batch_idx:{batch_idx}")

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels,
                             logits_mask=logits_mask, outline_label=outline_mask, context_kws=context_kws,
                             target_kws=target_kws, outline_kws=outline_kws, return_dict=False, train_step=batch_idx)

        self.log('train_loss', outputs[0].item())
        self.log('train_classify_loss', outputs[1].item())
        return outputs[0]
        # return None

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        # peak_idx = batch['peak_idx']
        logits_mask = batch['nodes_dis']
        outline_mask = batch['outline_mask']

        context_kws = batch['context_kws']
        target_kws = batch['target_kws']
        outline_kws = batch['outline_kws']

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels,
                             logits_mask=logits_mask, outline_label=outline_mask, context_kws=context_kws,
                             target_kws=target_kws, outline_kws=outline_kws, return_dict=False)

        self.log('val_loss', outputs[0].item())
        self.log('val_classify_loss', outputs[1].item())
        return outputs[0]

    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        # peak_idx = batch['peak_idx']
        logits_mask = batch['nodes_dis']
        outline_mask = batch['outline_mask']

        context_kws = batch['context_kws']
        target_kws = batch['target_kws']
        outline_kws = batch['outline_kws']

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, logits_mask=logits_mask,
                             outline_label=outline_mask, context_kws=context_kws,
                             target_kws=target_kws, outline_kws=outline_kws,
                             return_dict=False)

        start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero().item() for batch in input_ids]
        end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero().item() for batch in input_ids]
        outline_input_ids = None
        for idx, batch in enumerate(input_ids):
            # ids = batch[start_idxs[idx]: end_idxs[idx] + 1]
            ids = batch[start_idxs[idx]: end_idxs[idx]]
            if outline_input_ids is None:
                outline_input_ids = ids
            else:
                outline_input_ids = torch.cat([outline_input_ids, ids], dim=0)

        cls_logits = outputs[3]
        cls_labels = torch.cat(outline_mask, dim=0).view(-1)

        def compute_acc(mask):
            logits = cls_logits[mask]
            labels = cls_labels[mask]

            preds = logits.argmax(-1)
            return (preds == labels).float().mean().item()

        # soo_mask = (outline_input_ids != helper.get_token_id(helper.soo_token))[1:, ...]
        # soos_mask = (outline_input_ids != helper.get_token_id(helper.soos_token))[1:, ...]
        # other_mask = ((outline_input_ids != helper.get_token_id(helper.soo_token))[1:, ...]) & ((
        #                                                                                               outline_input_ids != helper.get_token_id(
        #                                                                                           helper.soos_token))[
        #                                                                                       1:, ...])
        # all_mask = (outline_input_ids != helper.tokenizer.eos_token_id)[1:, ...]
        #
        # soo_acc = compute_acc(soo_mask)
        # soos_acc = compute_acc(soos_mask)
        # other_acc = compute_acc(other_mask)
        # all_acc = compute_acc(all_mask)

        only_soo_soos_mask = (outline_input_ids == helper.get_token_id(helper.soo_token)) | (outline_input_ids == helper.get_token_id(helper.soos_token))
        try:
            only_soo_soos_acc = compute_acc(only_soo_soos_mask)
            self.log_dict({
                'sum_loss': outputs[0].item(),
                'classify_loss': outputs[1].item(),
                # 'soo_cls_acc': soo_acc,
                # 'soos_cls_acc': soos_acc,
                # 'other_cls_acc': other_acc,
                # 'all_acc': all_acc,
                'only_soo_soos_acc': only_soo_soos_acc,
                # 'cnt': cnt.item(),
                # 'label_one': label_one.item(),
                # 'label_all': label_all
            })
        except:
            pass


    def configure_optimizers(
            self,
    ):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)
        # scheduler = get_linear_schedule_with_warmup()
        return optimizer


def train(args):
    batch_size = 1

    valid_dataset = StoriumDataset('./data/valid{}.json'.format(args.train_val_test_suffix), chunk=None)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
                                  num_workers=multiprocessing.cpu_count(), collate_fn=pad_collate)
    # assert 'emotion_dis' in valid_dataset[0]
    train_dataset = StoriumDataset('./data/train{}.json'.format(args.train_val_test_suffix), chunk=None)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                  num_workers=multiprocessing.cpu_count(), collate_fn=pad_collate)
    # train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
    #                               num_workers=1, collate_fn=pad_collate)

    print('after load data')
    num_epochs = 5

    model = Gpt2('gpt2_en_ckpt_origin')
    # model.train()
    # exit()
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', verbose=True)
    earlystop_callback = EarlyStopping(monitor='val_loss', verbose=True, mode='min')
    trainer = pl.Trainer(gpus=[args.gpu], max_epochs=num_epochs, val_check_interval=0.5,
                         callbacks=[checkpoint_callback, earlystop_callback],
                         default_root_dir=args.root_dir,
                         accumulate_grad_batches=8)  # , profiler="pytorch")
    # precision=16, amp_level='O1', distributed_backend='ddp')

    trainer.fit(model=model, train_dataloader=train_dataloader, val_dataloaders=valid_dataloader)


def test(args):
    batch_size = 1
    if args.onlytarget:
        print('only target!')
    if args.onlyoutline:
        print('only outline!')
    if args.onlybedding:
        print('only bedding!')
    if args.onlyending:
        print('only ending')
    if args.onlyoutlinebedding:
        print('only outlinebedding!')
    test_dataset = StoriumDataset('./data/test{}.json'.format(args.train_val_test_suffix),
                                  only_target=args.onlytarget, only_outline=args.onlyoutline,
                                  only_bedding=args.onlybedding, only_ending=args.onlyending,
                                  only_outlinebedding=args.onlyoutlinebedding, chunk=None)
    # if num_workers !=0 then lambda function in pad_collate would cause error (in multiprocessing)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                                 num_workers=multiprocessing.cpu_count(),
                                 collate_fn=pad_collate)

    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    if args.onlytarget:
        model.model.only_target = True
    # model = Gpt2('lightning_logs/version_5/checkpoints/epoch=1.ckpt')
    # model = Gpt2('lightning_logs/version_5/checkpoints')
    trainer = pl.Trainer(gpus=[args.gpu])
    trainer.test(model, test_dataloader)

    print('cnt = ', model.cnt)


def generate(args):
    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    def get_file_name():
        name_with_postfix = os.path.basename(__file__)
        idx = name_with_postfix.rindex('.')
        res = name_with_postfix[:idx] + f'_onecard_version{args.version_num}' + display_argument(args)
        return res

    # only target: version_36
    # global graph
    # graph = get_conceptnet()
    # print('finish load graph!')
    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    # torch.cuda.set_device(args.gpu)
    device = torch.device("cuda:{}".format(args.gpu))
    model.to(device)
    model.model.to(device)
    model.eval()
    model.model.eval()
    test_dataset = StoriumDataset('./data/test{}.json'.format(args.train_val_test_suffix))
    # if num_workers !=0 then lambda function in pad_collate would cause error (in multiprocessing)
    # batch_size = 1
    # test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
    #                              num_workers=multiprocessing.cpu_count(),
    #                              collate_fn=pad_collate)
    data = test_dataset.a
    res = []

    def parallel_generate():
        batch_size = 16
        print(f"out file:{get_file_name()}")

        def prune(x):
            if helper.tokenizer.eos_token_id in x:
                idx = x.index(helper.tokenizer.eos_token_id)
                return x[:idx]
            return x

        ending = len(test_dataset)
        # ending = 16
        for idx in tqdm(range(0, ending, batch_size)):
            # for idx in tqdm(range(1500)):
            # if idx > 10:
            #     break
            end = min(ending, idx + batch_size)
            batch = []
            for j in range(idx, end):
                batch.append(test_dataset[j])
            # if batch['input'].size(0) > 150 or batch['output'].size(0) > 150:
            #     continue
            # input_ids = [sample['input'].to(device) for sample in batch]
            input_ids = [helper.tokenizer.decode(sample['input']) for sample in batch]

            helper.tokenizer.padding_side = "left"
            inputs = helper.tokenizer(input_ids, return_tensors="pt", padding=True)
            context_kws = [sample['context_kws'] for sample in batch]
            output_seqs = model.model.generate(input_ids=inputs['input_ids'].to(device), max_length=1024,
                                               top_p=0.9, temperature=args.temperature, do_sample=True,
                                               attention_mask=inputs['attention_mask'].to(device),
                                               no_repeat_ngram_size=args.norepeatngram,
                                               use_cache=True, context_kws=context_kws)
            # output_seqs = model.model.generate(input_ids=input_ids, max_length=1024,
            #                                    top_k=10, do_sample=False, attention_mask=masks, use_cache=True)

            cards_text = []
            for j in range(idx, end):
                scene = data[j]
                card_text = ''
                for entry in scene['entries'][-1:]:
                    # if entry['role'] == 'narrator':
                    #     continue
                    for card in entry['cards']:
                        card_text += card['description'] + '<end_card>'
                cards_text.append(card_text)
            answer = [helper.tokenizer.decode(sample['output'].tolist(), skip_special_tokens=True) for sample in batch]
            prompt = [helper.tokenizer.decode(sample['input'].tolist(), skip_special_tokens=True) for sample in batch]
            output_text = [
                helper.tokenizer.decode(sample.tolist(), skip_special_tokens=True).replace(prompt[idx], '', 1)
                for idx, sample in enumerate(output_seqs)]

            # prompt = helper.tokenizer.decode(input_ids.squeeze().tolist())
            for j in range(0, end - idx):
                res.append(
                    {'prompt': prompt[j], 'generated': output_text[j], 'answer': answer[j], 'cards': cards_text[j]})

        with open(f'result/{get_file_name()}.json', 'w', encoding='utf-8') as f:
            json.dump(res, f, indent=1, ensure_ascii=False)
        print(f"finish generate to {get_file_name()}")
        # with open('gpt2_persona_guide_onlytarget.json', 'w', encoding='utf-8') as f:
        #     json.dump(res, f, indent=1, ensure_ascii=False)

    parallel_generate()
    exit()

    for idx in tqdm(range(len(test_dataset) // 60)):
        # for idx in tqdm(range(1500)):
        # if idx > 10:
        #     break
        batch = test_dataset[idx]
        # if batch['input'].size(0) > 150 or batch['output'].size(0) > 150:
        #     continue
        scene = data[idx]
        card_text = ''
        for entry in scene['entries'][-1:]:
            # if entry['role'] == 'narrator':
            #     continue
            for card in entry['cards']:
                card_text += card['description'] + '<end_card>'

        input_ids = batch['input'].to(device)
        input_ids = input_ids.unsqueeze(0)

        output_seqs = model.model.generate(input_ids=input_ids, max_length=1024,
                                           top_k=10, do_sample=True)
        # model.model.card_rep = None
        # print(output_seqs.size())
        output_seqs = output_seqs.squeeze()
        # print(output_seqs.size())
        # print(output_seqs)
        prompt = helper.tokenizer.decode(input_ids.squeeze().tolist())

        text = helper.tokenizer.decode(output_seqs.tolist()).replace(prompt, '')

        # print('-' * 40, 'story', idx, '-' * 40)
        # print('prompt:', prompt)
        # print('-' * 40)
        # print('generated:', text.replace(prompt, ''))
        # print('-' * 40)
        answer = helper.tokenizer.decode(batch['output'].tolist())
        # print('answer:', answer)
        res.append({'prompt': prompt, 'generated': text, 'answer': answer, 'cards': card_text})
        # break
    with open(f'{get_file_name()}.json', 'w', encoding='utf-8') as f:
        json.dump(res, f, indent=1, ensure_ascii=False)


def generate_with_true_target(args):
    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    def get_file_name():
        name_with_postfix = os.path.basename(__file__)
        idx = name_with_postfix.rindex('.')
        res = name_with_postfix[:idx] + '_onecard' + display_argument(args)
        return res + '_truetarget'

    # only target: version_36

    # global graph
    # graph = get_conceptnet()
    # print('finish load graph!')

    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    # torch.cuda.set_device(args.gpu)
    device = torch.device("cuda:{}".format(args.gpu))
    model.to(device)
    model.model.to(device)
    model.eval()
    model.model.eval()
    test_dataset = StoriumDataset('./data/test{}.json'.format(args.train_val_test_suffix))
    # if num_workers !=0 then lambda function in pad_collate would cause error (in multiprocessing)
    # batch_size = 1
    # test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
    #                              num_workers=multiprocessing.cpu_count(),
    #                              collate_fn=pad_collate)
    data = test_dataset.a
    res = []

    def parallel_generate():
        batch_size = 16
        print(f"out file:{get_file_name()}")

        def prune(x):
            if helper.tokenizer.eos_token_id in x:
                idx = x.index(helper.tokenizer.eos_token_id)
                return x[:idx]
            return x

        ending = len(test_dataset)
        for idx in tqdm(range(0, ending, batch_size)):
            # for idx in tqdm(range(1500)):
            # if idx > 10:
            #     break
            end = min(ending, idx + batch_size)
            batch = []
            for j in range(idx, end):
                batch.append(test_dataset[j])

            # if batch['input'].size(0) > 150 or batch['output'].size(0) > 150:
            #     continue
            # input_ids = [sample['input'].to(device) for sample in batch]
            # for sample in batch:
            #     endoftarget_idx = sample['output'].tolist().index(helper.get_token_id(helper.eot_token))
            #     sample['output'] = sample['output'][:endoftarget_idx + 1]

            def process(ts):
                # print(sample)
                # print(sample['output'])
                # print(sample['output'].size())
                endoftarget_idx = ts.tolist().index(helper.get_token_id(helper.eot_token))
                return ts[:endoftarget_idx + 1]

            input_ids = [helper.tokenizer.decode(torch.cat([sample['input'], process(sample['output'])])) for sample in
                         batch]
            # input_ids = [helper.tokenizer.decode(sample['input']) for sample in batch]

            helper.tokenizer.padding_side = "left"
            inputs = helper.tokenizer(input_ids, return_tensors="pt", padding=True)
            context_kws = [sample['context_kws'] for sample in batch]
            output_seqs = model.model.generate(input_ids=inputs['input_ids'].to(device), max_length=1024,
                                               top_p=0.9, temperature=args.temperature, do_sample=True,
                                               attention_mask=inputs['attention_mask'].to(device),
                                               no_repeat_ngram_size=3,
                                               use_cache=True, context_kws=context_kws)
            # output_seqs = model.model.generate(input_ids=input_ids, max_length=1024,
            #                                    top_k=10, do_sample=False, attention_mask=masks, use_cache=True)

            cards_text = []
            for j in range(idx, end):
                scene = data[j]
                card_text = ''
                for entry in scene['entries'][-1:]:
                    # if entry['role'] == 'narrator':
                    #     continue
                    for card in entry['cards']:
                        card_text += card['description'] + '<end_card>'
                cards_text.append(card_text)
            answer = [helper.tokenizer.decode(sample['output'].tolist(), skip_special_tokens=True) for sample in batch]
            prompt = [helper.tokenizer.decode(sample['input'].tolist(), skip_special_tokens=True) for sample in batch]
            output_text = [
                helper.tokenizer.decode(sample.tolist(), skip_special_tokens=True).replace(prompt[idx], '', 1)
                for idx, sample in enumerate(output_seqs)]

            # prompt = helper.tokenizer.decode(input_ids.squeeze().tolist())
            for j in range(0, end - idx):
                res.append(
                    {'prompt': prompt[j], 'generated': output_text[j], 'answer': answer[j], 'cards': cards_text[j]})

        with open(f'result/{get_file_name()}.json', 'w', encoding='utf-8') as f:
            json.dump(res, f, indent=1, ensure_ascii=False)
        print(f"finish generate to {get_file_name()}")
        # with open('gpt2_persona_guide_onlytarget.json', 'w', encoding='utf-8') as f:
        #     json.dump(res, f, indent=1, ensure_ascii=False)

    parallel_generate()
    exit()


def debug(args):
    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    args.keywords_masks = {}
    test_dataset = StoriumDataset('./data/test{}.json'.format(args.train_val_test_suffix),
                                  only_target=args.onlytarget, only_outline=args.onlyoutline,
                                  only_bedding=args.onlybedding, only_ending=args.onlyending,
                                  only_outlinebedding=args.onlyoutlinebedding)
    for idx, item in tqdm(enumerate(test_dataset)):
        if idx > 200:
            break
    print(args.keywords_masks)
    exit()

    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    # torch.cuda.set_device(args.gpu)
    device = torch.device("cuda:{}".format(args.gpu))
    model.to(device)
    model.model.to(device)
    model.eval()
    model.model.eval()


def true_outline(args):
    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    def get_file_name():
        name_with_postfix = os.path.basename(__file__)
        idx = name_with_postfix.rindex('.')
        res = name_with_postfix[:idx]
        return res + '_trueoutline'

    # only target: version_36

    # global graph
    # graph = get_conceptnet()
    # print('finish load graph!')

    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    # torch.cuda.set_device(args.gpu)
    device = torch.device("cuda:{}".format(args.gpu))
    model.to(device)
    model.model.to(device)
    model.eval()
    model.model.eval()
    test_dataset = StoriumDataset('./data/test{}.json'.format(args.train_val_test_suffix),
                                  ignore_no_outline_sample=args.ignore)
    # if num_workers !=0 then lambda function in pad_collate would cause error (in multiprocessing)
    # batch_size = 1
    # test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
    #                              num_workers=multiprocessing.cpu_count(),
    #                              collate_fn=pad_collate)
    data = test_dataset.a
    res = []

    def parallel_generate():
        batch_size = 16

        def prune(x):
            if helper.tokenizer.eos_token_id in x:
                idx = x.index(helper.tokenizer.eos_token_id)
                return x[:idx]
            return x

        ending = len(test_dataset)
        for idx in tqdm(range(0, ending, batch_size)):
            # for idx in tqdm(range(1500)):
            # if idx > 10:
            #     break
            end = min(ending, idx + batch_size)
            batch = []
            for j in range(idx, end):
                batch.append(test_dataset[j])

            # if batch['input'].size(0) > 150 or batch['output'].size(0) > 150:
            #     continue
            # input_ids = [sample['input'].to(device) for sample in batch]
            # for sample in batch:
            #     endoftarget_idx = sample['output'].tolist().index(helper.get_token_id(helper.eot_token))
            #     sample['output'] = sample['output'][:endoftarget_idx + 1]

            def process(ts):
                # print(sample)
                # print(sample['output'])
                # print(sample['output'].size())
                end_idx = ts.tolist().index(helper.get_token_id(helper.bob_token))
                return ts[:end_idx + 1]

            input_ids = [helper.tokenizer.decode(torch.cat([sample['input'], process(sample['output'])])) for sample in
                         batch]
            # input_ids = [helper.tokenizer.decode(sample['input']) for sample in batch]

            helper.tokenizer.padding_side = "left"
            inputs = helper.tokenizer(input_ids, return_tensors="pt", padding=True)
            context_kws = [sample['context_kws'] for sample in batch]
            output_seqs = model.model.generate(input_ids=inputs['input_ids'].to(device), max_length=1024,
                                               top_k=10, do_sample=True,
                                               attention_mask=inputs['attention_mask'].to(device),
                                               use_cache=True, context_kws=context_kws,
                                               no_repeat_ngram_size=args.norepeatngram)
            # output_seqs = model.model.generate(input_ids=input_ids, max_length=1024,
            #                                    top_k=10, do_sample=False, attention_mask=masks, use_cache=True)

            cards_text = []
            for j in range(idx, end):
                scene = data[j]
                card_text = ''
                for entry in scene['entries'][-1:]:
                    # if entry['role'] == 'narrator':
                    #     continue
                    for card in entry['cards']:
                        card_text += card['description'] + '<end_card>'
                cards_text.append(card_text)
            answer = [helper.tokenizer.decode(sample['output'].tolist(), skip_special_tokens=True) for sample in batch]
            prompt = [helper.tokenizer.decode(sample['input'].tolist(), skip_special_tokens=True) for sample in batch]
            output_text = [
                helper.tokenizer.decode(sample.tolist(), skip_special_tokens=True).replace(prompt[idx], '', 1)
                for idx, sample in enumerate(output_seqs)]

            # prompt = helper.tokenizer.decode(input_ids.squeeze().tolist())
            for j in range(0, end - idx):
                res.append(
                    {'prompt': prompt[j], 'generated': output_text[j], 'answer': answer[j], 'cards': cards_text[j]})

        with open(f'result/{get_file_name()}.json', 'w', encoding='utf-8') as f:
            json.dump(res, f, indent=1, ensure_ascii=False)

        # with open('gpt2_persona_guide_onlytarget.json', 'w', encoding='utf-8') as f:
        #     json.dump(res, f, indent=1, ensure_ascii=False)

    parallel_generate()
    exit()


def generate_outline(args):
    def get_ckpt_name(dir):
        return os.path.join(dir, os.listdir(dir)[0])

    ckpt = torch.load(get_ckpt_name('{}/lightning_logs/version_{}/checkpoints'.format(args.root_dir, args.version_num)),
                      map_location="cuda:{}".format(args.gpu))
    model = Gpt2('gpt2_en_ckpt_origin')
    model.load_state_dict(ckpt['state_dict'])
    # torch.cuda.set_device(args.gpu)
    device = torch.device("cuda:{}".format(args.gpu))
    model.to(device)
    model.model.to(device)
    model.eval()
    model.model.eval()

    res = []
    print('start to generate outline!')

    def parallel_generate():
        batch_size = 2

        def prune(x):
            if helper.tokenizer.eos_token_id in x:
                idx = x.index(helper.tokenizer.eos_token_id)
                return x[:idx]
            return x

        res = []

        with open('{}.json'.format(args.root_dir), encoding='utf-8') as f:
            a = json.load(f)
            ending = len(a)
            with torch.no_grad():
                for idx in tqdm(range(0, ending, batch_size)):
                    # for idx in tqdm(range(1500)):
                    # if idx > 10:
                    #     break
                    end = min(ending, idx + batch_size)
                    batch = []
                    for j in range(idx, end):
                        batch.append(a[j])

                    input_sents = []
                    for sample in batch:
                        beginofbed_idx = sample['generated'].find('<|beginofbedding|>')
                        input_sents.append(sample['prompt'] + sample['generated'][:beginofbed_idx].strip())
                        # if '<|endoftarget|>' not in input_sents[-1]:
                        #     # print(sample)
                        #     # exit()
                        #     input_sents[-1] += '<|endoftarget|>'
                        # sample['output'] = sample['output'][:endoftarget_idx + 1]
                    #
                    # input_ids = [helper.tokenizer.decode(torch.cat([sample['input'], sample['output']])) for sample in batch]

                    helper.tokenizer.padding_side = "left"
                    inputs = helper.tokenizer(input_sents, return_tensors="pt", padding=True, max_length=1023,
                                              truncation=True)
                    inputs['input_ids'] = torch.cat([inputs['input_ids'], torch.full((inputs['input_ids'].size(0), 1),
                                                                                     helper.get_token_id(
                                                                                         helper.eot_token),
                                                                                     dtype=torch.long)], dim=-1)
                    inputs['attention_mask'] = torch.cat(
                        [inputs['attention_mask'], torch.full((inputs['attention_mask'].size(0), 1),
                                                              1,
                                                              dtype=torch.long)], dim=-1)
                    outputs = model.model(input_ids=inputs['input_ids'].to(device),
                                          attention_mask=inputs['attention_mask'].to(device), return_dict=False,
                                          return_outline_dis=True, topk=31)
                    outline_ids = outputs[-2]
                    probs = outputs[-1]
                    # print(outline_ids.size())
                    # print(probs.size())
                    # exit()
                    tokens = [helper.tokenizer.convert_ids_to_tokens(ids) for ids in outline_ids]
                    for outline in tokens:
                        for i, token in enumerate(outline):
                            if ord(token[0]) == 288:
                                outline[i] = token[1:]
                    # tokens = [token if ord(token[0]) != 288 else token[1:] for token in tokens]
                    # print(tokens)
                    for i in range(idx, end):
                        scene = batch[i - idx]
                        scene['predict_outline'] = tokens[i - idx]
                        res.append(scene)

        with open('{}_predict_outline.json'.format(args.root_dir), 'w', encoding='utf-8') as f:
            json.dump(res, f, indent=1, ensure_ascii=False)

        # with open('gpt2_persona_guide_onlytarget.json', 'w', encoding='utf-8') as f:
        #     json.dump(res, f, indent=1, ensure_ascii=False)

    parallel_generate()
    exit()

graph = get_conceptnet()
print('finish load graph!')


parser = argparse.ArgumentParser()
parser.add_argument('--test', action='store_true')
parser.add_argument('--onlytarget', action='store_true')
parser.add_argument('--onlyoutline', action='store_true')
parser.add_argument('--onlybedding', action='store_true')
parser.add_argument('--onlyending', action='store_true')
parser.add_argument('--onlyoutlinebedding', action='store_true')
parser.add_argument('--generate', action='store_true')
parser.add_argument('--generatewithtruetarget', action='store_true')
parser.add_argument('--generateoutline', action='store_true')
parser.add_argument('--trueoutline', action='store_true')

parser.add_argument('--debug', action='store_true')

parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--norepeatngram', type=int, default=0)
parser.add_argument('--ignore', help='whether ignore no outline sample', action='store_true')
parser.add_argument('--process', action='store_true')
parser.add_argument('--temperature', type=float, default=1.0)

parser.add_argument('--infile', type=str, default=None)
parser.add_argument('--outfile', type=str, default=None)
parser.add_argument('--generatenum', type=int, default=None)
parser.add_argument('--version_num', type=int, default=0)

args = parser.parse_args()

if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--test', action='store_true')
    # parser.add_argument('--onlytarget', action='store_true')
    # parser.add_argument('--onlyoutline', action='store_true')
    # parser.add_argument('--onlybedding', action='store_true')
    # parser.add_argument('--onlyending', action='store_true')
    # parser.add_argument('--onlyoutlinebedding', action='store_true')
    # parser.add_argument('--generate', action='store_true')
    # parser.add_argument('--generatewithtruetarget', action='store_true')
    # parser.add_argument('--generateoutline', action='store_true')
    # parser.add_argument('--trueoutline', action='store_true')

    # parser.add_argument('--debug', action='store_true')

    # parser.add_argument('--gpu', type=int, default=0)
    # parser.add_argument('--norepeatngram', type=int, default=0)
    # parser.add_argument('--ignore', help='whether ignore no outline sample', action='store_true')
    # parser.add_argument('--process', action='store_true')
    # parser.add_argument('--temperature', type=float, default=1.0)

    # parser.add_argument('--infile', type=str, default=None)
    # parser.add_argument('--outfile', type=str, default=None)
    # parser.add_argument('--generatenum', type=int, default=None)
    # parser.add_argument('--version_num', type=int, default=0)

    # args = parser.parse_args()

    if args.process:
        model_type = 'sentence-transformers/roberta-large-nli-stsb-mean-tokens'
        layers = 24
        print('begin process')
        scorer = BERTScorer(lang='en', model_type=model_type,
                            rescale_with_baseline=False, idf=True, num_layers=layers, nthreads=os.cpu_count(),
                            idf_sents=get_idf_sents(), batch_size=64,  # masked_words=tokenized_stop_words
                            )
        print('finish load scorer')
        preprocess(scorer)
        exit()

    from config import Config

    configs = Config(args=args, file=__file__)
    configs.show()

    args = configs

    if args.debug:
        debug(args)
    else:
        if args.generate:
            generate(args)
        elif args.trueoutline:
            true_outline(args)
        elif args.generateoutline:
            generate_outline(args)
        elif args.generatewithtruetarget:
            generate_with_true_target(args)
        else:
            if args.test:
                test(args)
            else:
                train(args)
