#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6

# Author: Wei Li
# Email: wei008@e.ntu.edu.sg


import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, \
    classification_report, precision_recall_fscore_support

from torch.nn.utils.rnn import pad_sequence
from torch.nn import utils as nn_utils
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from dataloader_gcn_dis import IEMOCAPDataset
import torch.nn.functional as F
import argparse
from transformers import get_linear_schedule_with_warmup
from GCN_functions import batch_graphify, MaskedEdgeAttention,batch_emgraph
from torch_geometric.nn import RGCNConv
import pickle
import copy
import scipy.sparse as sp

from prepare_data import *
#np.random.seed(1234)
#torch.random.manual_seed(1234)
#torch.cuda.manual_seed(1234)


def get_train_valid_sampler(trainset, valid=0.1):
    size = len(trainset)
    idx = list(range(size))
    split = int(valid * size)
    return SubsetRandomSampler(idx[split:]), SubsetRandomSampler(idx[:split])


def get_IEMOCAP_loaders(path, batch_size=2, valid=0.0, num_workers=0, word_idx=None, max_sen_len=30, pin_memory=False):
    trainset = IEMOCAPDataset(path=path, word_idx=word_idx, max_sen_len=max_sen_len)
    train_sampler, valid_sampler = get_train_valid_sampler(trainset, valid)
    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              collate_fn=trainset.collate_fn,
                              num_workers=num_workers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              sampler=valid_sampler,
                              collate_fn=trainset.collate_fn,
                              num_workers=num_workers,
                              pin_memory=pin_memory)

    testset = IEMOCAPDataset(path=path, word_idx=word_idx, max_sen_len=max_sen_len, train=False)
    test_loader = DataLoader(testset,
                             batch_size=batch_size,
                             collate_fn=testset.collate_fn,
                             num_workers=num_workers,
                             pin_memory=pin_memory)

    return train_loader, valid_loader, test_loader


class MaskedNLLLoss(nn.Module):

    def __init__(self, weight=None):
        super(MaskedNLLLoss, self).__init__()
        self.weight = weight
        self.loss = nn.NLLLoss(weight=weight,
                               reduction='sum')

    def forward(self, pred, target, mask):
        """
        pred -> batch*seq_len, n_classes
        target -> batch*seq_len
        mask -> batch, seq_len
        """
        mask_ = mask.view(-1, 1)  # batch*seq_len, 1
        if type(self.weight) == type(None):
            loss = self.loss(pred * mask_, target) / torch.sum(mask)
        else:
            loss = self.loss(pred * mask_, target) \
                   / torch.sum(self.weight[target] * mask_.squeeze())
        return loss


class GraphNetwork(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=64, dropout=0.5,
                 relations=2, no_cuda=False):

        super(GraphNetwork, self).__init__()

        self.conv1 = RGCNConv(num_features, hidden_size, num_relations=relations, num_bases=30)

    def forward(self, x, edge_index, edge_type, edge_norm=None, seq_lengths=None, umask=None):
        out = self.conv1(x, edge_index, edge_type)
        return out


class  Attention(nn.Module):
    def __init__(self, n_hidden, sen_len):
        super(Attention, self).__init__()
        self.n_hidden = n_hidden
        self.sen_len = sen_len
        self.linear1 = nn.Linear(n_hidden, n_hidden)
        self.linear2 = nn.Linear(n_hidden, 1)

    def forward(self, x):
        '''
        input shape: [batch_size * doc_len, sen_len, 2 * n_hidden]
        output shape: [batch_size * doc_len, 2 * n_hidden]
        '''
        x_tmp = x.reshape(-1, self.n_hidden)
        # x_tmp is of shape (batch_size * doc_len * sen_len, 2 * n_hidden)
        u = torch.tanh(self.linear1(x_tmp))
        # u is of shape (batch_size * doc_len * sen_len, 2 * n_hidden)
        alpha = self.linear2(u)
        # alpha is of shape (batch_size * doc_len * sen_len, 1)
        alpha = F.softmax(alpha.reshape(-1, 1, self.sen_len), dim = -1)
        # alpha is of shape (batch_size * doc_len, 1, sen_len)
        x = torch.matmul(alpha, x).reshape(-1, self.n_hidden)
        # x is of shape (batch_size * doc_len, 2 * n_hidden)
        return x

