From 21f5a7fcc08e9f1e13d994735fe4d95d8f6d62c2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Sep 2021 14:58:21 +0800 Subject: [PATCH] fix(subgraph): fix device recognition and scalar propagate GitOrigin-RevId: fd2fe8bec9d9730e8689dbad314e54a7ecbc8bde --- imperative/python/megengine/core/tensor/utils.py | 12 ++++++++++-- imperative/python/megengine/functional/nn.py | 1 - imperative/python/megengine/functional/tensor.py | 11 +---------- imperative/python/src/common.cpp | 2 +- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index fdf0e344..a3d76402 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -230,7 +230,7 @@ for name, mode in [ def subgraph( name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False ): - if device.physical_name.startswith("cpu"): + if not device.physical_name.startswith("gpu"): gopt_level = None # disable jit and compile jit_fusion = False @@ -370,7 +370,15 @@ def subgraph_fn( jit_fusion=jit_fusion, custom_grad=custom_grad, )(func) - return lambda *args: apply(op(), *args) + + def wrapped_func(*args): + if custom_grad: + outputs = op()(*args) + else: + outputs = apply(op(), *args) + return outputs + + return wrapped_func else: return interpret_subgraph(func, dtype, device) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 42b108eb..6729bcb8 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None): device=device, nr_inputs=1, jit_fusion=True, - # gopt_level=0, custom_grad=True, ) def softplus(inputs, f, c): diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 529ba499..1937ed0b 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -18,14 +18,7 @@ from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const from ..core.tensor.array_method import _broadcast, _remove_axis -from ..core.tensor.utils import ( - astensor1d, - convert_inputs, - get_device, - isscalar, - setscalar, - subgraph_fn, -) +from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..device import get_default_device from ..tensor import Tensor from .elemwise import ceil @@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: where = _get_where_op(dtype=dtype, device=device) (oup,) = where(mask, x, y) - if isscalar(mask): - setscalar(oup) return oup diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 80483c97..6a9ac154 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -67,7 +67,7 @@ void init_common(py::module m) { [](const CompNode& cn) { return cn.to_string_logical(); }) .def_property_readonly( "physical_name", - [](const CompNode& cn) { return cn.to_string(); }) + [](const CompNode& cn) { return cn.to_string_physical(); }) .def_property_readonly( "get_mem_status_bytes", [](const CompNode& cn) {