#coding:utf-8
import numpy as np
import json
from model.dialog_model import DialogModel
from model.embedding import use_cuda
from preprocession import *
import torch
import warnings
import yaml
import os
import argparse 
import logging
import time
import inspect
from gpu_mem_track import MemTracker
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
import random
import gc

# 创建解析器
parser = argparse.ArgumentParser() 
parser.add_argument('--data_path', type=str, help='data path', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/fake_two_hop_data/')
parser.add_argument('--embed_units', type=int, help='embed units', default=300)
parser.add_argument('--num_epoch', type=int, help='max epoch', default=50)
parser.add_argument('--batch_size', type=int, help='batch size', default=30)
parser.add_argument('--max_gradient_norm', type=int, help='max_gradient_norm', default=5)
parser.add_argument('--result_dir_name', type=str, help='result dir', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/dialog_training_process_fake/')
parser.add_argument('--lr_rate', type=float, help='lr rate', default= 0.0001)
parser.add_argument('--data_dir', type=str, help='data dir', default='/apdcephfs/share_47076/pengdasi/ConceptFlow/ConceptFlow-master/edit_data/')
parser.add_argument('--trans_units', type=int, help='trans_units', default=100)
parser.add_argument('--units', type=int, help='units', default=512)
parser.add_argument('--layers', type=int, help='layers', default=2)
parser.add_argument('--gnn_layers', type=int, help='gnn layers', default=3)
parser.add_argument('--symbols', type=int, help='symbols', default=30000)
parser.add_argument('--linear_dropout', type=float, help='linear_dropout', default=0.2)

parser.add_argument('--local_rank', type=int, help='local_rank', default=-1)
parser.add_argument('--ckpt', type=str, help='ckpt file', default=None)
parser.add_argument('--prev_ppx_word', type=float, help='prev_ppx_word', default=None)
args = parser.parse_args()

warnings.filterwarnings('ignore')


class DialogDataset(Dataset):
    def __init__(self, path, net_info):
        self.dataset = []
        with open(path) as f:
            for line in f:
                self.dataset.append(line)

        self.csk_entities = net_info['csk_entities']
        self.csk_triples = net_info['csk_triples']
        self.word2id = net_info['word2id']
        self.entity2id = net_info['entity2id']
        self.relation2id = net_info['relation2id']
        self.dict_csk_entities = net_info['dict_csk_entities']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return json.loads(self.dataset[idx])


    def collate_fn(self, data):

        encoder_len = max([len(item['post']) for item in data]) + 1  
        decoder_len = max([len(item['response']) for item in data]) + 1    
        graph_len = max([len(item['nodes']['0']) + len(item['nodes']['1']) + len(item['nodes']['2']) for item in data])

        posts_id = np.full((len(data), encoder_len), 0, dtype=int)
        responses_id = np.full((len(data), decoder_len), 0, dtype=int)
        graph = np.full((len(data), graph_len), 0, dtype=int)
        edges = np.full((len(data), graph_len + 1, graph_len + 1), 0, dtype=int)
        match_entity = np.full((len(data), decoder_len), -1, dtype=int)

        for idx, item in enumerate(data):
            
            # posts
            for i, post_word in enumerate(padding(item['post'], encoder_len)):
                if post_word in self.word2id:
                    posts_id[idx, i] = self.word2id[post_word]                    
                else:
                    posts_id[idx, i] = self.word2id['_UNK']
                  
            # responses
            for i, response_word in enumerate(padding(item['response'], decoder_len)):
                if response_word in self.word2id:
                    responses_id[idx, i] = self.word2id[response_word]
                else:
                    responses_id[idx, i] = self.word2id['_UNK']

            # graph
            i = 0
            edges[idx, 0, 0] = 1
            g2l = dict()
            nodes = set(item['nodes']['0']) | set(item['nodes']['1']) | set(item['nodes']['2'])
            for entity_index in nodes:
                if self.csk_entities[entity_index] in self.entity2id:
                    entity = self.entity2id[self.csk_entities[entity_index]]
                    graph[idx][i] = entity
                    g2l[entity] = i
                    i += 1
                    edges[idx, i, i] = 1
                    edges[idx, 0, i] = 2
                    edges[idx, i, 0] = 3

            for tri_index in item['tris']:
                sbj = self.csk_triples[tri_index].split()[0][:-1]
                rel = self.csk_triples[tri_index].split()[1][:-1]
                obj = self.csk_triples[tri_index].split()[2]
                if (sbj not in self.entity2id) or (obj not in self.entity2id):
                    continue
                if (self.entity2id[sbj] not in g2l) or (self.entity2id[obj] not in g2l):
                    continue
                index_1 = g2l[self.entity2id[sbj]] + 1
                index_2 = g2l[self.entity2id[obj]] + 1
                index_rel = self.relation2id[rel]
                edges[idx, index_2, index_1] = index_rel
            
            #match entity
            for i in range(len(item['hit'])):
                if item['hit'][i] == -1:
                    continue
                entity = self.csk_entities[item['hit'][i]]
                if entity not in self.entity2id:
                    continue
                if self.entity2id[entity] not in g2l:
                    continue
                match_entity[idx, i] = g2l[self.entity2id[entity]]
        
        query_text = torch.LongTensor(np.array(posts_id))
        answer_text = torch.LongTensor(np.array(responses_id))

        padding_num = -2 ** 32 + 1
        nodes = torch.LongTensor(graph)
        edges = torch.LongTensor(edges)
        very_neg_num = torch.ones_like(edges, dtype=torch.float32) * padding_num
        zero_num = torch.zeros_like(edges, dtype=torch.float32)
        graph_mask = torch.where(edges==0, very_neg_num, zero_num)

        one_hot_entity = torch.zeros(len(data), decoder_len, graph_len)
        for b in range(len(data)):
            for d in range(decoder_len):
                if match_entity[b][d] == -1:
                    continue
                else:
                    one_hot_entity[b][d][match_entity[b][d]] = 1

        return query_text, answer_text, nodes, edges, graph_mask, one_hot_entity

def get_vocab(path):

    logging.info("Load graph info...")    
    with open(path + 'resource.txt') as f:
        d = json.loads(f.readline())
    f.close()    
    csk_entities = d['csk_entities']
    dict_csk_entities = d['dict_csk_entities']
    raw_vocab = d['vocab_dict']
    csk_triples = d['csk_triples']
    dict_csk_triples = d['dict_csk_triples']
    
    logging.info("Creating word vocabulary...")
    vocab_list = ['_PAD','_GO', '_EOS', '_UNK', ] + sorted(raw_vocab, key=raw_vocab.get, reverse=True)
    if len(vocab_list) > args.symbols:
        vocab_list = vocab_list[:args.symbols]

    word2id = dict()
    for word in vocab_list:
        word2id[word] = len(word2id)

    logging.info("Loading word vectors...")
    vectors = {}
    error_line = 0
    with open(path +'glove.840B.300d.txt', encoding = 'utf8', errors='ignore') as f:
        for i, line in enumerate(f):
            if i % 100000 == 0:
                logging.info("processing line %d" % i)
            s = line.strip()
            if len(s.split()) != 301:
                logging.info(i)
                error_line += 1
                continue
            word = s[:s.find(' ')]
            vector = s[s.find(' ')+1:]
            vectors[word] = vector
    logging.info("error line: %d" %error_line)
    
    embed = []
    for word in vocab_list:
        if word in vectors:
            #vector = map(float, vectors[word].split())
            vector = vectors[word].split()
        else:
            vector = np.zeros((args.embed_units), dtype=np.float32) 
        embed.append(vector)
    embed = np.array(embed, dtype=np.float32)

    logging.info("Creating entity vocabulary...")
    entity_list = ['<NONE>'] 
    with open(path + 'entity.txt') as f:
        for line in f:
            e = line.strip()
            entity_list.append(e)

    entity2id = dict()
    for entity in entity_list:
        entity2id[entity] = len(entity2id)

    logging.info("Loading entity vectors...")
    entity_embed = []
    with open(path + 'entity_transE.txt') as f:
        for line in f:
            s = line.strip().split('\t')
            entity_embed.append(s)
    entity_embed = np.array(entity_embed, dtype=np.float32)
    
    logging.info("Creating relation vocabulary...")
    relation_list = ['<NONE>', 'Selfto', 'TextFrom', 'TextTo'] 
    with open(path + 'relation.txt') as f:
        for line in f:
            r = line.strip()
            relation_list.append(r)

    relation2id = dict()        
    for relation in relation_list:
        relation2id[relation] = len(relation2id) 

    logging.info("Loading relation vectors...")
    relation_embed = []
    with open(path + 'relation_transE.txt') as f:
        for line in f:
            s = line.strip().split('\t')
            relation_embed.append(s)
    relation_embed = np.array(relation_embed, dtype=np.float32)

    net_info = dict()
    net_info['csk_entities'] = csk_entities
    net_info['dict_csk_entities'] = dict_csk_entities
    net_info['csk_triples'] = csk_triples
    net_info['dict_csk_triples'] = dict_csk_triples
    net_info['entity2id'] = entity2id
    net_info['relation2id'] = relation2id
    net_info['word2id'] = word2id

    return net_info, embed, entity_embed, relation_embed


def train(args, model, net_info, model_optimizer, local_rank, start_epoch):

    if args.prev_ppx_word != None:
        prev_ppx_word = args.prev_ppx_word

    file_part_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
    s1 = ''.join(file_part_list)

    count = 0
    for epoch in range(start_epoch, args.num_epoch):
        logging.info ("epoch: %d"  % (epoch + 1))
        
        total_ppx = 0.0
        total_ppx_word = 0.0
        total_ppx_entity = 0.0

        total_word_cut = 0
        total_entity_cut = 0

        g = torch.Generator()
        g.manual_seed(epoch)
        idx_list = torch.randperm(len(file_part_list), generator=g).tolist()
        new_file_part_list = [file_part_list[i] for i in idx_list]
        s2 = ''.join(new_file_part_list)
        logging.info('Process %d, original order: %s, shuffle order: %s' %(local_rank, s1, s2))
        
        model.train()
        start_time = time.time()

        for c in new_file_part_list:
            path = args.data_path + 'train_a' + c
            data_train = DialogDataset(path, net_info)
            logging.info('Load train data from %s' %path)

            train_sampler = DistributedSampler(data_train)
            train_loader = DataLoader(data_train, sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_train.collate_fn)
            train_sampler.set_epoch(epoch)

            data_size = 10 * len(data_train)

            for iteration, batch_data in enumerate(train_loader):
                batch = tuple(input_tensor.cuda() for input_tensor in batch_data)
                query, answer, nodes, edges, graph_mask, one_hot_entity = batch
                loss, ppx, ppx_word, ppx_entity, word_neg_num, entity_neg_num\
                     = model(query, answer, nodes, edges, graph_mask, one_hot_entity)

                total_ppx += float(ppx)
                total_ppx_word += float(ppx_word)
                total_ppx_entity += float(ppx_entity)
                total_word_cut += int(word_neg_num)
                total_entity_cut += int(entity_neg_num)

                model_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), args.max_gradient_norm)
                model_optimizer.step()

                if count % 1000 == 0:                
                    iteration_time = time.time() - start_time
                    logging.info ("iteration: %d, loss: %f, time: %s" %(count, loss.data, str(iteration_time)))
                    start_time = time.time()
                count += 1

            logging.info('train data from %s finished!' %path)

            del train_loader
            gc.collect()
            del train_sampler
            gc.collect()
            del data_train
            gc.collect()
        
        logging.info("ppx for epoch %d: %f" %(epoch + 1, np.exp(total_ppx / data_size)))
        logging.info("ppx entity for epoch %d: %f" %(epoch + 1, np.exp(total_ppx_word / (data_size - total_word_cut))))
        logging.info("ppx entity for epoch %d: %f" %(epoch + 1, np.exp(total_ppx_entity / (data_size - total_entity_cut))))

        ppx, ppx_word, ppx_entity = evaluate(args, model, net_info)
        
        if local_rank == 0:
            ppx_f = open(args.result_dir_name + 'result.txt','a')
            ppx_f.write("epoch " + str(epoch + 1) + " ppx: " + str(ppx) + " ppx_word: " + str(ppx_word) + " ppx_entity: " + \
                str(ppx_entity) + '\n')
            ppx_f.close()

        if epoch > 0:
            if ppx_word > prev_ppx_word:
                break
        prev_ppx_word = ppx_word

        if local_rank == 0:
            state = {'model': model.module.state_dict(), 'optimizer': model_optimizer.state_dict(), 'epoch': epoch + 1}
            torch.save(state, args.result_dir_name + 'epoch_' + str(epoch + 1) + '.pkl')

