import json
import pickle
import numpy as np
# from aser.extract.aser_extractor import SeedRuleASERExtractor, DiscourseASERExtractor
from scipy.sparse import csr_matrix
import torch
from transformers import RobertaTokenizer,RobertaModel

tokenizer = RobertaTokenizer.from_pretrained("pre_train/roberta_base/")
model = RobertaModel.from_pretrained("pre_train/roberta_base/").cuda()

class Data_preprocess:

    def __init__(self):
        # 加载数据集
        with open('newdata/DailyDialog/data_list.pkl', 'rb') as f:
            self.data_list = pickle.load(f) # 数据集

        self.id_num = 0  #结点数
        self.node_list = []  # 结点集
        # self.features_list = []  # 特征集
        self.utte_id = []
        self.utte_num = []
        self.speaker_num = []
        self.speakers = []
        self.utterance_event_source = []
        self.utterance_event_target = []
        self.utterance_speaker_source = []
        self.utterance_speaker_target = []
        #创建结点集合，生成特征
        # self.build_node()
        self.build_node_edges()
        #self.label_list = []

        self.utterance_source = []
        self.utterance_target = []
        self.utterance_edges()
        self.edges = []
        self.build_edges()

        # 构建标签数据集
        self.labels = []  # 标签集
        self.build_label()

    # Robert计算特征
    def cal_feature(self, text):
        encoded_input = tokenizer(text, return_tensors='pt')
        print("llllllll")
        print(type(encoded_input))
        for key in encoded_input.keys():
            encoded_input[key] = encoded_input[key].cuda()
        # output = torch.max(model(**encoded_input)['last_hidden_state'], dim=1)[0]
        output = model(**encoded_input).pooler_output
        output = output.cpu().detach().numpy().tolist()[0]
        print("oooooooooo")
        return output

    def build_node_features(self):
        id_num = 0
        # pool_features = []
        hidden_features = []
        for i in range(len(self.data_list)):
            utterance_num = 0
            # 每个对话语句结点
            for j in range(len(self.data_list[i])):
                h = self.cal_feature(self.data_list[i][j]['text'])
                hidden_features.append(h)
                # pool_features.append(p)
                id_num = id_num + 1

                 #特征计算
                # self.id_num = self.id_num + 1
                utterance_num = utterance_num + 1
            temp = id_num - utterance_num
            # self.utte_num.append(utterance_num)
            temp2 = temp

            # 事件结点
            for j in range(len(self.data_list[i])):
                if (len(self.data_list[i][j]['event']) != 0):
                    for k in range(len(self.data_list[i][j]['event'])):
                        event = self.data_list[i][j]['event'][k]
                        h = self.cal_feature(event)
                        hidden_features.append(h)
                        id_num = id_num + 1

                temp = temp + 1

            # 说话者结点
            speaker_set = set()
            for j in range(len(self.data_list[i])):
                speaker_set.add(self.data_list[i][j]['speaker'])
            speaker_set = list(speaker_set)
            # self.speaker_num.append(len(speaker_set))
            # self.speakers.append(speaker_set)

            for k in range(len(speaker_set)):

                hset = []
                # pset = []
                # print(node['content'])
                for j in range(len(self.data_list[i])):
                    if (speaker_set[k] == self.data_list[i][j]['speaker']):
                        hset.append(hidden_features[temp2 + j])
                        # pset.append(pool_features[temp2 + j])
                        # self.utterance_speaker_source.append(self.id_num)
                        # self.utterance_speaker_target.append(temp2 + j)
                h_Ave = list(np.mean(hset, axis=0))
                # p_Ave = list(np.mean(pset, axis=0))
                hidden_features.append(h_Ave)
                # pool_features.append(p_Ave)
                id_num = id_num + 1
                # self.features_list.append(Ave)
                # self.id_num = self.id_num + 1

            print("build_node_features...~~~~~~~~~~~~~~~~~~~~~~" + str(i) + "~~~~~~~~~~~~~~~~~~~~~~")
            # print(self.id_num)
        return hidden_features#, pool_features

    def build_node_edges(self):
        for i in range(len(self.data_list)):
            utterance_num = 0
            self.utte_id.append(self.id_num)
            # 每个对话语句结点
            for j in range(len(self.data_list[i])):
                node = {}
                node['id'] = self.id_num

                node['content'] = self.data_list[i][j]['text']
                node['type'] = "utterance"
                node['dataset'] = self.data_list[i][j]['dataset']
                if 'label' in self.data_list[i][j]:
                    node['label'] = self.data_list[i][j]['label']

                node['speaker'] = self.data_list[i][j]['speaker']
                self.node_list.append(node)
                # h, p = self.cal_feature(node['content'])
                # self.features_list.append(p) #特征计算
                self.id_num = self.id_num + 1
                utterance_num = utterance_num + 1
            temp = self.id_num - utterance_num
            self.utte_num.append(utterance_num)
            temp2 = temp

            # 事件结点
            for j in range(len(self.data_list[i])):
                if (len(self.data_list[i][j]['event']) != 0):
                    for k in range(len(self.data_list[i][j]['event'])):
                        event = self.data_list[i][j]['event'][k]

                        node = {}
                        node['id'] = self.id_num
                        node['content'] = event
                        node['type'] = "event"
                        # h, p = self.cal_feature(node['content'])
                        # self. features_list.append(p)
                        self.utterance_event_source.append(temp)
                        self.utterance_event_target.append(self.id_num)
                        self.node_list.append(node)
                        self.id_num = self.id_num + 1
                temp = temp + 1

            # 说话者结点
            speaker_set = set()
            for j in range(len(self.data_list[i])):
                speaker_set.add(self.data_list[i][j]['speaker'])
            speaker_set = list(speaker_set)
            self.speaker_num.append(len(speaker_set))
            self.speakers.append(speaker_set)

            for k in range(len(speaker_set)):
                node = {}
                node['id'] = self.id_num
                node['content'] = speaker_set[k]
                node['type'] = "speaker"
                self.node_list.append(node)
                # uset = []
                print(node['content'])
                for j in range(len(self.data_list[i])):
                    if (node['content'] == self.data_list[i][j]['speaker']):
                        # uset.append(self.features_list[temp2 + j])
                        self.utterance_speaker_source.append(self.id_num)
                        self.utterance_speaker_target.append(temp2 + j)
                # Ave = list(np.mean(uset, axis=0))
                # self.features_list.append(Ave)
                self.id_num = self.id_num + 1

            print("build_node...~~~~~~~~~~~~~~~~~~~~~~" + str(i) + "~~~~~~~~~~~~~~~~~~~~~~")
            print(self.id_num)



    #构建结点~
    def build_node(self):

        for i in range(len(self.data_list)):
            utterance_num = 0
            # event_num = 0
            self.utte_id.append(self.id_num)
            # 每个对话语句结点
            for j in range(len(self.data_list[i])):
                node = {}
                node['id'] = self.id_num

                node['content'] = self.data_list[i][j]['text']
                node['type'] = "utterance"
                node['dataset'] = self.data_list[i][j]['dataset']
                node['label'] = self.data_list[i][j]['label']
                node['speaker'] = self.data_list[i][j]['speaker']
                self.node_list.append(node)
                h, p = self.cal_feature(node['content'])
                self.features_list.append(p) #特征计算
                self.id_num = self.id_num + 1
                utterance_num = utterance_num + 1
            temp = self.id_num - utterance_num
            self.utte_num.append(utterance_num)
            temp2 = temp

            # 事件结点
            for j in range(len(self.data_list[i])):
                if (len(self.data_list[i][j]['event']) != 0):
                    for k in range(len(self.data_list[i][j]['event'])):
                        if (str(self.data_list[i][j]['event'][k]) != '[]'):
                            print(self.data_list[i][j]['event'][k])#
                            node = {}
                            node['id'] = self.id_num
                            node['content'] = str(self.data_list[i][j]['event'][k]).replace('[', '').replace(']', '')
                            node['type'] = "event"
                            h, p = self.cal_feature(node['content'])
                            self. features_list.append(p)
                            self.utterance_event_source.append(temp)
                            self.utterance_event_target.append(self.id_num)
                            self.node_list.append(node)
                            self.id_num = self.id_num + 1
                temp = temp + 1

            # 说话者结点
            speaker_set = set()
            for j in range(len(self.data_list[i])):
                speaker_set.add(self.data_list[i][j]['speaker'])
            speaker_set = list(speaker_set)
            self.speaker_num.append(len(speaker_set))
            self.speakers.append(speaker_set)

            for k in range(len(speaker_set)):
                node = {}
                node['id'] = self.id_num
                node['content'] = speaker_set[k]
                node['type'] = "speaker"
                self.node_list.append(node)
                uset = []
                print(node['content'])
                for j in range(len(self.data_list[i])):
                    if (node['content'] == self.data_list[i][j]['speaker']):
                        uset.append(self.features_list[temp2 + j])
                        self.utterance_speaker_source.append(self.id_num)
                        self.utterance_speaker_target.append(temp2 + j)
                Ave = list(np.mean(uset, axis=0))
                self.features_list.append(Ave)
                self.id_num = self.id_num + 1

            print("build_node...~~~~~~~~~~~~~~~~~~~~~~" + str(i) + "~~~~~~~~~~~~~~~~~~~~~~")
            print(self.id_num)

    # 语句之间边构造
    def utterance_edges(self):
        for i in range(len(self.utte_id)):  # 循环每一个对话
            for j in range(self.utte_id[i] + 1, self.utte_id[i] + self.utte_num[i]):  # 循环每个对话，在该对话第二个句子到最后一个句子作为目标节点
                speaker_temp = [-1 for n in range(self.speaker_num[i])]
                for k in range(self.utte_id[i], j):
                    p = self.speakers[i].index(self.node_list[k]['speaker'])  # 寻找说话者set的下标
                    if k > speaker_temp[p]:
                        speaker_temp[p] = k
                for k in range(len(speaker_temp)):
                    if speaker_temp[k] != -1:
                        self.utterance_source.append(speaker_temp[k])
                        self.utterance_target.append(j)
                        # 双向
                        # self.utterance_source.append(j)
                        # self.utterance_target.append(speaker_temp[k])
            print("utterance_edges.................."+str(i)+"......................")

    def build_edges(self):
        value1 = [1 for n in range(len(self.utterance_source))]
        utterance_utterance = csr_matrix((value1, (self.utterance_source, self.utterance_target)), shape=[self.id_num, self.id_num])
        self.edges.append(utterance_utterance)

        value2 = [1 for n in range(len(self.utterance_event_source))]
        utterance_event = csr_matrix((value2, (self.utterance_event_source, self.utterance_event_target)), shape=[self.id_num, self.id_num])
        self.edges.append(utterance_event)
        event_utterance = csr_matrix((value2, (self.utterance_event_target, self.utterance_event_source)), shape=[self.id_num, self.id_num])
        self.edges.append(event_utterance)

        value3 = [1 for n in range(len(self.utterance_speaker_source))]
        utterance_speaker = csr_matrix((value3, (self.utterance_speaker_source, self.utterance_speaker_target)),shape=[self.id_num, self.id_num])
        self.edges.append(utterance_speaker)
        speaker_utterance = csr_matrix((value3, (self.utterance_speaker_target, self.utterance_speaker_source)),shape=[self.id_num, self.id_num])
        self.edges.append(speaker_utterance)

    def build_EmoryNLP_label(self):
        # neutral, happiness,sadness, anger, surprise, disgust,fear.
        train_node = []
        val_node = []
        test_node = []
        for i in range(len(self.node_list)):
            if self.node_list[i]['type'] == "utterance":
                label_list = []
                label_list.append(self.node_list[i]['id'])
                if self.node_list[i]['label'] == 'Neutral':
                    label_list.append(0)
                elif self.node_list[i]['label'] == 'Joyful':
                    label_list.append(1)
                elif self.node_list[i]['label'] == 'Powerful':
                    label_list.append(2)
                elif self.node_list[i]['label'] == 'Mad':
                    label_list.append(3)
                elif self.node_list[i]['label'] == 'Sad':
                    label_list.append(4)
                elif self.node_list[i]['label'] == 'Scared':
                    label_list.append(5)
                else:
                    label_list.append(6)
                if self.node_list[i]['dataset'] == 'train':
                    train_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'dev':
                    val_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'test':
                    test_node.append(label_list)
            print("labeling(((((((((((((((((("+str(i)+"))))))))))))))))")
        return train_node, val_node, test_node

    def build_IEMOCAP_label(self):
        # neutral, happy, sad, angry, excited, frustrated
        train_node = []
        val_node = []
        test_node = []
        for i in range(len(self.node_list)):
            if self.node_list[i]['type'] == "utterance" and 'label' in self.node_list[i]:
                label_list = []
                label_list.append(self.node_list[i]['id'])
                if self.node_list[i]['label'] == 'neu':
                    label_list.append(0)
                elif self.node_list[i]['label'] == 'fru':
                    label_list.append(1)
                elif self.node_list[i]['label'] == 'sad':
                    label_list.append(2)
                elif self.node_list[i]['label'] == 'ang':
                    label_list.append(3)
                elif self.node_list[i]['label'] == 'hap':
                    label_list.append(4)
                else:
                    label_list.append(5)
                if self.node_list[i]['dataset'] == 'train':
                    train_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'dev':
                    val_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'test':
                    test_node.append(label_list)
            print("labeling(((((((((((((((((("+str(i)+"))))))))))))))))")
        return train_node, val_node, test_node

    def build_DailyDialog_label(self):
        # 'surprise', 'none', 'anger', 'sadness', 'happiness', 'fear', 'disgust'}
        train_node = []
        val_node = []
        test_node = []
        for i in range(len(self.node_list)):
            if self.node_list[i]['type'] == "utterance":
                label_list = []
                label_list.append(self.node_list[i]['id'])
                if self.node_list[i]['label'] == 'none':
                    label_list.append(0)
                elif self.node_list[i]['label'] == 'anger':
                    label_list.append(1)
                elif self.node_list[i]['label'] == 'disgust':
                    label_list.append(2)
                elif self.node_list[i]['label'] == 'fear':
                    label_list.append(3)
                elif self.node_list[i]['label'] == 'happiness':
                    label_list.append(4)
                elif self.node_list[i]['label'] == 'sadness':
                    label_list.append(5)
                elif self.node_list[i]['label'] == 'surprise':
                    label_list.append(6)

                if self.node_list[i]['dataset'] == 'train':
                    train_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'dev':
                    val_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'test':
                    test_node.append(label_list)
            print("labeling((((((((((((((((((" + str(i) + "))))))))))))))))")
        return train_node, val_node, test_node

    def build_node_label(self):
        # neutral, happiness,sadness, anger, surprise, disgust,fear.
        train_node = []
        val_node = []
        test_node = []
        for i in range(len(self.node_list)):
            if self.node_list[i]['type'] == "utterance":
                label_list = []
                label_list.append(self.node_list[i]['id'])
                if self.node_list[i]['label'] == 'neutral':
                    label_list.append(0)
                elif self.node_list[i]['label'] == 'happiness':
                    label_list.append(1)
                elif self.node_list[i]['label'] == 'sadness':
                    label_list.append(2)
                elif self.node_list[i]['label'] == 'anger':
                    label_list.append(3)
                elif self.node_list[i]['label'] == 'surprise':
                    label_list.append(4)
                elif self.node_list[i]['label'] == 'disgust':
                    label_list.append(5)
                else:
                    label_list.append(6)
                if self.node_list[i]['dataset'] == 'train':
                    train_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'dev':
                    val_node.append(label_list)
                elif self.node_list[i]['dataset'] == 'test':
                    test_node.append(label_list)
            print("labeling(((((((((((((((((("+str(i)+"))))))))))))))))")
        return train_node, val_node, test_node

    def build_label(self):
        # train_node, val_node, test_node = self.build_IEMOCAP_label()
        train_node, val_node, test_node = self.build_DailyDialog_label()
        # train_node, val_node, test_node = self.build_EmoryNLP_label()
        # train_node, val_node, test_node = self.build_node_label()
        train_label = np.asarray(train_node)
        self.labels.append(train_label)
        val_label = np.asarray(val_node)
        self.labels.append(val_label)
        test_label = np.asarray(test_node)
        self.labels.append(test_label)




if __name__=="__main__":
    datapre = Data_preprocess()
    edges = datapre.edges
    with open("newdata/DailyDialog/edges.pkl", "wb") as f:
        pickle.dump(edges, f)
    labels = datapre.labels
    with open("newdata/DailyDialog/labels.pkl", "wb") as f:
        pickle.dump(labels, f)

    hidden_features = datapre.build_node_features()
    hidden_features = np.asarray(hidden_features)
    with open("newdata/DailyDialog/node_features.pkl", "wb") as f:
        pickle.dump(hidden_features, f)
    # print('zzzzzzzz')
    # with open("newdata/EmoryNLP/pool_node_features.pkl", "wb") as f:
    #     pickle.dump(pool_features, f)