class ECPEC(nn.Module):

    def __init__(self, input_dim, n_class, dropout, acti,sen_len):
        super(ECPEC, self).__init__()
        self.input_dim = input_dim
        self.n_class = n_class

        self.We = nn.Linear(2 * input_dim, input_dim)
        self.Wc = nn.Linear(2*input_dim, input_dim)

        self.Wsa = nn.Linear(2*input_dim, self.n_class)
        self.Wcd = nn.Linear(2*input_dim, self.n_class)
        self.Wecpe = nn.Linear(4*input_dim, self.n_class)

        self.Wem = nn.Linear(2*input_dim, self.n_class)


        self.egru = nn.LSTM(input_size=input_dim, hidden_size=int(input_dim / 2), num_layers=1, bidirectional=True)
        self.cgru = nn.LSTM(input_size=input_dim, hidden_size=int(input_dim / 2), num_layers=1, bidirectional=True)

        self.attention = Attention(input_dim, sen_len)

        self.wgru = nn.LSTM(input_size=input_dim, hidden_size=int(input_dim / 2), num_layers=2, batch_first=True,
                            bidirectional=True)

        if acti=='sigmoid':
            self.ac = nn.Sigmoid()
        elif acti=='relu':
            self.ac = nn.ReLU()
        elif acti=='tanh':
            self.ac = nn.Tanh()
        self.ac_linear = nn.ReLU()

        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout)
        self.dropouta = nn.Dropout(0.2)
        self.dropoutb = nn.Dropout(0.2)
        self.nocuda = False
        self.ln = torch.nn.LayerNorm([30,input_dim],elementwise_affine=True)


        self.graph_hidden_size = 200
        self.max_seq_len = 110
        self.relations = 2
        pos_emb_dim=50

        self.att = MaskedEdgeAttention(2 * input_dim, self.max_seq_len, self.nocuda)

        self.graph_net_ = GraphNetwork(2 * input_dim, self.n_class, self.graph_hidden_size, 0.5, self.relations, self.nocuda)
        self.graph_net_me = GraphNetwork(4 * input_dim+pos_emb_dim, self.n_class, 2 * self.graph_hidden_size, 0.5, self.relations, self.nocuda)
        self.wp = 10
        self.wf = 10


        

    def kernel_generator(self, rel_pos):
        n_couple = rel_pos.size(1)
        rel_pos_ = rel_pos[0].cuda()
        kernel_left = torch.LongTensor(torch.cat([rel_pos_.reshape(-1, 1)] * n_couple, dim=1))

        kernel = kernel_left - kernel_left.transpose(0, 1)
        return torch.exp(-(torch.pow(kernel, 2)))

    def _reverse_seq(self, X, mask):
        """
        X -> seq_len, batch, dim
        mask -> batch, seq_len
        """
        X_ = X.transpose(0, 1)
        mask_sum = torch.sum(mask, 1).int()

        xfs = []
        for x, c in zip(X_, mask_sum):
            xf = torch.flip(x[:c], [0])
            xfs.append(xf)

        return pad_sequence(xfs)

    def forward(self, x,pos_encode, sen_lengths, mask):
        """
        sen_lengths句子长度
        mask文件长度
        :param x:-->seq, batch, dim 
        batch, seq, sen_len, dim
        :return:
        """
        
        x = x.float()

        #batch, seq, sen_len, dim=x.size()
        seq_lengths = sen_lengths.squeeze().cpu()

        #x = x.reshape(-1, sen_len, dim)
        pack = nn_utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
        hw = torch.zeros((4, x.size(0), int(x.size(2) / 2))).cuda()
        cw = torch.zeros((4, x.size(0), int(x.size(2) / 2))).cuda()
        out,_ = self.wgru(pack,(hw, cw))
        unpacked = nn_utils.rnn.pad_packed_sequence(out, batch_first=True)
        index = torch.LongTensor([seq_lengths[i]+unpacked[0].size(1)*i-1 for i in range(unpacked[0].size(0))])
        U = unpacked[0].contiguous().view(-1, unpacked[0].size(2))[index].unsqueeze(1)

        #U =self.attention(out).unsqueeze(1)



        # the conversational sentiment analysis task
        he = torch.zeros((2, U.size(1), int(U.size(2) / 2))).cuda()
        ce = torch.zeros((2, U.size(1), int(U.size(2) / 2))).cuda()
        
        f1 = U
        f2 = torch.cat((U[0].unsqueeze(0), U[0:-1]), dim=0)
        conf = self.dropout(self.ac(self.We(torch.cat((f2, f1), dim=2))))
        forward, (hne,cne) = self.egru(conf, (he,ce))
        rever_U = self._reverse_seq(U, mask)

        fx1 = rever_U
        fx2 = torch.cat((rever_U[0].unsqueeze(0), rever_U[0:-1]), dim=0)
        conb = self.dropouta(self.ac(self.We(torch.cat((fx2, fx1), dim=2))))
        backward, (hne,cne) = self.egru(conb, (he,ce))
        backward = self._reverse_seq(backward, mask)
        ConE = torch.cat((forward, backward), dim=2)
        #
        #results_em = F.gumbel_softmax(self.Wem(ConE.squeeze(1)), tau=1, hard=False,dim=1)
        results_em = self.Wem(ConE.squeeze(1))
        pre_em = torch.argmax(results_em, 1)
        types_ = torch.ones(ConE.size(0), ConE.size(0)).cuda()
        vertex_, edge_, edge_type_ = batch_emgraph(ConE, mask, ConE.size(0),pre_em, type=types_)
        #vertex_ = self.dropoutb(vertex_)
        gcnf_ = self.graph_net_(vertex_, edge_, edge_type_.long())
        #
        emoe = gcnf_.unsqueeze(1)
        #emoc = gcnf_.unsqueeze(1)

        # conversational sentiment analysis task
        #feature_sa = self.transa(emoe)
        # pair
        doc_len,batch, feat_dim = emoe.size()
        P_left = torch.cat([emoe.squeeze(1)] * doc_len, dim=1)
        P_left = P_left.reshape(doc_len * doc_len, feat_dim)
        P_right = torch.cat([emoe.squeeze(1)] * doc_len, dim=0)
        P_ = torch.cat([P_left, P_right], dim=1).cuda()

        #base_idx = np.arange(1, doc_len + 1)
        #emo_pos = np.concatenate([base_idx.reshape(-1, 1)] * doc_len, axis=1).reshape(1, -1)[0]
        #cau_pos = np.concatenate([base_idx] * doc_len, axis=0)
        #rel_pos = cau_pos - emo_pos
        #rel_pos = torch.LongTensor(rel_pos + self.K).cuda()
        #rel_pos_emb = self.pos_layer(rel_pos)
        #kernel = self.kernel_generator(rel_pos.unsqueeze(0))
        #rel_pos_emb = torch.matmul(kernel, rel_pos_emb)
        P = torch.cat([P_, pos_encode], dim=1).cuda()

        matrix_size = len(pre_em)
        matrix = torch.LongTensor(np.zeros((matrix_size*matrix_size, matrix_size*matrix_size))).cuda()

        # 遍历情感子句和原因子句
        for i in range(matrix_size):
                # 情感子句为中心句子，与距离为3的（情感子句-原因子句）配对
            for j in range(matrix_size):
                if pre_em[i] == 1 and pre_em[j] == 1 :
                    #for z in range(max(i*matrix_size, i*matrix_size - 3), min(matrix_size*matrix_size, i*matrix_size +4)):
                    for z in range(max(0, i - 2), min(matrix_size, i +3)):
                        for x in range(max(0, j - 2), min(matrix_size, j +3)):
                            matrix[i*matrix_size+j,z*matrix_size+x] = 1
                            matrix[z*matrix_size+x,i*matrix_size+j] = 1
        for i in range(matrix_size*matrix_size):
            matrix[i,i] = 1
        tmp_coo = sp.coo_matrix(matrix.cpu())
        me_values = tmp_coo.data
        me_indices = np.vstack((tmp_coo.row,tmp_coo.col))
        me_v= torch.LongTensor(me_values).cuda()
        me_i = torch.LongTensor(me_indices).cuda()
        gcnf_me = self.graph_net_me(P, me_i, me_v)
        


        results_sa = torch.log_softmax(self.Wsa(emoe.squeeze(1)), dim=1)
        results_sa = results_sa.contiguous().view(-1, self.n_class)

        # the conversational cause utterance detection task
        #feature_cd = self.trancd(emoc)
        results_cd = torch.log_softmax(self.Wcd(emoe.squeeze(1)), dim=1)
        results_cd = results_cd.contiguous().view(-1, self.n_class)

        results_emd = torch.log_softmax(results_em, dim=1)
        results_emd = results_emd.contiguous().view(-1, self.n_class)

        results_ecpe = torch.log_softmax(self.Wecpe(gcnf_me), dim=1)
        results_ecpe = results_ecpe.contiguous().view(-1, self.n_class)

        return results_sa, results_cd,results_emd,results_ecpe


