Browse Source

fix(mge/core): avoid create RawTensor with zero-stride numpy ndarray

GitOrigin-RevId: a9b2940bdc
release-1.1
Megvii Engine Team 4 years ago
parent
commit
b80fade3f9
2 changed files with 11 additions and 0 deletions
  1. +2
    -0
      imperative/python/megengine/core/tensor/raw_tensor/__init__.py
  2. +9
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 2
- 0
imperative/python/megengine/core/tensor/raw_tensor/__init__.py View File

@@ -100,6 +100,8 @@ def _(data: DeviceTensorND):
@as_raw_tensor.register(np.ndarray)
def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c()
if 0 in array.strides:
array = array.squeeze().reshape(array.shape)
return RawTensor(put(array, dtype=dtype, device=device))




+ 9
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -458,6 +458,15 @@ def test_conv_bias():
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")


def test_zero_stride_numpy_array():
inp = np.random.randn(3, 224, 224).astype(np.float32)
inp = inp[np.newaxis, :]

inp = tensor(inp, dtype=np.float32)
weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)


def test_condtake():
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[True, False, True], [False, True, True]])


Loading…
Cancel
Save