Browse Source

fix(mge/functional): fix F.ones when input is a tensor of scalar type

GitOrigin-RevId: 6d01d6b58d
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0c37a588ca
2 changed files with 8 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/functional/tensor.py
  2. +7
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 1
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -108,7 +108,7 @@ def full(shape, value, dtype="float32", device=None):
if device is None:
device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)()
if len(shape) == 0: # scalar
if shape is (): # scalar.shape
return x
return broadcast_to(x, shape)



+ 7
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -739,3 +739,10 @@ def test_cvt_color():
x = tensor(inp)
y = F.img_proc.cvt_color(x, mode="RGB2GRAY")
np.testing.assert_allclose(y.numpy(), out, atol=1e-5)


@pytest.mark.parametrize("val", [2, [2,], [2, 3]])
def test_ones(val):
shp = tensor(val)
np_shp = np.array(val)
np.testing.assert_equal(F.ones(shp), np.ones(np_shp))

Loading…
Cancel
Save