diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 3072837a..4218e52d 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -108,7 +108,7 @@ def full(shape, value, dtype="float32", device=None): if device is None: device = get_default_device() (x,) = Const(value, dtype=dtype, device=device)() - if len(shape) == 0: # scalar + if shape is (): # scalar.shape return x return broadcast_to(x, shape) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index dc43cd45..bd8b016b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -739,3 +739,10 @@ def test_cvt_color(): x = tensor(inp) y = F.img_proc.cvt_color(x, mode="RGB2GRAY") np.testing.assert_allclose(y.numpy(), out, atol=1e-5) + + +@pytest.mark.parametrize("val", [2, [2,], [2, 3]]) +def test_ones(val): + shp = tensor(val) + np_shp = np.array(val) + np.testing.assert_equal(F.ones(shp), np.ones(np_shp))