import tqdm
import torch
from torch import nn
from torch import optim

from models import KBCModel
from regularizers import Regularizer


class KBCOptimizer(object):
    def __init__(
            self, model: KBCModel, regularizer: list, optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.regularizer = regularizer[0]
        self.regularizer2 = regularizer[1]
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor, e=0, weight=None):
        self.model.train()
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        # loss = nn.CrossEntropyLoss(reduction='mean', weight=weight)

        if e%5==0 and e>1:
        # if e<0:  #不加规则
            loss = nn.CrossEntropyLoss(reduction='none', weight=weight)
        else:
            loss = nn.CrossEntropyLoss(reduction='mean', weight=weight)   #结果实验Weight对结果产生影响



        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                    b_begin:b_begin + self.batch_size
                ].cuda()

                predictions, factors = self.model.forward(input_batch.long())
                truth = input_batch[:, 2].long()

                # l_fit = loss(predictions, truth)
                

                if e%5==0 and e>1:  #这是使用规则生成的样本进行推理
                # if e<0:
                    l_fit = loss(predictions, truth)
                    l_fit = torch.mean(l_fit*input_batch[:,-1])  #直接使用乘法来生成数据
                    
                    # weight = torch.clamp_max(input_batch[:,-1]+0.2,max=1)  #+0.2  最大值限制为1
                    # l_fit = torch.mean(l_fit*weight)  #直接使用乘法来生成数据
                    
                    # l_fit = torch.mean(l_fit*(input_batch[:,-1]-0.1))
                    
                    # l_fit = torch.mean(l_fit*(input_batch[:,-1]-0.2))
                else:
                    l_fit = loss(predictions, truth)
                    
                
                l_reg = self.regularizer.forward(factors)

                l = l_fit + l_reg

                self.optimizer.zero_grad()
                l.backward()

                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(loss=f'{l.item():.1f}', reg=f'{l_reg.item():.1f}')

        return l