def evaluate(args, model, net_info):

    model.eval()

    data_test = DialogDataset(args.data_path + 'testset.txt', net_info)
    logging.info('Load test data from %s' %(args.data_path + 'testset.txt'))

    eval_sampler = SequentialSampler(data_test)
    test_loader = DataLoader(data_test, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_test.collate_fn)

    total_ppx = 0.0
    total_ppx_word = 0.0
    total_ppx_entity = 0.0

    total_word_cut = 0
    total_entity_cut = 0

    count = 0
    for iteration, batch_data in enumerate(test_loader):
        batch = tuple(input_tensor.cuda() for input_tensor in batch_data)
        query, answer, nodes, edges, graph_mask, one_hot_entity = batch
        loss, ppx, ppx_word, ppx_entity, word_neg_num, entity_neg_num\
            = model(query, answer, nodes, edges, graph_mask, one_hot_entity)

        total_ppx += float(ppx)
        total_ppx_word += float(ppx_word)
        total_ppx_entity += float(ppx_entity)
        total_word_cut += int(word_neg_num)
        total_entity_cut += int(entity_neg_num)

        if count % 100 == 0:
            logging.info ("iteration for evaluate: %d, Loss: %f" %(iteration, loss.data))
        count += 1
        
    logging.info("ppx on test set %f" % np.exp(total_ppx / len(data_test)))
    logging.info("ppx_word on test set %f" % np.exp(total_ppx_word / (len(data_test) - total_word_cut)))
    logging.info("ppx_entity on test set %f" % np.exp(total_ppx_entity / (len(data_test) - total_entity_cut)))

    return np.exp(total_ppx / len(data_test)), np.exp(total_ppx_word / (len(data_test) - total_word_cut)), \
        np.exp(total_ppx_entity / (len(data_test) - total_entity_cut))

def main():

    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    net_info, word_embed, entity_embed, relation_embed = get_vocab(args.data_dir)

    logging.info(args.local_rank)
    is_distributed = (args.local_rank != -1)

    if not is_distributed:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    model = DialogModel(args, word_embed, entity_embed, relation_embed).to(device)
    start_epoch = 0
    
    if args.ckpt != None:
        ckpt_state = torch.load(args.ckpt)
        model.load_state_dict(ckpt_state['model'])
        start_epoch = int(ckpt_state['epoch'])

    if args.local_rank == 0:
        torch.distributed.barrier()

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    model_optimizer = torch.optim.Adam(model.parameters(), lr = args.lr_rate)

    if args.ckpt != None:
        optimizer_state = ckpt_state['optimizer']
        model_optimizer.load_state_dict(optimizer_state)  

    if not os.path.exists(args.result_dir_name):
        os.mkdir(args.result_dir_name)

    train(args, model, net_info, model_optimizer, args.local_rank, start_epoch)

main()
