Browse Source

feat(imperative): region restrictd conv support bias in python

GitOrigin-RevId: 9a2c1ee27a
master
Megvii Engine Team 2 years ago
parent
commit
c1c1d6d160
3 changed files with 55 additions and 21 deletions
  1. +6
    -0
      imperative/python/megengine/functional/nn.py
  2. +6
    -3
      imperative/python/megengine/module/conv.py
  3. +43
    -18
      imperative/python/test/unit/functional/test_functional.py

+ 6
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1980,6 +1980,7 @@ def region_restricted_conv(
weight: Tensor, weight: Tensor,
rin: Tensor, rin: Tensor,
rout: Tensor, rout: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1, stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0, padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1,
@@ -1994,6 +1995,9 @@ def region_restricted_conv(
Args: Args:
inp: feature map of the convolution operation. inp: feature map of the convolution operation.
weight: convolution kernel. weight: convolution kernel.
rin: input mask
rout: output mask
bias: bias added to the result of convolution (if given).
stride: stride of the 2D region restricted convolution operation. Default: 1 stride: stride of the 2D region restricted convolution operation. Default: 1
padding: size of the paddings added to the input on both sides of its padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
@@ -2027,6 +2031,8 @@ def region_restricted_conv(
sparse=sparse_type, sparse=sparse_type,
) )
(output,) = apply(op, inp, weight, rin, rout) (output,) = apply(op, inp, weight, rin, rout)
if bias is not None:
output += bias
return output return output






+ 6
- 3
imperative/python/megengine/module/conv.py View File

@@ -1040,6 +1040,7 @@ class RegionRestrictedConv(_ConvNd):
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups, and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``. Default: 1 in_channels // groups, height, width)``. Default: 1
bias: whether to add a bias onto the result of convolution. Default: True
conv_mode: Supports `cross_correlation`. Default: `cross_correlation` conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
compute_mode: When set to "default", no special requirements will be compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32", placed on the precision of intermediate results. When set to "float32",
@@ -1071,6 +1072,7 @@ class RegionRestrictedConv(_ConvNd):
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, Tuple[int, int]],
groups: int, groups: int,
bias: bool = True,
stride: Union[int, Tuple[int, int]] = 1, stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
@@ -1095,7 +1097,7 @@ class RegionRestrictedConv(_ConvNd):
0, 0,
dilation, dilation,
groups, groups,
False,
bias,
**kwargs, **kwargs,
) )


@@ -1133,7 +1135,7 @@ class RegionRestrictedConv(_ConvNd):
(self.padding[1], self.padding[1]), (self.padding[1], self.padding[1]),
) )


def calc_conv(self, inp, weight, rin, rout):
def calc_conv(self, inp, weight, rin, rout, bias):
assert self.padding_mode in [ assert self.padding_mode in [
"zeros", "zeros",
"reflect", "reflect",
@@ -1144,6 +1146,7 @@ class RegionRestrictedConv(_ConvNd):
weight, weight,
rin, rin,
rout, rout,
bias,
self.stride, self.stride,
self.padding, self.padding,
self.dilation, self.dilation,
@@ -1153,4 +1156,4 @@ class RegionRestrictedConv(_ConvNd):
) )


def forward(self, inp, rin, rout): def forward(self, inp, rin, rout):
return self.calc_conv(inp, self.weight, rin, rout)
return self.calc_conv(inp, self.weight, rin, rout, self.bias)

+ 43
- 18
imperative/python/test/unit/functional/test_functional.py View File

@@ -930,7 +930,8 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)




