@@ -88,6 +88,7 @@ def test_check_output(self):
             )
 
     def test_check_grad_x(self):
+        check_prim_pir_grad = True
         if self.dtype not in ("uint16", "float16"):
             self.check_grad_with_place(
                 core.CPUPlace(),
@@ -96,7 +97,7 @@ def test_check_grad_x(self):
                 user_defined_grad_outputs=self.out_grad,
                 check_prim=True,
                 only_check_prim=True,
-                check_prim_pir=self.check_prim_pir,
+                check_prim_pir=check_prim_pir_grad,
             )
         if paddle.is_compiled_with_cuda():
             self.check_grad_with_place(
@@ -106,10 +107,11 @@ def test_check_grad_x(self):
                 user_defined_grad_outputs=self.out_grad,
                 check_prim=True,
                 only_check_prim=True,
-                check_prim_pir=self.check_prim_pir,
+                check_prim_pir=check_prim_pir_grad,
             )
 
     def test_check_grad_scale_bias(self):
+        check_prim_pir_grad = True
         if self.data_format == "NCHW" and self.training is False:
             self.enable_cinn = False
         if self.dtype == "float32":
@@ -130,7 +132,7 @@ def test_check_grad_scale_bias(self):
                 user_defined_grad_outputs=self.out_grad,
                 check_prim=True,
                 only_check_prim=True,
-                check_prim_pir=self.check_prim_pir,
+                check_prim_pir=check_prim_pir_grad,
             )
         if paddle.is_compiled_with_cuda():
             self.check_grad_with_place(
@@ -140,7 +142,7 @@ def test_check_grad_scale_bias(self):
                 user_defined_grad_outputs=self.out_grad,
                 check_prim=True,
                 only_check_prim=True,
-                check_prim_pir=self.check_prim_pir,
+                check_prim_pir=check_prim_pir_grad,
             )
 
     def initConfig(self):