|
|
@@ -253,6 +253,19 @@ def test_expand_dims(is_varnode): |
|
|
|
np.testing.assert_equal(y, yy.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_expand_dims_for_scalar(): |
|
|
|
x = np.array(1, dtype="float32") |
|
|
|
xx = make_tensor(x, None) |
|
|
|
for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]: |
|
|
|
y = np.expand_dims(x, axis) |
|
|
|
yy = F.expand_dims(xx, axis) |
|
|
|
np.testing.assert_equal(y, yy.numpy()) |
|
|
|
|
|
|
|
for axis in [1, -2, (1, 2), (-2, -3)]: |
|
|
|
np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis) |
|
|
|
np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_varnode", [True, False]) |
|
|
|
def test_elemwise_dtype_promotion(is_varnode): |
|
|
|
if is_varnode: |
|
|
|