Browse Source

fix(mge/module): fix prelu error when use_symbolic_shape is true

GitOrigin-RevId: 25b9c4d41d
release-1.7
Megvii Engine Team 3 years ago
parent
commit
ac86d64474
2 changed files with 19 additions and 7 deletions
  1. +0
    -6
      imperative/python/megengine/module/activation.py
  2. +19
    -1
      imperative/python/test/unit/module/test_activation.py

+ 0
- 6
imperative/python/megengine/module/activation.py View File

@@ -239,12 +239,6 @@ class PReLU(Module):
self.weight = Parameter(data=[init]) self.weight = Parameter(data=[init])


def forward(self, inputs): def forward(self, inputs):
assert self.weight.shape == (1,) or self.weight.shape == (
1,
int(inputs.shape[1]),
1,
1,
), "invalid weight's shape"
return prelu(inputs, self.weight) return prelu(inputs, self.weight)






+ 19
- 1
imperative/python/test/unit/module/test_activation.py View File

@@ -7,9 +7,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest


import megengine as mge import megengine as mge
from megengine.module import LeakyReLU
from megengine.jit.tracing import set_symbolic_shape
from megengine.module import LeakyReLU, PReLU




def test_leaky_relu(): def test_leaky_relu():
@@ -21,3 +23,19 @@ def test_leaky_relu():


np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data) np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data)
np.testing.assert_equal(output.numpy(), np_output) np.testing.assert_equal(output.numpy(), np_output)


@pytest.mark.parametrize("shape", [(1, 64, 15, 15), (64,)])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_prelu(shape, use_symbolic):
old_flag = set_symbolic_shape(use_symbolic)
data = np.random.random(size=shape)

num_channel = 1 if len(shape) == 1 else shape[1]
prelu = PReLU(num_parameters=num_channel, init=0.25)
output = prelu(mge.Tensor(data))

np_output = np.maximum(data, 0) + prelu.weight.numpy() * np.minimum(data, 0)
set_symbolic_shape(old_flag)

np.testing.assert_allclose(output.numpy(), np_output, atol=1e-5)

Loading…
Cancel
Save