#coding:utf-8
import numpy as np
import json
from model.dialog_model import DialogModel
from model.embedding import use_cuda
from preprocession import *
import torch
import warnings
import yaml
import os
import argparse 
import logging
import time
import inspect
from gpu_mem_track import MemTracker
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
import random
import gc

# 创建解析器
parser = argparse.ArgumentParser() 
parser.add_argument('--data_path', type=str, help='data path', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/fake_two_hop_data/')
parser.add_argument('--embed_units', type=int, help='embed units', default=300)
parser.add_argument('--num_epoch', type=int, help='max epoch', default=50)
parser.add_argument('--batch_size', type=int, help='batch size', default=30)
parser.add_argument('--max_gradient_norm', type=int, help='max_gradient_norm', default=5)
parser.add_argument('--generated_path', type=str, help='result dir', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/dialog_training_process_fake/')
parser.add_argument('--lr_rate', type=float, help='lr rate', default= 0.0001)
parser.add_argument('--data_dir', type=str, help='data dir', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/edit_data/')
parser.add_argument('--trans_units', type=int, help='trans_units', default=100)
parser.add_argument('--units', type=int, help='units', default=512)
parser.add_argument('--layers', type=int, help='layers', default=2)
parser.add_argument('--gnn_layers', type=int, help='gnn layers', default=3)
parser.add_argument('--symbols', type=int, help='symbols', default=30000)
parser.add_argument('--linear_dropout', type=float, help='linear_dropout', default=0.2)

parser.add_argument('--local_rank', type=int, help='local_rank', default=-1)
parser.add_argument('--ckpt', type=str, help='ckpt file', default=None)
parser.add_argument('--prev_ppx_word', type=float, help='prev_ppx_word', default=None)
args = parser.parse_args()

warnings.filterwarnings('ignore')


class DialogDataset(Dataset):
    def __init__(self, path, net_info):
        self.dataset = []
        with open(path) as f:
            for line in f:
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']
        self.dict_csk_entities = net_info['dict_csk_entities']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])


    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data]) + 1  
        decoder_len = max([len(item['response']) for item in data]) + 1    
        graph_len = max([len(item['nodes']['0']) + len(item['nodes']['1']) + len(item['nodes']['2']) for item in data])

        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        graph = np.full((len(data), graph_len), 0, dtype=int)
        edges = np.full((len(data), graph_len + 1, graph_len + 1), 0, dtype=int)
        match_entity = np.full((len(data), decoder_len), -1, dtype=int)

        for idx, item in enumerate(data):
            
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']
                  
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            # graph
            i = 0
            edges[idx, 0, 0] = 1
            g2l = dict()
            nodes = set(item['nodes']['0']) | set(item['nodes']['1']) | set(item['nodes']['2'])
            for entity_index in nodes:
                if self.csk_entities[entity_index] in self.entity2id:
                    entity = self.entity2id[self.csk_entities[entity_index]]
                    graph[idx][i] = entity
                    g2l[entity] = i
                    i += 1
                    edges[idx, i, i] = 1
                    edges[idx, 0, i] = 2
                    edges[idx, i, 0] = 3

            for tri_index in item['tris']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                edges[idx, index_2, index_1] = index_rel
            
            #match entity
            for i in range(len(item['hit'])):
                if item['hit'][i] == -1:
                    continue
                entity = self.csk_entities[item['hit'][i]]
                if entity not in self.entity2id:
                    continue
                if self.entity2id[entity] not in g2l:
                    continue
                match_entity[idx, i] = g2l[self.entity2id[entity]]
        
        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        padding_num = -2 ** 32 + 1
        nodes = torch.LongTensor(graph)
        edges = torch.LongTensor(edges)
        very_neg_num = torch.ones_like(edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(edges, dtype=torch.float32)
        graph_mask = torch.where(edges==0, very_neg_num, zero_num)

        one_hot_entity = torch.zeros(len(data), decoder_len, graph_len)
        for b in range(len(data)):
            for d in range(decoder_len):
                if match_entity[b][d] == -1:
                    continue
                else:
                    one_hot_entity[b][d][match_entity[b][d]] = 1

        return query_text, answer_text, nodes, edges, graph_mask, one_hot_entity

def get_vocab(path):

    logging.info("Load graph info...")    
    with open(path + 'resource.txt') as f:
        d = json.loads(f.readline())
    f.close()    
    csk_entities = d['csk_entities']
    dict_csk_entities = d['dict_csk_entities']
    raw_vocab = d['vocab_dict']
    csk_triples = d['csk_triples']
    dict_csk_triples = d['dict_csk_triples']
    
    logging.info("Creating word vocabulary...")
    vocab_list = ['_PAD','_GO', '_EOS', '_UNK', ] + sorted(raw_vocab, key=raw_vocab.get, reverse=True)
    if len(vocab_list) > args.symbols:
        vocab_list = vocab_list[:args.symbols]

    word2id = dict()
    for word in vocab_list:
        word2id[word] = len(word2id)

    logging.info("Loading word vectors...")
    vectors = {}
    error_line = 0
    with open(path +'glove.840B.300d.txt', encoding = 'utf8', errors='ignore') as f:
        for i, line in enumerate(f):
            if i % 100000 == 0:
                logging.info("processing line %d" % i)
            s = line.strip()
            if len(s.split()) != 301:
                logging.info(i)
                error_line += 1
                continue
            word = s[:s.find(' ')]
            vector = s[s.find(' ')+1:]
            vectors[word] = vector
    logging.info("error line: %d" %error_line)
    
    embed = []
    for word in vocab_list:
        if word in vectors:
            #vector = map(float, vectors[word].split())
            vector = vectors[word].split()
        else:
            vector = np.zeros((args.embed_units), dtype=np.float32) 
        embed.append(vector)
    embed = np.array(embed, dtype=np.float32)

    logging.info("Creating entity vocabulary...")
    entity_list = ['<NONE>'] 
    with open(path + 'entity.txt') as f:
        for line in f:
            e = line.strip()
            entity_list.append(e)

    entity2id = dict()
    for entity in entity_list:
        entity2id[entity] = len(entity2id)

    logging.info("Loading entity vectors...")
    entity_embed = []
    with open(path + 'entity_transE.txt') as f:
        for line in f:
            s = line.strip().split('\t')
            entity_embed.append(s)
    entity_embed = np.array(entity_embed, dtype=np.float32)
    
    logging.info("Creating relation vocabulary...")
    relation_list = ['<NONE>', 'Selfto', 'TextFrom', 'TextTo'] 
    with open(path + 'relation.txt') as f:
        for line in f:
            r = line.strip()
            relation_list.append(r)

    relation2id = dict()        
    for relation in relation_list:
        relation2id[relation] = len(relation2id) 

    logging.info("Loading relation vectors...")
    relation_embed = []
    with open(path + 'relation_transE.txt') as f:
        for line in f:
            s = line.strip().split('\t')
            relation_embed.append(s)
    relation_embed = np.array(relation_embed, dtype=np.float32)

    net_info = dict()
    net_info['csk_entities'] = csk_entities
    net_info['dict_csk_entities'] = dict_csk_entities
    net_info['csk_triples'] = csk_triples
    net_info['dict_csk_triples'] = dict_csk_triples
    net_info['entity2id'] = entity2id
    net_info['relation2id'] = relation2id
    net_info['word2id'] = word2id

    return net_info, embed, entity_embed, relation_embed


def inference(args, model, net_info):

    model.eval()

    id2word = dict()
    word2id = net_info['word2id']
    for key in word2id.keys():
        id2word[word2id[key]] = key
    
    def write_batch_res_text(word_index, selector):
        batch_size = len(word_index)
        decoder_len = len(word_index[0])

        text = []
        responses = []
        for i in range(batch_size):
            tmp_dict = dict()
            tmp = []
            for j in range(decoder_len):
                if word_index[i][j] == 2:
                    break
                tmp.append(id2word[word_index[i][j]])
            tmp_dict['response_text'] = tmp
            entity_tmp = []
            for j in range(len(tmp)):
                if selector[i][j] == 1:
                    entity_tmp.append(tmp[j])
            tmp_dict['entity'] = entity_tmp
            
            responses.append(tmp)
            text.append(tmp_dict)

        return responses, text

    data_test = DialogDataset(args.data_path + 'testset.txt', net_info)
    logging.info('Load test data from %s' %(args.data_path + 'testset.txt'))

    eval_sampler = SequentialSampler(data_test)
    test_loader = DataLoader(data_test, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_test.collate_fn)

    all_responses = []
    all_answers = []
    all_posts = []
    all_text = []

    count = 0
    for iteration, batch_data in enumerate(test_loader):
        batch = tuple(input_tensor.cuda() for input_tensor in batch_data)
        query, answer, nodes, edges, graph_mask, one_hot_entity = batch
        word_index, selector, _ = model(query, answer, nodes, edges, graph_mask, one_hot_entity)

        if count % 100 == 0:
            logging.info ("generate: %d" %iteration)
        count += 1
        responses, text = write_batch_res_text(word_index, selector)

        answer_text = answer.cpu().numpy()
        answers = []
        for i in range(answer_text.shape[0]):
            tmp = []
            for j in range(answer_text.shape[1]):
                if answer_text[i][j] == 2:
                    break
                tmp.append(id2word[int(answer_text[i][j])])
            answers.append(tmp)
        
        post_text = query.cpu().numpy()
        posts = []
        for i in range(post_text.shape[0]):
            tmp = []
            for j in range(post_text.shape[1]):
                if post_text[i][j] == 2:
                    break
                tmp.append(id2word[int(post_text[i][j])])
            posts.append(tmp)

        all_responses.extend(responses)
        all_answers.extend(answers)
        all_posts.extend(posts)
        all_text.extend(text)

    answers_sens = [' '.join(x) for x in all_answers]
    responses_sens = [' '.join(x) for x in all_responses]
    posts_sens = [' '.join(x) for x in all_posts]
    
    w_answers = open(args.generated_path + 'answers.txt', 'w')
    for s in answers_sens:
        w_answers.write(s + "\n")
    w_answers.close()

    w_responses = open(args.generated_path + 'responses.txt', 'w')
    for s in responses_sens:
        w_responses.write(s + "\n")
    w_responses.close()

    w_post = open(args.generated_path + 'posts.txt', 'w')
    for s in posts_sens:
        w_post.write(s + "\n")
    w_post.close()

    w_dialog = open(args.generated_path + 'test_dialog.txt', 'w')
    for x,y,z in zip(posts_sens, responses_sens, answers_sens):
        w_dialog.write(x + "\t" + y + "\t" + z + "\n")
    w_dialog.close()

    w = open(args.generated_path + 'generated_res.txt', 'w')
    for line in all_text:
        w.write(json.dumps(line) + '\n')
    w.close()

    return


def main():

    logging.basicConfig(level = logging.INFO)
    net_info, word_embed, entity_embed, relation_embed = get_vocab(args.data_dir)

    model = DialogModel(args, word_embed, entity_embed, relation_embed).cuda()
    
    if args.ckpt != None:
        ckpt_state = torch.load(args.ckpt)
        model.load_state_dict(ckpt_state['model'])
        start_epoch = int(ckpt_state['epoch'])

    model.is_inference = True
    model.word2id = net_info['word2id']
    entity2id = net_info['entity2id']
    id2entity = dict()
    for key in entity2id.keys():
        id2entity[entity2id[key]] = key
    model.entity2id = entity2id
    model.id2entity = id2entity

    inference(args, model, net_info)

main()