def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_naive(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -943,15 +944,22 @@ def test_region_restricted_conv_forward_backward_naive():
cpu_src = tensor(src_1, device=handle) cpu_src = tensor(src_1, device=handle)
cpu_filter = tensor(filter_1, device=handle) cpu_filter = tensor(filter_1, device=handle)
gm = GradManager().attach([cpu_src, cpu_filter]) gm = GradManager().attach([cpu_src, cpu_filter])
cpu_bias = (
tensor(np.ones((1, 2, 1, 1), dtype=np.float32), device=handle) if bias else None
)
with gm: with gm:
cpu_out = F.region_restricted_conv( cpu_out = F.region_restricted_conv(
cpu_src, cpu_src,
cpu_filter, cpu_filter,
tensor(rin_1, device=handle), tensor(rin_1, device=handle),
tensor(rout_1, device=handle), tensor(rout_1, device=handle),
bias=cpu_bias,
groups=2, groups=2,
) )
gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle)) gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle))
if cpu_bias is not None:
cpu_out = cpu_out - cpu_bias
np.testing.assert_allclose(cpu_out, np.array([14, 126]).reshape(1, 2, 1, 1))
np.testing.assert_allclose( np.testing.assert_allclose(
cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2) cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2)
) )
@@ -963,7 +971,8 @@ def test_region_restricted_conv_forward_backward_naive():
@pytest.mark.skipif( @pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
) )
def test_region_restricted_conv_forward_backward_cuda():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_cuda(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -998,18 +1007,23 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
expected_out = F.region_restricted_conv( expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
) )
gm.backward( gm.backward(
expected_out, expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
) )
return src, filter
return src, filter, expected_out


expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()


src = tensor( src = tensor(
np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32), np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
@@ -1018,18 +1032,25 @@ def test_region_restricted_conv_forward_backward_cuda():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle) rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle) rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle)) gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle))
np.testing.assert_allclose(src.grad, expected_src.grad) np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad) np.testing.assert_allclose(filter.grad, expected_filter.grad)
np.testing.assert_allclose(gpu_out, expected_out)




@pytest.mark.skipif( @pytest.mark.skipif(
not is_cuda_available(), reason="rrconv cuda kernel requires cuda available" not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
) )
def test_region_restricted_conv_forward_backward_uint8():
@pytest.mark.parametrize("bias", [True, False])
def test_region_restricted_conv_forward_backward_uint8(bias):
import megengine as mge import megengine as mge
import megengine.module as M import megengine.module as M
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
@@ -1063,18 +1084,23 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0") filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0") rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0") rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
bias_cpu = (
tensor(np.ones(diff_shape).astype(np.float32), device="cpu0")
if bias
else None
)
gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
expected_out = F.region_restricted_conv( expected_out = F.region_restricted_conv(
src, filter, rin, rout, groups=GROUP
src, filter, rin, rout, bias=bias_cpu, groups=GROUP
) )
gm.backward( gm.backward(
expected_out, expected_out,
tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"), tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
) )
return src, filter
return src, filter, expected_out


expected_src, expected_filter = get_groundtruth()
expected_src, expected_filter, expected_out = get_groundtruth()


# forward and dgrad/wgrad # forward and dgrad/wgrad
src = tensor( src = tensor(
@@ -1084,23 +1110,22 @@ def test_region_restricted_conv_forward_backward_uint8():
filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle) filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle) rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle) rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)
bias_gpu = (
tensor(np.ones(diff_shape).astype(np.float32), device=handle) if bias else None
)


gm = GradManager().attach([src, filter]) gm = GradManager().attach([src, filter])
with gm: with gm:
gpu_out = F.region_restricted_conv(src, filter, rin, rout, groups=GROUP)
gpu_out = F.region_restricted_conv(
src, filter, rin, rout, bias=bias_gpu, groups=GROUP
)
gm.backward( gm.backward(
gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle) gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle)
) )
# assert uint8 gpu result close to cpu result # assert uint8 gpu result close to cpu result
np.testing.assert_allclose(src.grad, expected_src.grad) np.testing.assert_allclose(src.grad, expected_src.grad)
np.testing.assert_allclose(filter.grad, expected_filter.grad) np.testing.assert_allclose(filter.grad, expected_filter.grad)


def test_region_restricted_conv():
test_region_restricted_conv_forward_backward_naive()
if is_cuda_available():
test_region_restricted_conv_forward_backward_cuda()
test_region_restricted_conv_forward_backward_uint8()
np.testing.assert_allclose(gpu_out, expected_out)




def test_conv2d_autocast(): def test_conv2d_autocast():


Loading…
Cancel
Save