Browse Source

fix(imperative): allow rng op infer shape fallible

GitOrigin-RevId: 687844500c
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
9779bc7f6d
2 changed files with 22 additions and 7 deletions
  1. +3
    -1
      imperative/python/test/unit/functional/test_functional.py
  2. +19
    -6
      imperative/src/impl/ops/rng.cpp

+ 3
- 1
imperative/python/test/unit/functional/test_functional.py View File

@@ -71,7 +71,8 @@ def test_dropout():
with gm:
out = F.nn.dropout(data, rate, training=True)
gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
assert not out.numpy().all()
if len(shape) != 0:
assert not out.numpy().all()
np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)

def test_multiple_dropout(shape, rate):
@@ -99,6 +100,7 @@ def test_dropout():
out4 = F.nn.dropout(data, rate, training=True)
assert not (out1.numpy() == out4.numpy()).all()

test_dropout_with_shape([], 0.4)
test_dropout_with_shape([13, 17, 63, 21], 0.4)
test_dropout_with_shape([16, 32, 64], 0.3)
test_multiple_dropout([1024], 0.2)


+ 19
- 6
imperative/src/impl/ops/rng.cpp View File

@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
return {{dest}, true};
return {{dest}, inputs[0].layout.ndim != 0};
}

template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;

SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node;
dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
return {dests, true};
if (success) {
dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
} else {
dests[1].layout = TensorLayout(dtype::Int32());
}
return {dests, success};
}

template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;

SmallVector<LogicalTensorDesc> dests(2);
auto cn = inputs[0].comp_node;
dests[0].comp_node = cn;
@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
inputs[0].layout);
};
dests[1].comp_node = cn;
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
return {dests, true};
if (success) {
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
} else {
dests[1].layout = TensorLayout(dtype::Byte());
}

return {dests, success};
}

template <typename Op>


Loading…
Cancel
Save