From dd39265e9510e7e8b8d3e92ea1f5034c9416af9b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 00:25:39 +0800 Subject: [PATCH] fix(mgb/dtype): enable TypeCvt for bool when trace(symbolic=True) GitOrigin-RevId: 4e0fc63369b623e6e9e9eca396ec03f87b56452f --- imperative/python/test/unit/test_tracing.py | 13 +++++++++++++ src/core/impl/dtype.cpp | 1 + 2 files changed, 14 insertions(+) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 722b582b..9a55ae52 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -16,6 +16,7 @@ import megengine import megengine.core.tensor.megbrain_graph as G import megengine.module as M from megengine import cgtools, tensor +from megengine.core._trace_option import set_tensor_shape from megengine.core.ops import builtin as ops from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.core import apply @@ -274,3 +275,15 @@ def test_optimize_for_inference(): res = G.load_comp_graph_from_file(out) computing_input = res.output_vars_list[0].owner.inputs[0] assert computing_input.dtype == np.float16 + + +def test_trace_cvt_bool(): + set_tensor_shape(True) + x = tensor([0], dtype=np.int32) + + @trace(symbolic=True) + def f(x): + return x.shape[0] == 0 + + for i in range(3): + np.testing.assert_equal(f(x).numpy()[0], False) diff --git a/src/core/impl/dtype.cpp b/src/core/impl/dtype.cpp index a91a9f9f..6c8e9c9b 100644 --- a/src/core/impl/dtype.cpp +++ b/src/core/impl/dtype.cpp @@ -136,6 +136,7 @@ void mgb::static_cast_dtype(T* dest, DType src_type, const void* storage, nr_elem, src_type); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb #define cb(_name, _bits) \ case DTypeTrait::enumv: \