import torch
from transformers import BartForConditionalGeneration # transformers.src.transfomers.models.
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,


)
import torch.nn.functional as F

import torch
import torch.nn as nn
from bart_classify_model.attention import BertAttention
from transformers import BertConfig

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class SelfAttentionLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(SelfAttentionLayer, self).__init__()
        

        self.mha = nn.MultiheadAttention(d_model, num_heads,batch_first=True)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model)
        )
        

        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)
        
    def forward(self, q,k,v, key_padding_mask = None, attn_mask = None):

        attn_output, _ = self.mha(q,k,v, key_padding_mask = key_padding_mask, attn_mask = attn_mask)
        # attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(q + attn_output)

        ff_output = self.feed_forward(out1)
        # ff_output = self.dropout2(ff_output)
        out2 = self.layernorm2(out1 + ff_output)

        # attn_output, _ = self.mha(q,k,v, key_padding_mask = key_padding_mask, attn_mask = attn_mask)
        # # attn_output = self.dropout1(attn_output)

 
        # ff_output = self.feed_forward(q + attn_output)
        # # ff_output = self.dropout2(ff_output)

        
        return out2






class classify_model(BartForConditionalGeneration):
    def __init__(self, config, args=None, **kwargs):
        super().__init__(config)

        tag_hidden_dim = 1024 #1024 1280
        self.class_num = 2
        self.tagging_layer = nn.Sequential(
                            
                            nn.Linear(tag_hidden_dim , tag_hidden_dim* 4 ), #
                            nn.Dropout(0.1),
                            nn.ReLU(),                        

                            nn.Linear(tag_hidden_dim* 4, 4 ),
                            )
        
        topic_hidden_state = 2048
        # self.topic_layer = nn.Sequential(
        #                     nn.Linear(topic_hidden_state, topic_hidden_state* 2 ), #
        #                     nn.ReLU(),
        #                     nn.Linear(topic_hidden_state* 2, 2 ), 
        #                     )
        self.config = config
        d_model = config.d_model
        num_heads = config.encoder_attention_heads
        dff = 2 * config.d_model
        rate = config.dropout
        num_layers = 4
        self.topic_layer = nn.ModuleList( [SelfAttentionLayer( d_model, num_heads, dff, rate ) for x in range(args.cross_attention_layer_num)] ) #
        # config = BertConfig(hidden_size=1024, vocab_size= 32128, max_position_embeddings=600,num_attention_heads  =16, is_decoder =True)
        # self.topic_layer = BertAttention(config)

        self.topic_classify_layer = nn.Sequential(
                            # nn.Dropout(0.1),
                            nn.Linear(1024, 1024* 4 ), #
                            nn.Dropout(0.1),
                            nn.ReLU(),
                            nn.Linear(1024* 4, 2 ), 
                            )
        # self.decoder_layers = [SelfAttentionLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]


        self.preference_one_fn = torch.nn.CrossEntropyLoss( torch.tensor([1.0, 2.0,2.0,2.0]) )
        # Initialize weights and apply final processing
        self.post_init()

    def prepare_inputs_for_generation( 
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,

    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
            #
            "path_attention_mask": kwargs["path_attention_mask"],
            # "tokenizer": kwargs["tokenizer"],
        }
    
    def topic_matrix_decode(
            self,
            topic_list ,
            sentence_id_mask,
    ):
        
        all_topic_attention_mask = []
        for each_topic_struct in topic_list:


            # print("each_topic_struct",each_topic_struct)
            

            for n, sent_id in enumerate( each_topic_struct ):
                # print("sentence_id_mask",sentence_id_mask)
                if n == 0:
                    this_path_attention_mask = ( sentence_id_mask == sent_id).int()
                else:
                    this_path_attention_mask += ( sentence_id_mask == sent_id).int()

            # print("all_tokenized_input[][index]",all_tokenized_input["input_ids"][index])
            this_path_attention_mask [-1] = 1
            this_path_attention_mask[0] = 1 

            all_topic_attention_mask.append (this_path_attention_mask.tolist() ) 

        return all_topic_attention_mask

    def topic_decode(
            self,
            predict_topic_sentence_matrix,
    ):
        topic_related_sentence_matrix = []
        # print("predict_topic_sentence_matrix",predict_topic_sentence_matrix.shape)
        max_topic_num , max_sentence_num = predict_topic_sentence_matrix.shape
        for each_topic_idx in range(max_topic_num):
            this_sentence_related_sentences = []
            # print("predict_topic_matrix[each_sentence_idx]",predict_topic_sentence_matrix[each_topic_idx])
            for i, related in enumerate(  predict_topic_sentence_matrix[each_topic_idx] ):
                if related == 1:
                    this_sentence_related_sentences.append( i ) 


            if this_sentence_related_sentences != []:
                topic_related_sentence_matrix.append( this_sentence_related_sentences )

        # print("topic_related_sentence_matrix",topic_related_sentence_matrix)

        # topic_related_sentence_matrix = list ( set (topic_related_sentence_matrix) )
        topic_related_sentence_matrix = [sorted(list(x) ) for x in set(tuple(x) for x in topic_related_sentence_matrix)]
        topic_related_sentence_matrix = sorted(topic_related_sentence_matrix, key=lambda x: len(x), reverse=False)


        pop_list = []
        for i in range(len(topic_related_sentence_matrix)):
            for j in range( i+1, len(topic_related_sentence_matrix)):
                all_in = True
                for each_element in topic_related_sentence_matrix[j]:
                    if each_element not in topic_related_sentence_matrix[i]:
                        all_in = False 

                if all_in: 
                    pop_list.append(j) 
                    break #i

        pop_list = sorted(list( set( pop_list )), reverse=True)  
        print("pop_list",pop_list)

        for x in pop_list:
            topic_related_sentence_matrix.pop(x)


        # topic_list = [[int(char) for char in string] for string in sentence_related_matrix]
        topic_list = topic_related_sentence_matrix
        print("topic_list",topic_list)
        return topic_list
    
    def delete_replicate( self, topic_list):
        new_topic_struct = []
        while [] in topic_list:
            topic_list.remove([])

        topic_list =  sorted(topic_list, key=lambda x: len(x), reverse=False)

        for m, each_topic_struct in enumerate(topic_list):
            flag = True
            for k in range( m +1, len(topic_list)):
                # print("topic_list[m]",topic_list[m])
                # print("topic_list[k]",topic_list[k])
                if set(topic_list[m] ) .issubset( set(topic_list[k])  ): #，，
                    flag = False
                    break
            if flag:
                new_topic_struct.append( each_topic_struct )

        return new_topic_struct

    def predict_topic(
        self,
        tokenizer,
        input_ids = None,
        attention_mask= None,
        sentence_mask=None,
        reply_as_topic = None,
        tagging_label = None,
        new_topic_struct=None,
        max_sentence_num = None,
        ancient_list = None,

    ):

        encoder = self.get_encoder()

        ##########################################################################



        encoder_outputs = encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,

            # output_hidden_states=output_hidden_states,
            return_dict=True,

        )

        encoder_hidden_state = encoder_outputs[0]
        seq_len, dim = encoder_hidden_state.shape[-2:]

        # total_right = 0
        """  """
        # next_concat_encoder_hidden_state = torch.cat( (encoder_hidden_state, torch.roll(encoder_hidden_state, shifts=-1, dims=-2) ), dim=-1)
        # tagged_topic = self.tagging_layer(next_concat_encoder_hidden_state).argmax(dim=-1) .squeeze()
        
        tagged_topic = self.tagging_layer(encoder_hidden_state).argmax(dim=-1) .squeeze()
        # print("input_ids",input_ids.shape)
        # print("(tagged_topic > 0)[0]",(tagged_topic > 0)[0])
        # print("input_ids[0]",input_ids)
        # print( (tagged_topic > 0)[0] * input_ids[0] )
        tag_word = tokenizer.decode( (tagged_topic > 0) * input_ids[0] )
        print("tag_word",tag_word)
        # has_mask = (tagging_label > 0 ).int()
        # total_right = ((tagged_topic == has_mask ).int() * has_mask ).sum()
        my_tagging_label = ((tagging_label > 0) & (tagging_label < 100)) * 1 + ((tagging_label > 100) & (tagging_label < 200)) * 2 + ((tagging_label > 200) ) * 3 #+ (tagging_label == -100) * -100
        has_mask = (tagging_label > 0 ).int()
        total_right = ((tagged_topic == my_tagging_label ).int()  ).sum()
        print("total_right",total_right)
        
        # tagged_topic = tagging_label
        # print("tagged_topic",tagged_topic)
        # print("input_ids",input_ids)

        # print("encoder_hidden_state",encoder_hidden_state.shape)
        encoder_hidden_state = encoder_hidden_state.repeat(max_sentence_num, 1, 1).reshape( max_sentence_num, seq_len, dim )

        tagged_topic_mask = (tagged_topic  > 0).int() 
        tagged_topic_mask = tagged_topic_mask.repeat( max_sentence_num, 1).reshape( max_sentence_num, seq_len )
        q = (encoder_hidden_state * tagged_topic_mask.unsqueeze(dim=-1) )  .reshape( max_sentence_num,seq_len, dim ) #


        sentence_max_num = max_sentence_num #

        special_token = tokenizer.mask_token # "<mask>"
        special_token_id = tokenizer.encode(special_token)[1]
        if special_token_id != tokenizer.mask_token_id:
            print("special_token_id")
        all_sentence_mask = torch.zeros( [sentence_max_num , seq_len]).to( encoder_hidden_state.device)
        for sent_num in range(sentence_max_num): #
            all_sentence_mask[sent_num] = (sentence_mask == sent_num).int()
            all_sentence_mask[sent_num][0] = 1
            all_sentence_mask[sent_num][-1] = 1
            all_sentence_mask[sent_num] += (input_ids == special_token_id).int().squeeze()
            all_sentence_mask[sent_num] = (all_sentence_mask[sent_num] > 0 ).int()
            # if sent_num != 0 :
            #     ancient_sent_id = ancient_list[sent_num][-2]
            #     all_sentence_mask[sent_num] += (sentence_mask == ancient_sent_id).int()

        # print("all_sentence_mask",all_sentence_mask)
        k = v = (encoder_hidden_state * all_sentence_mask.unsqueeze(dim=-1)) .reshape( max_sentence_num,seq_len, dim )


        cross_atten_mask = tagged_topic_mask.unsqueeze(dim=-1) * all_sentence_mask.unsqueeze(dim=-2)
        # print("cross_atten_mask1",cross_atten_mask.shape)
        cross_atten_mask = cross_atten_mask.reshape(-1,seq_len, seq_len)
        # print("cross_atten_mask2",cross_atten_mask.shape)
        cross_atten_mask = cross_atten_mask.repeat(1, self.config.encoder_attention_heads , 1).reshape(-1,seq_len, seq_len)



        cross_atten_mask2 = tagged_topic_mask.unsqueeze(dim=-1) * tagged_topic_mask.unsqueeze(dim=-2)
        # print("cross_atten_mask1",cross_atten_mask.shape)
        cross_atten_mask2 = cross_atten_mask2.reshape(-1,seq_len, seq_len)
        # print("cross_atten_mask2",cross_atten_mask.shape)
        cross_atten_mask2 = cross_atten_mask2.repeat(1, self.config.encoder_attention_heads , 1).reshape(-1,seq_len, seq_len)
        # for m in self.topic_layer:
        #     q = m(q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)

        # q = self.topic_layer[1](q,q,q, attn_mask = cross_atten_mask2.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1) #.reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = self.topic_layer[3](q,q,q, attn_mask = cross_atten_mask2.float() ) #.reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)

        # q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = torch.mean(q.reshape(1, sentence_max_num,seq_len,dim), dim=1)
        # # print("tagged_topic_mask",tagged_topic_mask.shape)
        # # print("q",q.shape)
        # q = self.topic_layer[1](q,q,q ,key_padding_mask = tagged_topic_mask[0].unsqueeze(dim=0).float()) 
        # q = q.unsqueeze(dim=1).repeat(1,sentence_max_num,1,1).reshape(1*sentence_max_num, seq_len, dim)
        # q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)

        q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = torch.mean(q.reshape(1, sentence_max_num,seq_len,dim), dim=1)
        q = self.topic_layer[1](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
        # q = q.unsqueeze(dim=1).repeat(1,sentence_max_num,1,1).reshape(1*sentence_max_num, seq_len, dim)
        q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1)


        # for m in self.topic_layer:
        #     q = m(q,k,v, attn_mask = cross_atten_mask).reshape(max_sentence_num ,seq_len, dim) * tagged_topic_mask.reshape(-1,seq_len).unsqueeze(dim=-1) 
        # attention_output = self.topic_layer()
        # attention_output = self.topic_layer(hidden_states = q, encoder_hidden_states = k)

        attention_output = q.reshape(max_sentence_num ,seq_len, dim) * tagged_topic_mask.unsqueeze(dim=-1)  #


        sentence_tpoic_output = self.topic_classify_layer( attention_output )


        predict_topic_sentence_matrix = sentence_tpoic_output.argmax(dim=-1).squeeze(dim=0)  #* tagged_topic_mask.unsqueeze(dim=-1) 

        # print("predict_topic_sentence_matrix",predict_topic_sentence_matrix)


        # tagged_topic = (predict_topic_sentence_matrix.sum(dim=-2) > 0).int()

        diff = torch.diff((tagged_topic  == 1).int()) #tagged_topic[1:] - tagged_topic[:-1]
        topic_start_point = torch.nonzero(diff == 1)  
        topic_start_point = topic_start_point + 1
        print("topic_start_point",topic_start_point)
        topic_end_point = torch.nonzero(diff == -1)  
        topic_end_point = topic_end_point + 1 

        # print("predict_topic_sentence_matrix",predict_topic_sentence_matrix)

        # topic_word_list = []
        # topic_word_pos_dict = {}
        # for each_start_point_idx , (start_pos, end_pos)  in enumerate( zip(topic_start_point, topic_end_point) ):

        #     each_topic_word = tokenizer.decode( input_ids[0][start_pos :  end_pos], skip_special_tokens=True).replace(" ", "").lower() 

        #     contain_name = each_topic_word
        #     for dict_topic_word in topic_word_pos_dict.keys():
        #         if (contain_name  in  dict_topic_word or dict_topic_word in contain_name): 
        #             contain_name = dict_topic_word
                     
        #     if contain_name not in topic_word_pos_dict.keys(): 
        #         topic_word_pos_dict[contain_name] = [ each_start_point_idx ] 
        #     else:
        #         topic_word_pos_dict[contain_name] .append(  each_start_point_idx  ) 


        topic_list = [ list(range(max_sentence_num)) for _ in topic_start_point] 
        for sentence_id, each_sentence_topic_relation in  enumerate(predict_topic_sentence_matrix):
            for each_start_point_idx , (start_pos, end_pos)  in enumerate( zip(topic_start_point, topic_end_point) ):
                
                # print("each_sentence_topic_relation[start_pos :  end_pos ]",each_sentence_topic_relation[start_pos :  end_pos ])
                if 1 not in each_sentence_topic_relation[start_pos :  end_pos ]:
                    topic_list[each_start_point_idx ].remove( sentence_id )
                    # print("topic_list[each_start_point_idx ]",topic_list[each_start_point_idx ])
        print("topic_list",topic_list)

        # print("topic_word_pos_dict",topic_word_pos_dict)
        # same_name_topic_list = []
        # for same_word_topic_idx in range( len( topic_word_pos_dict)):
        #     temp_list = []
        #     for same_word_topic_id  in  list(topic_word_pos_dict.values())[same_word_topic_idx] : #topic_word_pos_dict.items()[same_word_topic_idx]:
        #         temp_list.extend( topic_list[same_word_topic_id])


        #     temp_list = list(set(temp_list))
        #     same_name_topic_list.append(temp_list)
        # print("same_name_topic_list",same_name_topic_list)
        # topic_list = self.delete_replicate(same_name_topic_list)

        topic_list = self.delete_replicate(topic_list)

        #
        # predict_topic_matrix[0][0] = 1 

        # predict_topic_matrix[0][2] = 1 
        # predict_topic_matrix[2][0] = 1 

        # predict_topic_matrix[3][2] = 1
        # predict_topic_matrix[2][2] = 1
        # predict_topic_matrix[2][3] = 1
        # predict_topic_matrix[3][3] = 1
        # topic_list = self.topic_decode( predict_topic_sentence_matrix )
        #label
        # topic_list = new_topic_struct
        print(",",topic_list)
        print("", new_topic_struct)


        if topic_list == []:
            print("")
            topic_list = reply_as_topic

        all_topic_attention_mask = self.topic_matrix_decode( topic_list ,sentence_mask )
        all_topic_attention_mask = torch.tensor(all_topic_attention_mask)

        return all_topic_attention_mask, total_right,topic_list

    def forward(
        self,
        input_ids = None,
        attention_mask= None,
        decoder_input_ids = None,
        decoder_attention_mask = None,
        head_mask = None,
        decoder_head_mask = None,
        cross_attn_head_mask = None,
        encoder_outputs = None,
        past_key_values = None,
        inputs_embeds = None,
        decoder_inputs_embeds = None,
        labels = None,
        use_cache = None,
        output_attentions= None,
        output_hidden_states= None,
        return_dict= None,
        speaker_attention_mask=None,
        reply_attention_mask=None,

        dialog_mask = None,
        reply_label=None,

        path_attention_mask=None, 
        tokenizer = None,

        sentence_mask=None,
        all_topic_matrix_label = None,
        tagging_label = None,

        all_ancient_list = None,
        all_sentence_topic_mask = None,

        all_sentence_cross_attention_mask = None,
        all_topic_cross_attention_mask = None,

        tag_ratio = None,
        generation_ratio = None,
        topic_ratio = None,

        is_train = False,
        triplets_num=None,
        ce_weight = None,
    ) :
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict


        # if is_train:        
        #     print("labels",labels)
        if labels is not None:

            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

            encoder = self.get_encoder()
            decoder = self.get_decoder()

            ##########################################################################

            if len(input_ids.shape) == 1:
                input_ids = input_ids.unsqueeze(dim = 0)
                # attention_mask = attention_mask.unsqueeze(dim = 0)

            #copy
            # input_ids = input_ids.repeat(path_attention_mask.shape[0],1).reshape(path_attention_mask.shape[0],input_ids.shape[-1])
            # #attention mask
            # attention_mask = attention_mask.repeat(path_attention_mask.shape[0],1).reshape(path_attention_mask.shape[0],input_ids.shape[-1])

            if decoder_input_ids is None and decoder_inputs_embeds is None:
                if input_ids is None:
                    raise ValueError(
                        "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                        "passed, `input_ids` cannot be `None`. Please pass either "
                        "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                    )

                decoder_input_ids = shift_tokens_right( #teacher forcing
                    input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
                )

            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            use_cache = use_cache if use_cache is not None else self.config.use_cache
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict


            encoder_outputs = encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                # reply_attention_mask=reply_attention_mask,
                # speaker_attention_mask=speaker_attention_mask,    
            )

            if labels is not None:
                encoder_hidden_state = encoder_outputs[0]
                tag_result = self.tagging_layer(encoder_hidden_state)
                prob = 1 - (F.softmax(tag_result,dim = -1)[:,:,0])
                # print("tag_result",tag_result.shape)
                # print("prob",prob.shape)
            # bs, seq_len, dim = encoder_outputs[0].shape
            # decoder_input = encoder_outputs[0].repeat(path_attention_mask.shape[0],1,1).reshape(path_attention_mask.shape[0],seq_len,dim)

            # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
            decoder_outputs = decoder(
                input_ids=decoder_input_ids, #teacher forcing
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=encoder_outputs[0], #decoder_input,
                encoder_attention_mask=path_attention_mask, #path_attention_mask. ，
                head_mask=decoder_head_mask,
                cross_attn_head_mask=cross_attn_head_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                prob = prob,
                path_attention_mask = path_attention_mask,
            )

            # if not return_dict:
            #     return decoder_outputs + encoder_outputs


            outputs = Seq2SeqModelOutput(
                last_hidden_state=decoder_outputs.last_hidden_state,
                past_key_values=decoder_outputs.past_key_values,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )


            ####################################################################################################
            lm_logits = self.lm_head(outputs[0])
            # print("lm_logits",lm_logits.shape)
            lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

            # output = lm_logits.argmax(dim=-1)
            # # print("path_attention_mask",path_attention_mask[0])
            # labels[labels[:, :] == -100 ] = tokenizer.pad_token_id
            # for x in range( len ( output ) ):
            #     print("", tokenizer.decode(input_ids[x] ))
            #     # print("", tokenizer.decode( output[x] ))
            #     print("",tokenizer.decode(labels[x]) )
            
            # # for x in labels:
                

            # labels[labels[:, :] == tokenizer.pad_token_id] = -100

            masked_lm_loss = None
            # print("")
            if labels is not None:
                
                labels = labels.to(lm_logits.device)
                loss_fct = torch.nn.CrossEntropyLoss( )

                # loss_fct2 = torch.nn.CrossEntropyLoss(reduction='none')
                # loss_list  = loss_fct2(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)).reshape(labels.shape[0],labels.shape[1] )
                # # print("loss_list",loss_list)
                # # print("loss_list",loss_list.shape)
                # # print("ce_weight",ce_weight.shape)
                # loss1 = (loss_list * ce_weight).sum(dim=-1).sum(dim=-1) / (labels.shape[0] * labels.shape[1]) #
                # # print("loss1",loss1)
                # masked_lm_loss =  loss1
                masked_lm_loss = generation_ratio * loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
                # print("masked_lm_loss",masked_lm_loss)


                

                
                #encoder output
                #.detach().clone()

                #one stage
                """  """
                #，topicloss
                # next_concat_encoder_hidden_state = torch.cat( (encoder_hidden_state, torch.roll(encoder_hidden_state, shifts=-1, dims=-2) ), dim=-1) #[3, 235, 1024]
                # tag_result = self.tagging_layer(next_concat_encoder_hidden_state)
                

                my_tagging_label = ((tagging_label > 0) & (tagging_label < 100)) * 1 + ((tagging_label > 100) & (tagging_label < 200)) * 2 + ((tagging_label > 200) ) * 3 + (tagging_label == -100) * -100
                #             1                   -100
                # print("my_tagging_label",my_tagging_label[0])
                # print("",tokenizer.decode(  (input_ids   * (tagging_label > 0).int())[0] ))
                tag_loss = self.preference_one_fn( tag_result.reshape(-1,4) , my_tagging_label.reshape(-1)) #tagging_label-100，tagging_label > 0

                masked_lm_loss += tag_ratio * tag_loss
                # print("tag_loss",tag_loss.item())
                
                

                #loss
                bs, seq_len , dim = encoder_hidden_state.shape
                # print("all_sentence_topic_mask",all_sentence_topic_mask.shape) #[3, 10, 235]
                # print("encoder_hidden_state",encoder_hidden_state.shape) #[3, 235, 1024]
                
                encoder_hidden_state = encoder_hidden_state.unsqueeze(dim=1).repeat(1, 10,1,1).reshape(bs, 10 ,seq_len, dim) #bs * seq_len * 1024
                #[3, 10 ,235, 1024]
                # print("encoder_hidden_state",encoder_hidden_state.shape)
                # print("all_sentence_cross_attention_mask",all_sentence_cross_attention_mask[0][1])
                # print("all_topic_cross_attention_mask",all_topic_cross_attention_mask[0][1])
                k = v = (encoder_hidden_state * all_sentence_cross_attention_mask.unsqueeze(dim=-1)).reshape(-1, seq_len, dim)
                #one stage 
                q = (encoder_hidden_state * all_topic_cross_attention_mask.unsqueeze(dim=-1) ).reshape(-1, seq_len, dim) #

                # print("all_topic_cross_attention_mask",all_topic_cross_attention_mask.shape)
                # print("all_sentence_cross_attention_mask",all_sentence_cross_attention_mask.shape)
                cross_atten_mask = all_topic_cross_attention_mask.unsqueeze(dim=-1) * all_sentence_cross_attention_mask.unsqueeze(dim=-2)
                # print("cross_atten_mask1",cross_atten_mask.shape)
                cross_atten_mask = cross_atten_mask.reshape(-1,seq_len, seq_len)
                # print("cross_atten_mask2",cross_atten_mask.shape)
                cross_atten_mask = cross_atten_mask.repeat(1, self.config.encoder_attention_heads , 1).reshape(-1,seq_len, seq_len)
                # attention_output = self.topic_layer(q,k,v)

                cross_atten_mask2 = all_topic_cross_attention_mask.unsqueeze(dim=-1) * all_topic_cross_attention_mask.unsqueeze(dim=-2)
                # print("cross_atten_mask1",cross_atten_mask.shape)
                cross_atten_mask2 = cross_atten_mask2.reshape(-1,seq_len, seq_len)
                # print("cross_atten_mask2",cross_atten_mask.shape)
                cross_atten_mask2 = cross_atten_mask2.repeat(1, self.config.encoder_attention_heads , 1).reshape(-1,seq_len, seq_len)
                # for m in self.topic_layer:
                #     q = m(q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = self.topic_layer[1](q,q,q, attn_mask = cross_atten_mask2.float() ) .reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1) #.reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = self.topic_layer[3](q,q,q, attn_mask = cross_atten_mask2.float() ) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1) #.reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)


                # q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = torch.mean(q.reshape(bs, 10,seq_len,dim), dim=1)
                # q = self.topic_layer[1](q,q,q, key_padding_mask = all_topic_cross_attention_mask[:,0]) 
                # q = q.unsqueeze(dim=1).repeat(1,10,1,1).reshape(bs*10, seq_len, dim)
                # q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # # q = self.topic_layer[3](q,q,q, attn_mask = cross_atten_mask2.float() ) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1) #.reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)


                q = self.topic_layer[0](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = torch.mean(q.reshape(bs, 10,seq_len,dim), dim=1)
                q = self.topic_layer[1](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                # q = q.unsqueeze(dim=1).repeat(1,10,1,1).reshape(bs*10, seq_len, dim)
                q = self.topic_layer[2](q,k,v, attn_mask = cross_atten_mask.float() ).reshape(-1 ,seq_len, dim) * all_topic_cross_attention_mask.reshape(-1,seq_len).unsqueeze(dim=-1)
                    # q = torch.mean(q.reshape(bs, 10,seq_len,dim), dim=1)
                    
                    # print("q",q.shape)
                    # q = q.unsqueeze(dim=1).repeat(1, 10,1,1).reshape(bs 10 ,seq_len, dim)
                    # print("q",q.shape)
                    # print("all_topic_cross_attention_mask",all_topic_cross_attention_mask.shape)
                # attention_output = self.topic_layer(hidden_states = q, encoder_hidden_states = k)
                # print("attention_output",attention_output)
                #one stage
                attention_output = q.reshape(bs, 10 ,seq_len, dim) * all_topic_cross_attention_mask.unsqueeze(dim=-1) # 

                # print("attention_output",attention_output[0][0])
                sentence_tpoic_output = self.topic_classify_layer( attention_output )

                # print("sentence_tpoic_output",sentence_tpoic_output.argmax(dim=-1)[0][0])
                # print("all_topic_matrix_label",all_topic_matrix_label[0][0])
                # all_topic_matrix_label = all_topic_matrix_label.to(encoder_hidden_state.device)
                # print("all_topic_matrix_label",all_topic_matrix_label.shape)
                # print("topic_output.view(-1, 2)",F.softmax(topic_output[0], dim=-1) )
                # topic_loss  = self.preference_one_fn(sentence_tpoic_output.view(-1, 2), all_topic_matrix_label.view(-1))
                topic_loss  = loss_fct (sentence_tpoic_output.view(-1, 2), all_topic_matrix_label.view(-1))

                # print("topic_loss",topic_loss.item())
                masked_lm_loss += topic_ratio * topic_loss
                




            if not return_dict:
                output = (lm_logits,) + outputs[1:]
                return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            
        

            return Seq2SeqLMOutput(
                loss=masked_lm_loss,
                logits=lm_logits,
                past_key_values=outputs.past_key_values,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
        else:
            # # print("testpath_attention_mask",path_attention_mask)
            # #decoder   encoder_outputs
            # use_cache = False
            # if decoder_input_ids is None and decoder_inputs_embeds is None:
            #     decoder_input_ids = shift_tokens_right(
            #         labels, self.config.pad_token_id, self.config.decoder_start_token_id
            #     )

            # decoder = self.get_decoder()

            # ##########################################################################

            # if decoder_input_ids is None and decoder_inputs_embeds is None:
            #     if input_ids is None:
            #         raise ValueError(
            #             "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
            #             "passed, `input_ids` cannot be `None`. Please pass either "
            #             "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
            #         )

            #     decoder_input_ids = shift_tokens_right( #teacher forcing
            #         input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            #     )

            # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # output_hidden_states = (
            #     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            # )
            # use_cache = use_cache if use_cache is not None else self.config.use_cache
            # return_dict = return_dict if return_dict is not None else self.config.use_return_dict


            # # print("encoder_outputs",encoder_outputs)

            # encoder_hidden_states = encoder_outputs["last_hidden_state"]
            # decoder_outputs = self.decoder(
            #     input_ids=decoder_input_ids, #teacher forcing
            #     # attention_mask=decoder_attention_mask,
            #     encoder_hidden_states=encoder_hidden_states, #generate 
            #     encoder_attention_mask=path_attention_mask, #path_attention_mask. ，
            #     head_mask=decoder_head_mask,
            #     cross_attn_head_mask=cross_attn_head_mask,
            #     past_key_values=past_key_values,
            #     inputs_embeds=decoder_inputs_embeds,
            #     use_cache=use_cache,
            #     output_attentions=output_attentions,
            #     output_hidden_states=output_hidden_states,
            #     return_dict=return_dict,
            # )

            # # if not return_dict:
            # #     return decoder_outputs + encoder_outputs


            # outputs = Seq2SeqModelOutput(
            #     last_hidden_state=decoder_outputs.last_hidden_state,
            #     past_key_values=decoder_outputs.past_key_values,
            #     decoder_hidden_states=decoder_outputs.hidden_states,
            #     decoder_attentions=decoder_outputs.attentions,
            #     cross_attentions=decoder_outputs.cross_attentions,
            #     encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            #     encoder_hidden_states=encoder_outputs.hidden_states,
            #     encoder_attentions=encoder_outputs.attentions,
            # )


            # ####################################################################################################
            # lm_logits = self.lm_head(outputs[0])
            # # print("lm_logits",lm_logits.shape)
            # lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

            # masked_lm_loss = None
            # if not return_dict:
            #     output = (lm_logits,) + outputs[1:]
            #     return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            
        

            # return Seq2SeqLMOutput(
            #     loss=masked_lm_loss,
            #     logits=lm_logits,
            #     past_key_values=outputs.past_key_values,
            #     decoder_hidden_states=outputs.decoder_hidden_states,
            #     decoder_attentions=outputs.decoder_attentions,
            #     cross_attentions=outputs.cross_attentions,
            #     encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            #     encoder_hidden_states=outputs.encoder_hidden_states,
            #     encoder_attentions=outputs.encoder_attentions,
            # )


            #
            encoder_hidden_state = encoder_outputs[0]
            tag_result = self.tagging_layer(encoder_hidden_state)
            prob = 1 - F.softmax(tag_result,dim = -1)[:,:,0]
            # prob = prob.unsqueeze(dim=0)
            # print("path_attention_mask",path_attention_mask)
            outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                encoder_outputs=encoder_outputs,
                decoder_attention_mask=decoder_attention_mask,
                head_mask=head_mask,
                decoder_head_mask=decoder_head_mask,
                cross_attn_head_mask=cross_attn_head_mask,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                decoder_inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                # reply_attention_mask=reply_attention_mask,
                # speaker_attention_mask=speaker_attention_mask,
                path_attention_mask = path_attention_mask,
                prob = prob,
            )

            lm_logits = self.lm_head(outputs[0])
            lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

            if not return_dict:
                output = (lm_logits,) + outputs[1:]
                return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
                
            

            return Seq2SeqLMOutput(
                loss=None,
                logits=lm_logits,
                past_key_values=outputs.past_key_values,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )