GitOrigin-RevId: fd2fe8bec9
tags/v1.9.0
@@ -230,7 +230,7 @@ for name, mode in [ | |||||
def subgraph( | def subgraph( | ||||
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False | name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False | ||||
): | ): | ||||
if device.physical_name.startswith("cpu"): | |||||
if not device.physical_name.startswith("gpu"): | |||||
gopt_level = None # disable jit and compile | gopt_level = None # disable jit and compile | ||||
jit_fusion = False | jit_fusion = False | ||||
@@ -370,7 +370,15 @@ def subgraph_fn( | |||||
jit_fusion=jit_fusion, | jit_fusion=jit_fusion, | ||||
custom_grad=custom_grad, | custom_grad=custom_grad, | ||||
)(func) | )(func) | ||||
return lambda *args: apply(op(), *args) | |||||
def wrapped_func(*args): | |||||
if custom_grad: | |||||
outputs = op()(*args) | |||||
else: | |||||
outputs = apply(op(), *args) | |||||
return outputs | |||||
return wrapped_func | |||||
else: | else: | ||||
return interpret_subgraph(func, dtype, device) | return interpret_subgraph(func, dtype, device) | ||||
@@ -988,7 +988,6 @@ def _get_softplus_op(dtype=None, device=None): | |||||
device=device, | device=device, | ||||
nr_inputs=1, | nr_inputs=1, | ||||
jit_fusion=True, | jit_fusion=True, | ||||
# gopt_level=0, | |||||
custom_grad=True, | custom_grad=True, | ||||
) | ) | ||||
def softplus(inputs, f, c): | def softplus(inputs, f, c): | ||||
@@ -18,14 +18,7 @@ from ..core.ops import builtin | |||||
from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor.array_method import _broadcast, _remove_axis | from ..core.tensor.array_method import _broadcast, _remove_axis | ||||
from ..core.tensor.utils import ( | |||||
astensor1d, | |||||
convert_inputs, | |||||
get_device, | |||||
isscalar, | |||||
setscalar, | |||||
subgraph_fn, | |||||
) | |||||
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | |||||
from ..device import get_default_device | from ..device import get_default_device | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .elemwise import ceil | from .elemwise import ceil | ||||
@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
where = _get_where_op(dtype=dtype, device=device) | where = _get_where_op(dtype=dtype, device=device) | ||||
(oup,) = where(mask, x, y) | (oup,) = where(mask, x, y) | ||||
if isscalar(mask): | |||||
setscalar(oup) | |||||
return oup | return oup | ||||
@@ -67,7 +67,7 @@ void init_common(py::module m) { | |||||
[](const CompNode& cn) { return cn.to_string_logical(); }) | [](const CompNode& cn) { return cn.to_string_logical(); }) | ||||
.def_property_readonly( | .def_property_readonly( | ||||
"physical_name", | "physical_name", | ||||
[](const CompNode& cn) { return cn.to_string(); }) | |||||
[](const CompNode& cn) { return cn.to_string_physical(); }) | |||||
.def_property_readonly( | .def_property_readonly( | ||||
"get_mem_status_bytes", | "get_mem_status_bytes", | ||||
[](const CompNode& cn) { | [](const CompNode& cn) { | ||||