class Classifier(nn.Module):

    def __init__(self):
        super(Classifier, self).__init__()
        pass

    def forward(self, *input):
        pass


def evaluate(pred, label):
    # calculate precision
    num_proposed_pairs = np.sum(pred)
    res = pred + label
    num_correct_pairs = np.sum(res == 2)
    precision = float(num_correct_pairs) / num_proposed_pairs

    # calculate recall
    # suitable for training phase 3 and phase 2
    num_annotated_pairs = np.sum(label)

    recall = float(num_correct_pairs) / num_annotated_pairs

    # calculate f1
    f1 = 2 * precision * recall / (precision + recall)

    return round(precision, 4), round(recall, 4), round(f1, 4)
def train_or_eval_model(model, loss_function, loss_function_cd,loss_function_em,loss_function_ecpe,dataloader, epoch, prop,prop2, embedding,pos_embedding, optimizer=None, train=False):
    losses = []
    # conversational sentiment analysis
    predse = []
    labelse = []
    maskse = []

    # cause utterance detection
    predsc = []
    labelsc = []
    masksc = []

    # mid event
    predem=[]
    labelem=[]
    # ecpe
    predecpe=[]
    labelecpe=[]
    maskecpe = []
    # store the prediction results of two subtasks
    saTask = {}
    cdTask = {}
    emTask={}

    assert not train or optimizer != None
    if train:
        model.train()
    else:
        model.eval()

    for data in dataloader:
        if train:
            optimizer.zero_grad()
            best_model_wts = copy.deepcopy(model.state_dict())
        else:
            best_model_wts = copy.deepcopy(model.state_dict())

        textf, visuf, acouf, qmask, umask, label_e, label_c, causeLabel,ecpe_label, textid, sen_len,event_mid,rel_pos,vid_b =data
        
        #[d.cuda() for d in data[:-1]] if cuda else data[:-1]   
        umask=umask.cuda()
        label_e=label_e.cuda()
        label_c=label_c.cuda()
        ecpe_label=ecpe_label.cuda()
        textid=textid.cuda()
        sen_len=sen_len.cuda()
        event_mid=event_mid.cuda()
        rel_pos=rel_pos.cuda()
        word_encode = embedding(textid).squeeze()  # batch,seq,sen_len, dim 由于batch为1.squeeze()后为seq,sen_len, dim
        pos_encode = pos_embedding(rel_pos+120).squeeze()

        log_sa, log_cd,log_em,log_ecpe = model(word_encode,pos_encode, sen_len, umask)  # batch*seq_len, n_classes
        labelse_ = label_e.view(-1)  # batch*seq_len
        labelsc_ = label_c.view(-1)  # batch*seq_len
        labelevent_mid=event_mid.view(-1)
        ecpe_label_= ecpe_label.view(-1)


        loss_sa = loss_function(log_sa, labelse_, umask)
        loss_cd = loss_function_cd(log_cd, labelsc_, umask)
        loss_em=loss_function_em(log_em,labelevent_mid,umask)
        loss_ecpe=loss_function_ecpe(log_ecpe,ecpe_label_,torch.FloatTensor([1] * len(ecpe_label_)).unsqueeze(0).cuda())


        #loss = prop*loss_sa + prop2*loss_cd+(1-prop-prop2)*(-loss_em)
        loss = loss_sa + loss_cd+loss_em+loss_ecpe


        # conversational sentiment analysis
        pred_e = torch.argmax(log_sa, 1)  # batch*seq_len
        predse.append(pred_e.data.cpu().numpy())
        labelse.append(labelse_.data.cpu().numpy())
        maskse.append(umask.view(-1).cpu().numpy())

        # emotion utterance detection
        pred_c = torch.argmax(log_cd, 1)  # batch*seq_len
        predsc.append(pred_c.data.cpu().numpy())
        labelsc.append(labelsc_.data.cpu().numpy())
        masksc.append(umask.view(-1).cpu().numpy())

        # mid event
        pred_em = torch.argmax(log_em, 1)  
        predem.append(pred_em.data.cpu().numpy())
        labelem.append(labelevent_mid.data.cpu().numpy())

        # ecpe
        pred_ecpe= torch.argmax(log_ecpe, 1)  
        predecpe.append(pred_ecpe.data.cpu().numpy())
        labelecpe.append(ecpe_label_.data.cpu().numpy())
        maskecpe.append(torch.FloatTensor([1] * len(ecpe_label_)).view(-1).cpu().numpy())

        losses.append(loss.item() * maskse[-1].sum())
        saTask[data[-1][0]] = list(pred_e.data.cpu().numpy())
        cdTask[data[-1][0]] = list(pred_c.data.cpu().numpy())
        emTask[data[-1][0]] = list(pred_em.data.cpu().numpy())

        if train:
            loss.backward()
            optimizer.step()
            #scheduler.step()
    if predse != []:
        predse = np.concatenate(predse)
        labelse = np.concatenate(labelse)
        maskse = np.concatenate(maskse)
        # print(Counter(labels.tolist()))

        # cause utterance detection
        predsc = np.concatenate(predsc)
        labelsc = np.concatenate(labelsc)
        masksc = np.concatenate(masksc)
        # mid event
        predem = np.concatenate(predem)
        labelem = np.concatenate(labelem)

        predecpe = np.concatenate(predecpe)
        labelecpe = np.concatenate(labelecpe)
        maskecpe = np.concatenate(maskecpe)


    else:
        return float('nan'), float('nan'), float('nan'), float('nan'), float('nan'), [], [], [], float('nan'), float('nan'), float('nan'), float('nan')

    avg_loss = round(np.sum(losses) / np.sum(maskse), 4)
    avg_accuracy_e = round(accuracy_score(labelse, predse, sample_weight=maskse) * 100, 2)
    avg_fscore_e = round(f1_score(labelse, predse, sample_weight=maskse, average='weighted') * 100, 2)

    # cause utterance detection
    avg_accuracy_c = round(accuracy_score(labelsc, predsc, sample_weight=masksc) * 100, 2)
    avg_fscore_c = round(f1_score(labelsc, predsc, sample_weight=masksc, average='weighted') * 100, 2)


    avg_accuracy_md = round(accuracy_score(labelem, predem, sample_weight=maskse) * 100, 2)
    avg_fscore_md = round(f1_score(labelem, predem, sample_weight=masksc, average='weighted') * 100, 2)
    #avg_accuracy_ecpe = round(accuracy_score(labelecpe, predecpe, sample_weight=maskecpe) * 100, 2)
    #avg_fscore_ecpe = round(f1_score(labelecpe, predecpe, sample_weight=maskecpe, average='weighted') * 100, 2)
    p_avg_ecpe, r_avg_ecpe, f_avg_ecpe = evaluate(labelecpe, predecpe)


    #return avg_loss, avg_accuracy_e, avg_accuracy_c ,labelse, predse, labelsc, predsc, maskse,labelecpe,predecpe, maskecpe,avg_fscore_e, avg_fscore_c,saTask, cdTask,predem,labelem,emTask,p_avg_ecpe, r_avg_ecpe, f_avg_ecpe
    return avg_loss, avg_accuracy_e, avg_accuracy_c ,labelse, predse, labelsc, predsc, maskse,labelecpe,predecpe, maskecpe,avg_fscore_e, avg_fscore_c,avg_fscore_md,saTask, cdTask,predem,labelem,emTask,p_avg_ecpe, r_avg_ecpe, f_avg_ecpe,best_model_wts
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='does not use GPU')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')
    parser.add_argument('--l2', type=float, default=0.0001, metavar='L2',
                        help='L1 regularization weight')
    parser.add_argument('--rec-dropout', type=float, default=0.1,
                        metavar='rec_dropout', help='rec_dropout rate')
    parser.add_argument('--dropout', type=float, default=0.20, metavar='dropout',
                        help='dropout rate')
    parser.add_argument('--batchsize', type=int, default=1, metavar='BS',
                        help='batch size')
    parser.add_argument('--epochs', type=int, default=40, metavar='E',
                        help='number of epochs')
    parser.add_argument('--warmup_proportion', type=float, default=0.1, 
                        help='warmup proportion')
    parser.add_argument('--class-weight', action='store_true', default=False,
                        help='class weight')
    parser.add_argument('--activation', type=str, default='sigmoid',
                        help='activation function')
    parser.add_argument('--proportion', type=float, default=0.8, help='Proportion of Loss')
    parser.add_argument('--proportion2', type=float, default=0.1, help='Proportion of Loss')
    parser.add_argument('--max_sen_len', type=int, default=20, help='max sentence length')
    parser.add_argument('--tensorboard', action='store_true', default=False,
                        help='Enables tensorboard log')
    args = parser.parse_args()

    print(args)

    args.cuda = torch.cuda.is_available() and not args.no_cuda
    if args.cuda:
        print('Running on GPU')
    else:
        print('Running on CPU')

    batch_size = args.batchsize

    n_classes = 2
    cuda = args.cuda
    n_epochs = args.epochs
    dropout = args.dropout
    prop = args.proportion
    prop2 = args.proportion2
    acti = args.activation
    max_sen_len = args.max_sen_len

    D_m = 100

    model = ECPEC(D_m, n_classes, dropout, acti,max_sen_len)
    
    # word2vec loading

    w2v_path = '/home/btwang/ECPEC-main/Joint-GCN/key_words.txt'
    w2v_file = '/home/btwang/ECPEC-main/Dataset/glove.6B.100d.txt'
    word_idx_rev, word_id_mapping, word_embedding = load_w2v(D_m, w2v_path, w2v_file)
    word_embedding = torch.from_numpy(word_embedding)
    embedding = torch.nn.Embedding.from_pretrained(word_embedding, freeze=True).cuda()
    K=120
    pos_emb_dim=50
    pos_embedding = torch.nn.Embedding(2*K + 1, pos_emb_dim).cuda()

    if cuda:
        model.cuda()

    loss_weights = torch.FloatTensor([
        1 / 0.227883,
        1 / 0.772117,
    ])

    loss_weights_c = torch.FloatTensor([
        1 / 0.420654,
        1 / 0.579346,
    ])

    if args.class_weight:
        loss_function = MaskedNLLLoss(loss_weights.cuda() if cuda else loss_weights)
    else:
        loss_function = MaskedNLLLoss()
    if args.class_weight:
        loss_function_cd = MaskedNLLLoss(loss_weights_c.cuda() if cuda else loss_weights)
    else:
        loss_function_cd = MaskedNLLLoss()
    if args.class_weight:
        loss_function_em = MaskedNLLLoss(loss_weights_c.cuda() if cuda else loss_weights)
    else:
        loss_function_em = MaskedNLLLoss()

    if args.class_weight:
        loss_function_ecpe = MaskedNLLLoss(loss_weights_c.cuda() if cuda else loss_weights)
    else:
        loss_function_ecpe = MaskedNLLLoss()

    # params = ([p for p in model.parameters()] + [model.log_var_a] + [model.log_var_b])
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)



    train_loader, valid_loader, test_loader = \
        get_IEMOCAP_loaders(r'/home/btwang/ECPEC-main/Dataset/IEMOCAP_emotion_cause_features.pkl',
                            batch_size=batch_size,
                            valid=0.0,
                            num_workers=2,
                            word_idx=word_id_mapping,
                            max_sen_len=max_sen_len)
        #学习率预热
    #总step
    #num_steps_all = len(train_loader) // args.epochs
    ##需要预热的step
    #warmup_steps = int(num_steps_all * args.warmup_proportion)
    #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_steps_all)
    best_loss, best_label, best_pred, best_mask,best_f_avg_ecpe = None, None, None, None, None

    for e in range(n_epochs):
        start_time = time.time()
        #train_loss, train_acc, train_acc_c,_, _, _, _, _, _, _, _,train_fscore, train_fscore_c ,sa_pred, cd_pred,_,_,em_pred,train_p_avg_ecpe, train_r_avg_ecpe, train_f_avg_ecpe = train_or_eval_model(model,
        train_loss, train_acc, train_acc_c,_, _, _, _, _, _, _, _,train_fscore, train_fscore_c,train_fscore_md ,sa_pred, cd_pred,_,_,em_pred,train_p_avg_ecpe, train_r_avg_ecpe, train_f_avg_ecpe,train_best_model_wts = train_or_eval_model(model,
                                                                                                              loss_function,
                                                                                                              loss_function_cd,
                                                                                                              loss_function_em,
                                                                                                              loss_function_ecpe,
                                                                                                              train_loader,
                                                                                                              e,
                                                                                                              prop,
                                                                                                              prop2,
                                                                                                              embedding,
                                                                                                              pos_embedding,
                                                                                                              optimizer,
                                                                                                              True)
        #valid_loss, valid_acc, valid_acc_c,_, _, _, _, _, _, _, _, _,val_fscore, val_fscore_c,_, _, _ = train_or_eval_model(model,
        #                                                                                                  loss_function,
        #                                                                                                  loss_function_cd,
        #                                                                                                  loss_function_em,
        #                                                                                                  loss_function_ecpe,
        #                                                                                                  valid_loader,
        #                                                                                                  e,
        #                                                                                                  prop,
        #                                                                                                  prop2,
        #                                                                                                  embedding)
        #test_loss, test_acc, test_acc_c ,test_labele, test_prede, test_labelc, test_predc, test_mask, test_labelecpe, test_predecpe, test_mask_ecpe,test_fscore, test_fscore_c,saTask, cdTask,test_predem,test_labelem,emTask,test_avg_ecpe, test_r_avg_ecpe, test_f_avg_ecpe = train_or_eval_model(
        test_loss, test_acc, test_acc_c ,test_labele, test_prede, test_labelc, test_predc, test_mask, test_labelecpe, test_predecpe, test_mask_ecpe,test_fscore, test_fscore_c, test_fscore_md,saTask, cdTask,test_predem,test_labelem,emTask,test_avg_ecpe, test_r_avg_ecpe, test_f_avg_ecpe,test_best_model_wts = train_or_eval_model(
            model, loss_function, loss_function_cd,loss_function_em,loss_function_ecpe, test_loader, e, prop,prop2, embedding,pos_embedding)
        
        #if best_loss == None or best_loss > test_loss:
        #    best_loss, best_labele, best_prede, best_labelc, best_predc, best_mask, best_labelecpe, best_predecpe, best_mask_cepe, saTrain, cdTrain,emTrain ,best_sa, best_cd,best_em,best_avg_ecpe, best_r_avg_ecpe, best_f_avg_ecpe = \
        #        test_loss, test_labele, test_prede, test_labelc, test_predc, test_mask,test_labelecpe, test_predecpe, test_mask_ecpe, sa_pred, cd_pred,em_pred, saTask, cdTask,emTask,test_avg_ecpe, test_r_avg_ecpe, test_f_avg_ecpe

        if best_f_avg_ecpe == None or test_f_avg_ecpe > best_f_avg_ecpe:
            #best_loss, best_labele, best_prede, best_labelc, best_predc, best_mask, best_labelecpe, best_predecpe, best_mask_cepe, saTrain, cdTrain,emTrain ,best_sa, best_cd,best_em,best_avg_ecpe, best_r_avg_ecpe, best_f_avg_ecpe = \
            #    test_loss, test_labele, test_prede, test_labelc, test_predc, test_mask,test_labelecpe, test_predecpe, test_mask_ecpe, sa_pred, cd_pred,em_pred, saTask, cdTask,emTask,test_avg_ecpe, test_r_avg_ecpe, test_f_avg_ecpe
            best_loss, best_labele, best_prede, best_labelc, best_predc, best_mask,best_predem,best_labelem,best_labelecpe, best_predecpe, best_mask_cepe, saTrain, cdTrain,emTrain ,best_sa, best_cd,best_em,best_avg_ecpe, best_r_avg_ecpe, best_f_avg_ecpe,best_model_wts = \
                test_loss, test_labele, test_prede, test_labelc, test_predc, test_mask,test_predem,test_labelem,test_labelecpe, test_predecpe, test_mask_ecpe, sa_pred, cd_pred,em_pred, saTask, cdTask,emTask,test_avg_ecpe, test_r_avg_ecpe, test_f_avg_ecpe,test_best_model_wts
        valid_loss, valid_acc, val_fscore=0,0,0
        print('epoch {} train_loss {} train_acc {} train_acc_c {} train_fscore{} train_fscore_c {} train_fscore_md {} train_fscore_ecpe {} valid_loss {} valid_acc {} val_fscore {} test_loss {} test_acc {} test_acc_c {} test_fscore {} test_fscore_c {} test_fscore_md {} test_fscore_ecpe {} time {}'. \
            format(e + 1, train_loss, train_acc, train_acc_c, train_fscore, train_fscore_c,train_fscore_md,train_f_avg_ecpe, valid_loss, valid_acc, val_fscore, \
                   test_loss, test_acc, test_acc_c, test_fscore, test_fscore_c,test_fscore_md,test_f_avg_ecpe, round(time.time() - start_time, 2)))

    print('Test performance..', '\n')
    print('eeLoss {} accuracy {}'.format(best_loss,
                                       round(accuracy_score(best_labele, best_prede, sample_weight=best_mask) * 100,
                                             2)),'\n')
    print(classification_report(best_labele, best_prede, sample_weight=best_mask, digits=4),'\n')
    print(confusion_matrix(best_labele, best_prede, sample_weight=best_mask),'\n')

    print('ceLoss {} accuracy {}'.format(best_loss,
                                       round(accuracy_score(best_labelc, best_predc, sample_weight=best_mask) * 100,
                                             2)),'\n')
    print(classification_report(best_labelc, best_predc, sample_weight=best_mask, digits=4),'\n')
    print(confusion_matrix(best_labelc, best_predc, sample_weight=best_mask),'\n')



    print('midLoss {} accuracy {}'.format(best_loss,
                                       round(accuracy_score(best_labelem, best_predem, sample_weight=best_mask) * 100,
                                             2)),'\n')
    print(classification_report(best_labelem, best_predem, sample_weight=best_mask, digits=4),'\n')
    print(confusion_matrix(best_labelem, best_predem, sample_weight=best_mask),'\n')

    #print('Loss {} accuracy {}'.format(best_loss,
    #                                   round(accuracy_score(best_labelecpe, best_predecpe, sample_weight=best_mask_cepe) * 100,
    #                                         2)),'\n')
    #print(classification_report(best_labelecpe, best_predecpe, sample_weight=best_mask_cepe, digits=4),'\n')
    #print(confusion_matrix(best_labelecpe, best_predecpe, sample_weight=best_mask_cepe),'\n')


    print('ecpeLoss {} precision {} recall {} fscore {} '.format(best_loss, best_avg_ecpe, best_r_avg_ecpe, best_f_avg_ecpe))
    path = '/home/btwang/ECPEC-main/pkl_gcn_full/cccn_model.pth'
    torch.save(best_model_wts, path)
    #path_out = '/home/btwang/ECPEC-main/Dataset/IEMOCAP_emotion_cause_features.pkl'
    #videoIDs, videoSpeakers, videoLabels, causeLabels, causeLabels2, causeLabels3, videoText, \
    #videoAudio, videoVisual, videoSentence, trainVid, \
    #testVid = pickle.load(open(path_out, 'rb'), encoding='latin1')
#
#
    #path = '/home/btwang/ECPEC-main/pkl_gcn_full/ECPEC_phase_two_gcn_' + str(args.dropout) + '_' + str(args.proportion) + '_' + str(
    #    args.activation) + 'me1.pkl'
    #f = open(path, 'wb')
    #data = [videoIDs, videoSpeakers, videoLabels, causeLabels, causeLabels2, causeLabels3, saTrain, cdTrain,emTrain, best_sa,
    #        best_cd,best_em, videoText, videoAudio, videoVisual, videoSentence, trainVid, testVid]
    #pickle.dump(data, f)
    #f.close()
