GitOrigin-RevId: 1c728d6ab9
release-1.10
@@ -50,36 +50,36 @@ class autocast: | |||
self._origin_enabled = None | |||
self._origin_high = None | |||
self._origin_low = None | |||
self._origin_compute_mode = None | |||
self._origin_configs = None | |||
def __enter__(self): | |||
self._origin_enabled = amp._enabled | |||
amp._enabled = self.enabled | |||
amp._set_amp_dtype_autocast(self.enabled) | |||
if not self.enabled: | |||
return | |||
if self.enabled: | |||
self._origin_enabled = amp._enabled | |||
self._origin_high = amp._get_amp_high_prec_dtype() | |||
self._origin_low = amp._get_amp_low_prec_dtype() | |||
amp._enabled = self.enabled | |||
amp._set_amp_dtype_autocast(self.enabled) | |||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
self._origin_high = amp._get_amp_high_prec_dtype() | |||
self._origin_low = amp._get_amp_low_prec_dtype() | |||
amp._set_amp_high_prec_dtype(self.high_prec_dtype) | |||
amp._set_amp_low_prec_dtype(self.low_prec_dtype) | |||
self._origin_configs = _config._reset_execution_config(compute_mode="float32") | |||
self._origin_configs = _config._reset_execution_config( | |||
compute_mode="float32" | |||
) | |||
def __exit__(self, *args): | |||
amp._enabled = self._origin_enabled | |||
amp._set_amp_dtype_autocast(self._origin_enabled) | |||
if not self.enabled: | |||
return | |||
amp._set_amp_high_prec_dtype(self._origin_high) | |||
amp._set_amp_low_prec_dtype(self._origin_low) | |||
if self.enabled: | |||
amp._enabled = self._origin_enabled | |||
amp._set_amp_dtype_autocast(self._origin_enabled) | |||
amp._set_amp_high_prec_dtype(self._origin_high) | |||
amp._set_amp_low_prec_dtype(self._origin_low) | |||
_config._reset_execution_config(*self._origin_compute_mode) | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
if not self.enabled: | |||
return func(*args, **kwargs) | |||
with self: | |||
return func(*args, **kwargs) | |||
@@ -10,6 +10,7 @@ from copy import deepcopy | |||
from .. import functional as F | |||
from ..module import Module | |||
from ..tensor import Tensor | |||
from ..core import _config | |||
def _is_nchw_format(param: Tensor): | |||
@@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
else: | |||
raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | |||
# TODO: use initialization from tensor after fixing format setting | |||
if inplace: | |||
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
else: | |||
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
if x.format != "nhwc": | |||
if inplace: | |||
data = x.numpy().transpose(*pattern) | |||
x[...] = Tensor(data, format="nhwc") | |||
else: | |||
x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
return x | |||
@@ -144,7 +144,9 @@ class GradScaler: | |||
def _check_gradients(self, grads, scale): | |||
if len(grads) == 0: | |||
return False | |||
return _check_non_finite(grads, scale) | |||
rst = _check_non_finite(grads, scale) | |||
rst = rst.numpy() | |||
return rst | |||
def update(self, new_scale: float = None): | |||
r"""Update the scale factor according to whether encountered overflow grad. | |||
@@ -182,7 +182,6 @@ def _reset_execution_config( | |||
deterministic_kernel=None, | |||
async_level=None, | |||
compute_mode=None, | |||
bn_format=None, | |||
auto_format_convert=None, | |||
): | |||
global _benchmark_kernel, _deterministic_kernel, __compute_mode | |||
@@ -234,11 +233,11 @@ def _override( | |||
def train(): | |||
""" | |||
orig_flags = _reset_execution_config( | |||
benchmark_kernel, | |||
deterministic_kernel, | |||
async_level, | |||
compute_mode, | |||
auto_format_convert, | |||
benchmark_kernel=benchmark_kernel, | |||
deterministic_kernel=deterministic_kernel, | |||
async_level=async_level, | |||
compute_mode=compute_mode, | |||
auto_format_convert=auto_format_convert, | |||
) | |||
try: | |||
yield | |||
@@ -64,7 +64,9 @@ class Grad: | |||
continue | |||
grad.suppress() | |||
print("before backward") | |||
self._impl.backward(ys, dys) | |||
print("after backward") | |||
for grad in group: | |||
if grad is self: | |||
@@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||
from .._imperative_rt.ops import jit_supported | |||
from .._wrap import as_device | |||
from ..autodiff.grad import Function | |||
from .. import _config | |||
from ..ops import builtin | |||
from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | |||
from .dtype import is_dtype_equal, is_quantize | |||
@@ -1226,12 +1226,16 @@ def batch_norm( | |||
bias = make_full_if_none(bias, 0) | |||
if not training: | |||
op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps) | |||
op = builtin.BatchNorm( | |||
fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps | |||
) | |||
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | |||
return ret | |||
else: | |||
op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps) | |||
op = builtin.BatchNorm( | |||
avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps | |||
) | |||
if has_mean or has_var: | |||
running_mean = make_full_if_none(running_mean, 0) | |||
running_var = make_full_if_none(running_var, 1) | |||
@@ -19,7 +19,6 @@ class _BatchNorm(Module): | |||
affine=True, | |||
track_running_stats=True, | |||
freeze=False, | |||
param_dim="dim_1c11", | |||
**kwargs | |||
): | |||
super(_BatchNorm, self).__init__(**kwargs) | |||
@@ -30,7 +29,6 @@ class _BatchNorm(Module): | |||
self.track_running_stats = track_running_stats | |||
self._track_running_stats_saved = track_running_stats | |||
self.freeze = freeze | |||
self.param_dim = param_dim | |||
if self.freeze: | |||
assert ( | |||
self._track_running_stats_saved | |||
@@ -104,7 +102,6 @@ class _BatchNorm(Module): | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
eps=self.eps, | |||
param_dim=self.param_dim, | |||
) | |||
return output | |||
@@ -8,6 +8,7 @@ from typing import Union | |||
import numpy as np | |||
from ..core import _config | |||
from ..core._imperative_rt.core2 import ( | |||
get_auto_format_convert, | |||
pop_scope, | |||
@@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): | |||
"optimizer can only optimize Parameters, but one of the params is " | |||
+ str(type(param)) | |||
) | |||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||
param._reset(Tensor(param, no_cache=True)) | |||
for name, default in self._defaults.items(): | |||
if default is required and name not in param_group: | |||
@@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta): | |||
def _add_state(self, param, state_name, initializer=None): | |||
if initializer is None: | |||
initializer = np.zeros(param.shape, dtype=np.float32) | |||
with _config._override(auto_format_convert=False): | |||
initializer = np.zeros(param.shape, dtype=np.float32) | |||
state_dict = self._state.setdefault(param, {}) | |||
assert state_name not in state_dict | |||
state = Tensor(initializer, no_cache=True) | |||
state = Tensor(initializer, no_cache=True, format=param.format) | |||
state_dict[state_name] = state | |||
@abstractmethod | |||
@@ -5,6 +5,7 @@ from typing import Iterable, Union | |||
from ..functional.inplace import _inplace_add_ | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
from ..core import _config | |||
class SGD(Optimizer): | |||
@@ -10,7 +10,7 @@ import pytest | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import Parameter, Tensor, amp, tensor | |||
from megengine import Parameter, Tensor, amp, config | |||
class MyModule(M.Module): | |||
@@ -39,6 +39,22 @@ class MyModule(M.Module): | |||
@pytest.mark.parametrize("is_inplace", [False, True]) | |||
def test_convert_module(is_inplace): | |||
m = MyModule() | |||
expected_shape = { | |||
"i.bn.weight": (1, 1, 1, 4), | |||
"i.bn.bias": (1, 1, 1, 4), | |||
"i.bn.running_mean": (1, 1, 1, 4), | |||
"i.bn.running_var": (1, 1, 1, 4), | |||
"conv.weight": (2, 2, 4, 4, 2), | |||
"conv.bias": (1, 1, 1, 4), | |||
"bn.weight": (1, 1, 1, 4), | |||
"bn.bias": (1, 1, 1, 4), | |||
"bn.running_mean": (1, 1, 1, 4), | |||
"bn.running_var": (1, 1, 1, 4), | |||
"param": (1, 1, 1, 3), | |||
"buff": (1, 1, 1, 3), | |||
} | |||
m = amp.convert_module_format(m, is_inplace) | |||
for name, param in m.named_tensors(): | |||
assert param.format == "nhwc" | |||
with config._override(auto_format_convert=False): | |||
assert param.shape == expected_shape[name], name |
@@ -3,6 +3,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import tensor | |||
from megengine.autodiff import GradManager | |||
from megengine.jit import trace | |||
@@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
if is_symbolic is not None: | |||
func = trace(func, symbolic=is_symbolic) | |||
# out1 = func(x1) | |||
out1 = func(x1) | |||
out2 = func(x2) | |||
# np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
np.testing.assert_almost_equal(out1, out2, decimal=5) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
@@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic): | |||
_compare_nchw_nhwc(data, func, is_symbolic) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
def test_backward(is_symbolic): | |||
data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||
b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||
gm = GradManager().attach([w, b]) | |||
def _compare_backward(inps, model, is_symbolic=None): | |||
def func(*inps): | |||
return model(*inps) | |||
def func(x, w, b): | |||
return F.conv2d(x, w, b) | |||
if is_symbolic is not None: | |||
func = trace(func, symbolic=is_symbolic) | |||
gm = GradManager().attach(model.parameters()) | |||
with gm: | |||
if is_symbolic is not None: | |||
func = trace(func, symbolic=is_symbolic) | |||
x = func(x, w, b) | |||
assert x.format == "nhwc" | |||
# test manually convert to NHWC, usually used in detection head | |||
x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) | |||
gm.backward(x) | |||
print("finish backward", x.format) | |||
# backward grad has no format | |||
np.testing.assert_equal( | |||
w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
) | |||
np.testing.assert_equal( | |||
b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) | |||
) | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
expected_grads = [param.grad for param in model.parameters()] | |||
inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | |||
model = mge.amp.convert_module_format(model) | |||
gm = GradManager().attach(model.parameters()) | |||
with gm: | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
actual_grads = [param.grad for param in model.parameters()] | |||
for expected, actual in zip(expected_grads, actual_grads): | |||
# print(param.grad) | |||
np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
def test_backward_conv2d_dimshuffle(is_symbolic): | |||
class Net(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = M.Conv2d(2, 3, 1) | |||
def forward(self, inp): | |||
# test manually convert to NHWC, usually used in detection head | |||
return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | |||
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||
# grads = [ | |||
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)), | |||
# ] | |||
_compare_backward([inp], Net(), is_symbolic) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
def test_backward_groupconv2d_bn(is_symbolic): | |||
class Net(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = M.Conv2d(2, 2, 1, groups=2) | |||
self.bn = M.BatchNorm2d(2) | |||
def forward(self, inp): | |||
# test manually convert to NHWC, usually used in detection head | |||
return self.bn(self.conv(inp)) | |||
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
_compare_backward([inp], Net(), is_symbolic) | |||
# def func(x, w, b, bn_w, bn_b): | |||
# x = F.conv2d(x, w, b, groups=2) | |||
# x = F.batch_norm( | |||
# x, | |||
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||
# weight=bn_w, | |||
# bias=bn_b, | |||
# training=True, | |||
# inplace=True, | |||
# ) | |||
# return x | |||
# data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") | |||
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||
# grads = [ | |||
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)), | |||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||
# ] | |||
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) |
@@ -1,6 +1,8 @@ | |||
#include "megbrain/imperative/transformations/format.h" | |||
#include "megbrain/imperative/transformations/grad.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -17,7 +19,12 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||
const std::string& scope) const { | |||
std::vector<int32_t> pattern; | |||
if (tensor.format() == FT::NHWC && target == FT::NCHW) { | |||
pattern = {0, 3, 1, 2}; | |||
// FIXME(czh): temporary fast path for group conv 5D weight. | |||
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | |||
pattern = {0, 1, 4, 2, 3}; | |||
} else { | |||
pattern = {0, 3, 1, 2}; | |||
} | |||
} else if (tensor.format() == FT::NCHW && target == FT::NHWC) { | |||
pattern = {0, 2, 3, 1}; | |||
} else { | |||
@@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs( | |||
namespace { | |||
ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||
mgb_assert(shape.ndim == 4); | |||
auto out = ValueShape(shape); | |||
out[3] = shape[2]; | |||
out[2] = shape[1]; | |||
out[1] = shape[3]; | |||
return out; | |||
if (shape.ndim == 4) { | |||
out[1] = shape[3]; | |||
out[2] = shape[1]; | |||
out[3] = shape[2]; | |||
return out; | |||
} else if (shape.ndim == 5) { | |||
out[2] = shape[4]; | |||
out[3] = shape[2]; | |||
out[4] = shape[3]; | |||
return out; | |||
} else { | |||
mgb_throw( | |||
MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", | |||
shape.ndim); | |||
} | |||
} | |||
using FormatRule = std::function<ValueRefList( | |||
@@ -278,10 +295,10 @@ ValueRefList setsubtensor_rule( | |||
inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
FT format(FT::DEFAULT); | |||
for (auto& inp : inputs) { | |||
auto&& inp_ref = inp.as_ref(t.value_type()); | |||
if (inp_ref && inp_ref->format() != FT::DEFAULT) { | |||
mgb_assert(format == FT::DEFAULT || inp_ref->format() == format); | |||
format = inp_ref->format().type(); | |||
auto&& inp_format = inp.cast(t.value_type()).format(); | |||
if (inp_format != FT::DEFAULT) { | |||
mgb_assert(format == FT::DEFAULT || inp_format == format); | |||
format = inp_format.type(); | |||
} | |||
} | |||
return format; | |||
@@ -308,13 +325,6 @@ ValueRefList concat_rule( | |||
format); | |||
} | |||
ValueRefList elemwise_rule( | |||
const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
const FormatTransformation& t) { | |||
FT format = get_inputs_format(inputs, t); | |||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); | |||
} | |||
ValueRefList identity_rule_helper( | |||
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
// mgb_assert(inputs.size() == 1); | |||
@@ -336,24 +346,49 @@ ValueRefList batchnorm_rule( | |||
return identity_rule_helper(op, inputs, t); | |||
} | |||
ValueRefList checknonfinite_rule( | |||
const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
const FormatTransformation& t) { | |||
auto&& inputs_ = t.unwrap_inputs(inputs); | |||
auto&& outputs_ = imperative::apply(op, inputs_); | |||
return t.wrap_outputs(outputs_); | |||
} | |||
// clang-format off | |||
#define FOREACH_IDENTITY_OP(cb) \ | |||
cb(Copy) \ | |||
cb(FastpathCopy) \ | |||
cb(TypeCvt) \ | |||
cb(Dropout) \ | |||
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ | |||
cb(Elemwise) \ | |||
cb(CompiledOp) \ | |||
cb(SubgraphOp) | |||
#define FOREACH_IDENTITY_OP(cb) \ | |||
cb(Copy) \ | |||
cb(FastpathCopy) \ | |||
cb(TypeCvt) \ | |||
cb(Dropout) \ | |||
cb(Identity) | |||
#define FOREACH_FORMAT_OP(cb) \ | |||
cb(AdaptivePooling) \ | |||
cb(WarpAffine) \ | |||
#define FOREACH_FORMAT_OP(cb) \ | |||
cb(AdaptivePooling) \ | |||
cb(WarpAffine) \ | |||
cb(Resize) | |||
#define FOREACH_FORMAT_POLICY_OP(cb)\ | |||
cb(Pooling) \ | |||
#define FOREACH_FORMAT_POLICY_OP(cb) \ | |||
cb(Pooling) \ | |||
cb(Convolution) | |||
// clang-format on | |||
// multi inputs op without params | |||
#define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \ | |||
ValueRefList Op##_rule( \ | |||
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \ | |||
const FormatTransformation& t) { \ | |||
FT format = get_inputs_format(inputs, t); \ | |||
return t.wrap_outputs( \ | |||
imperative::apply(_op, t.unwrap_inputs(inputs)), format); \ | |||
} | |||
FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE) | |||
#undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE | |||
// identity op | |||
#define CREATE_IDENTITY_OP_RULE(Op) \ | |||
ValueRefList Op##_rule( \ | |||
@@ -409,8 +444,9 @@ struct FormatRuleRegistry { | |||
register_format_rule(setsubtensor_rule<SetSubtensor>); | |||
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); | |||
register_format_rule(concat_rule); | |||
register_format_rule(elemwise_rule); | |||
register_format_rule(batchnorm_rule); | |||
register_format_rule(checknonfinite_rule); | |||
FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) | |||
FOREACH_IDENTITY_OP(REGISTER_OP_RULE) | |||
FOREACH_FORMAT_OP(REGISTER_OP_RULE) | |||
FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) | |||
@@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation( | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
} | |||
} else if (op.is<GetFormat>()) { | |||
bool is_formatted_tensor = inputs.item().is(m_value_type); | |||
if (is_formatted_tensor) { | |||
return {FormatValue::make(inputs[0].cast(m_value_type).format())}; | |||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
if (inp_ref) { | |||
return {FormatValue::make(inp_ref->format())}; | |||
} else { | |||
mgb_log_warn( | |||
"Not FormattedTensorValue input for GetFormat op: %s", | |||
inputs[0].to_string().c_str()); | |||
"Not FormattedTensorValue input for GetFormat op: %s, %s", | |||
op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
return {FormatValue::make(FT::DEFAULT)}; | |||
} | |||
} else if (op.is<Operator::IdentityLike>()) { | |||
bool is_formatted_tensor = inputs.item().is(m_value_type); | |||
if (is_formatted_tensor) { | |||
auto&& format = inputs[0].cast(m_value_type).format(); | |||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
if (inp_ref) { | |||
auto&& format = inp_ref->format(); | |||
return wrap_outputs( | |||
imperative::apply(op, unwrap_inputs(inputs)), format.type()); | |||
} else { | |||
mgb_log_warn( | |||
"Not FormattedTensorValue input for IdentityLike op: %s", | |||
inputs[0].to_string().c_str()); | |||
"Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||
op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (op.is<AttachGrad>()) { | |||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
if (inp_ref) { | |||
auto format = inp_ref->format(); | |||
GenericFunction callback = | |||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
GenericFunction new_callback = | |||
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||
auto wrapped_inputs = SmallVector<ValueRef>{ | |||
this->value_type().make(inputs_.item(), format.type())}; | |||
auto ret = callback(wrapped_inputs); | |||
return ret; | |||
}; | |||
auto&& outputs = imperative::apply( | |||
op, inp_ref->value(), FunctionValue::make(new_callback)); | |||
return wrap_outputs(outputs, format.type()); | |||
} else { | |||
mgb_log_warn( | |||
"Not FormattedTensorValue input for AttachGrad op: %s, %s", | |||
op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (auto* set_grad = op.as<SetGrad>()) { | |||
size_t nr_inputs = set_grad->nr_inputs(); | |||
size_t nr_outputs = inputs.size() - nr_inputs; | |||
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||
// run original apply. | |||
// grads needn't to unwrap and wrap, which will be unwrapped in GradTrans | |||
auto&& outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||
// handle output's formats | |||
auto wrapped_outputs = ValueRefList(nr_outputs); | |||
for (size_t i = 0; i < nr_outputs; ++i) { | |||
if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | |||
wrapped_outputs[i] = | |||
m_value_type.make(outputs[i], output_ref->format().type()); | |||
} else { | |||
mgb_log_warn( | |||
"Not FormattedTensorValue outputs for SetGrad op: %s, %s", | |||
op.to_string().c_str(), inputs_[i].to_string().c_str()); | |||
wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT); | |||
} | |||
} | |||
return wrapped_outputs; | |||
} else { | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
} | |||
@@ -47,7 +47,10 @@ public: | |||
const Operator& op, Span<ValueRef> inputs) override; | |||
ValueRef unwrap(ValueRef value) override { | |||
mgb_assert(!value.is(m_value_type)); | |||
//mgb_assert(!value.is(m_value_type)); | |||
if (auto format_val = value.as_ref(m_value_type)) { | |||
return format_val->value(); | |||
} | |||
return value; | |||
} | |||
@@ -377,6 +377,8 @@ public: | |||
SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
GenericFunction grad_fn() const { return m_grad_fn; } | |||
size_t nr_inputs() const { return m_nr_inputs; } | |||