Browse Source

fix(imperative): update region restricted conv testcase for bias shape

GitOrigin-RevId: 8bf66e7312
master
Megvii Engine Team 2 years ago
parent
commit
d14ae856ce
1 changed files with 8 additions and 4 deletions
  1. +8
    -4
      imperative/python/test/unit/functional/test_functional.py

+ 8
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -1008,7 +1008,7 @@ def test_region_restricted_conv_forward_backward_cuda(bias):
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device="cpu0")
if bias
else None
)
@@ -1033,7 +1033,9 @@ def test_region_restricted_conv_forward_backward_cuda(bias):
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device=handle)
if bias
else None
)
gm = GradManager().attach([src, filter])
with gm:
@@ -1085,7 +1087,7 @@ def test_region_restricted_conv_forward_backward_uint8(bias):
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device="cpu0")
if bias
else None
)
@@ -1111,7 +1113,9 @@ def test_region_restricted_conv_forward_backward_uint8(bias):
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device=handle)
if bias
else None
)

gm = GradManager().attach([src, filter])


Loading…
Cancel
Save