Browse Source

fix(subgraph): fix device recognition and scalar propagate

GitOrigin-RevId: fd2fe8bec9
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
21f5a7fcc0
4 changed files with 12 additions and 14 deletions
  1. +10
    -2
      imperative/python/megengine/core/tensor/utils.py
  2. +0
    -1
      imperative/python/megengine/functional/nn.py
  3. +1
    -10
      imperative/python/megengine/functional/tensor.py
  4. +1
    -1
      imperative/python/src/common.cpp

+ 10
- 2
imperative/python/megengine/core/tensor/utils.py View File

@@ -230,7 +230,7 @@ for name, mode in [
def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
):
if device.physical_name.startswith("cpu"):
if not device.physical_name.startswith("gpu"):
gopt_level = None # disable jit and compile
jit_fusion = False

@@ -370,7 +370,15 @@ def subgraph_fn(
jit_fusion=jit_fusion,
custom_grad=custom_grad,
)(func)
return lambda *args: apply(op(), *args)

def wrapped_func(*args):
if custom_grad:
outputs = op()(*args)
else:
outputs = apply(op(), *args)
return outputs

return wrapped_func
else:
return interpret_subgraph(func, dtype, device)



+ 0
- 1
imperative/python/megengine/functional/nn.py View File

@@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None):
device=device,
nr_inputs=1,
jit_fusion=True,
# gopt_level=0,
custom_grad=True,
)
def softplus(inputs, f, c):


+ 1
- 10
imperative/python/megengine/functional/tensor.py View File

@@ -18,14 +18,7 @@ from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
get_device,
isscalar,
setscalar,
subgraph_fn,
)
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:

where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y)
if isscalar(mask):
setscalar(oup)
return oup




+ 1
- 1
imperative/python/src/common.cpp View File

@@ -67,7 +67,7 @@ void init_common(py::module m) {
[](const CompNode& cn) { return cn.to_string_logical(); })
.def_property_readonly(
"physical_name",
[](const CompNode& cn) { return cn.to_string(); })
[](const CompNode& cn) { return cn.to_string_physical(); })
.def_property_readonly(
"get_mem_status_bytes",
[](const CompNode& cn) {


Loading…
Cancel
Save