From 9779bc7f6d883e99efdc68511217942e122d1706 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Mar 2022 11:20:58 +0800 Subject: [PATCH] fix(imperative): allow rng op infer shape fallible GitOrigin-RevId: 687844500cc2cab18de576b1484215c72329e4b8 --- .../python/test/unit/functional/test_functional.py | 4 +++- imperative/src/impl/ops/rng.cpp | 25 ++++++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index a1bd6553..1f948f4d 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -71,7 +71,8 @@ def test_dropout(): with gm: out = F.nn.dropout(data, rate, training=True) gm.backward(out, tensor(np.ones(shape, dtype=np.float32))) - assert not out.numpy().all() + if len(shape) != 0: + assert not out.numpy().all() np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7) def test_multiple_dropout(shape, rate): @@ -99,6 +100,7 @@ def test_dropout(): out4 = F.nn.dropout(data, rate, training=True) assert not (out1.numpy() == out4.numpy()).all() + test_dropout_with_shape([], 0.4) test_dropout_with_shape([13, 17, 63, 21], 0.4) test_dropout_with_shape([16, 32, 64], 0.3) test_multiple_dropout([1024], 0.2) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 1f969f43..a59e44f8 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -559,25 +559,33 @@ std::tuple, bool> infer_output_attrs_fallible( } dest.comp_node = inputs[0].comp_node; dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); - return {{dest}, true}; + return {{dest}, inputs[0].layout.ndim != 0}; } template <> std::tuple, bool> infer_output_attrs_fallible< ShuffleRNG>(const OpDef& def, const SmallVector& inputs) { + bool success = inputs[0].layout.ndim != 0; + SmallVector dests(2); dests[0].comp_node = inputs[0].comp_node; dests[0].layout = TensorLayout(inputs[0].layout); dests[0].layout.dtype = inputs[0].layout.dtype; dests[1].comp_node = inputs[0].comp_node; - dests[1].layout = - TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); - return {dests, true}; + if (success) { + dests[1].layout = + TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); + } else { + dests[1].layout = TensorLayout(dtype::Int32()); + } + return {dests, success}; } template <> std::tuple, bool> infer_output_attrs_fallible( const OpDef& op, const SmallVector& inputs) { + bool success = inputs[0].layout.ndim != 0; + SmallVector dests(2); auto cn = inputs[0].comp_node; dests[0].comp_node = cn; @@ -590,8 +598,13 @@ std::tuple, bool> infer_output_attrs_fallible