GitOrigin-RevId: 6edd577a70
release-1.10
@@ -50,8 +50,6 @@ class autocast: | |||
self._origin_enabled = None | |||
self._origin_high = None | |||
self._origin_low = None | |||
self._origin_compute_mode = None | |||
self._origin_configs = None | |||
def __enter__(self): | |||
@@ -75,7 +73,7 @@ class autocast: | |||
amp._set_amp_high_prec_dtype(self._origin_high) | |||
amp._set_amp_low_prec_dtype(self._origin_low) | |||
_config._reset_execution_config(*self._origin_compute_mode) | |||
_config._reset_execution_config(*self._origin_configs) | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
@@ -15,11 +15,14 @@ from ..core import _config | |||
def _is_nchw_format(param: Tensor): | |||
# TODO: use better condition | |||
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||
return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc" | |||
def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
"""Convert NCHW Tensor to NHWC Tensor.""" | |||
if not _is_nchw_format(x): | |||
return x | |||
if x.ndim == 4: | |||
pattern = (0, 2, 3, 1) | |||
elif x.ndim == 5: | |||
@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
# TODO: use initialization from tensor after fixing format setting | |||
if x.format != "nhwc": | |||
if inplace: | |||
# reset will destroy backward grad | |||
# hostvalue should still be valid, so no d2h cost. | |||
data = x.numpy().transpose(*pattern) | |||
# reset will destroy existed backward grad | |||
x[...] = Tensor(data, format="nhwc") | |||
else: | |||
# use mge interface to maintain grad | |||
@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True): | |||
module = deepcopy(module) | |||
for name, param in module.named_tensors(): | |||
if _is_nchw_format(param): | |||
# hostvalue should still be valid, so no d2h cost. | |||
convert_tensor_format(param, inplace=True) | |||
convert_tensor_format(param, inplace=True) | |||
return module |
@@ -64,9 +64,7 @@ class Grad: | |||
continue | |||
grad.suppress() | |||
print("before backward") | |||
self._impl.backward(ys, dys) | |||
print("after backward") | |||
for grad in group: | |||
if grad is self: | |||
@@ -245,8 +245,6 @@ def conv2d( | |||
sparse_type = "dense" if groups == 1 else "group" | |||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
with _config._override(auto_format_convert=False): | |||
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) | |||
op = builtin.Convolution( | |||
stride_h=stride_h, | |||
stride_w=stride_w, | |||
@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) { | |||
} | |||
} | |||
py::object device_obj = device2obj(device, true); | |||
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none()); | |||
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | |||
} | |||
@@ -35,6 +35,7 @@ def test_basic(): | |||
b.format = "nhwc" | |||
assert b.format == "nhwc" | |||
def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||
x1 = tensor(data) | |||
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None): | |||
gm = GradManager().attach(model.parameters()) | |||
with gm: | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
expected_grads = [param.grad for param in model.parameters()] | |||
with mge.amp.autocast(): | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
expected_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||
for param in gm.attached_tensors(): | |||
param.grad = None | |||
inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | |||
model = mge.amp.convert_module_format(model) | |||
gm = GradManager().attach(model.parameters()) | |||
with gm: | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
actual_grads = [param.grad for param in model.parameters()] | |||
with mge.amp.autocast(): | |||
rst = func(*inps) | |||
gm.backward(rst) | |||
actual_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||
for expected, actual in zip(expected_grads, actual_grads): | |||
# print(param.grad) | |||
np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||
assert expected is not None | |||
assert actual is not None | |||
np.testing.assert_almost_equal(expected, actual, decimal=5) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
def test_backward_basic(is_symbolic): | |||
class Net(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.w = mge.Parameter([[2.0], [4.0], [6.0]]) | |||
self.b = mge.Parameter(-1.0) | |||
def forward(self, inp): | |||
return F.matmul(inp, self.w) + self.b | |||
inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) | |||
_compare_backward([inp], Net(), is_symbolic) | |||
@pytest.mark.parametrize("is_symbolic", [None]) | |||
@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic): | |||
class Net(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = M.Conv2d(2, 2, 1, groups=2) | |||
self.bn = M.BatchNorm2d(2) | |||
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | |||
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | |||
# self.bn = M.BatchNorm2d(2048) | |||
def forward(self, inp): | |||
# test manually convert to NHWC, usually used in detection head | |||
return self.bn(self.conv(inp)) | |||
return self.conv1(self.conv0(inp)) | |||
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | |||
_compare_backward([inp], Net(), is_symbolic) | |||
# def func(x, w, b, bn_w, bn_b): | |||
# x = F.conv2d(x, w, b, groups=2) | |||
@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu( | |||
CompNode output_cn; | |||
{ | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu"); | |||
for (auto&& info : input_infos) { | |||
auto input_cn = info->desc.comp_node; | |||
if (!output_cn.valid()) { | |||
@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu( | |||
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); | |||
} | |||
} | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu"); | |||
} | |||
SmallVector<DeviceTensorND> output_tensornds; | |||
@@ -530,7 +532,9 @@ void ChannelImpl::sync() { | |||
void ChannelImpl::sync_impl() { | |||
m_worker.wait_all_task_finish(); | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl"); | |||
check_worker_exc_unsafe(); | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl"); | |||
} | |||
void ChannelImpl::close() { | |||
@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() { | |||
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
auto& state = get_worker_state(); | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor"); | |||
m_dtr.update_used_time(dest); | |||
MGB_RECORD_EVENT( | |||
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | |||
@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
m_dtr.insert_candidate(dest); | |||
} | |||
notify_tensor_unsafe(dest); | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor"); | |||
} | |||
void ChannelImpl::release_tensor(TensorInfo* dest) { | |||
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor"); | |||
dest->ptr.reset(); | |||
auto& state = get_worker_state(); | |||
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||
m_dtr.erase_candidate(dest); | |||
} | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor"); | |||
} | |||
void ChannelImpl::regenerate(TensorInfo* dest) { | |||
@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() { | |||
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor"); | |||
mgb_assert(!m_waitee, "duplicate waitee"); | |||
m_waitee = info; | |||
m_waitee_id = Profiler::next_id(); | |||
@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
if (require_host && !host_available()) { | |||
// avoid dead lock | |||
lock.unlock(); | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock"); | |||
if (Profiler::is_profiling()) { | |||
m_worker.add_task( | |||
{Profiler::next_id(), GetValue{info}, | |||
@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
}); | |||
} | |||
lock.lock(); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock"); | |||
wait_host = true; | |||
} | |||
m_cv.wait(lock, [&]() { | |||
check_worker_exc_unsafe(); | |||
return require_host ? host_available() : static_cast<bool>(info->ptr); | |||
}); | |||
//mgb_log_warn("after cv wait"); | |||
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | |||
m_waitee = nullptr; | |||
if (wait_host) { | |||
auto err = info->ptr->comp_node().check_async_error(); | |||
mgb_assert(!err, "%s", err->what()); | |||
} | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor"); | |||
return info->ptr; | |||
} | |||
@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { | |||
if (info == m_waitee) { | |||
MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id); | |||
m_cv.notify_all(); | |||
//mgb_log_warn("cv notify_all"); | |||
} | |||
} | |||
@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
using namespace ranges::views; | |||
auto& state = get_worker_state(); | |||
auto& options = state.options; | |||
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str()); | |||
// TODO: remove std::visit for support osx 10.12 | |||
auto cmd_visitor = [&](const auto& cmd) { | |||
using T = std::decay_t<decltype(cmd)>; | |||
@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
for (auto& i : cmd.inputs) { | |||
if (mgb_unlikely(i->invalid)) { | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp"); | |||
for (auto& i : cmd.outputs) { | |||
i->invalid = true; | |||
} | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp"); | |||
return; | |||
} | |||
} | |||
@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
} | |||
cmd.dest->ptr->fetch_value(); | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue"); | |||
notify_tensor_unsafe(cmd.dest); | |||
imperative_log_profile_end("GetValue"); | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue"); | |||
} else if constexpr (std::is_same_v<T, Drop>) { | |||
if (cmd.dest->invalid) | |||
return; | |||
@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
cmd_visitor(cmd); | |||
} catch (...) { | |||
MGB_LOCK_GUARD(m_mutex); | |||
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception"); | |||
if constexpr (std::is_same_v<T, ApplyOp>) { | |||
for (auto oup : cmd.outputs) { | |||
oup->invalid = true; | |||
@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
if (m_waitee) { | |||
notify_tensor_unsafe(m_waitee); | |||
} | |||
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception"); | |||
} | |||
}, | |||
icmd.data); | |||
@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||
tensor.format().to_string().c_str(), | |||
Format(target).to_string().c_str()); | |||
} | |||
auto output = imperative::apply( | |||
*Dimshuffle::make(pattern, scope), | |||
SmallVector<ValueRef>{tensor.value()})[0]; | |||
auto output = | |||
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; | |||
return m_value_type.make(output, target); | |||
} | |||
@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||
} | |||
} | |||
std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) { | |||
auto out = std::vector<int32_t>(shape); | |||
if (shape.size() == 4) { | |||
out[1] = shape[2]; | |||
out[2] = shape[3]; | |||
out[3] = shape[1]; | |||
return out; | |||
} else if (shape.size() == 5) { | |||
// GIOHW -> GIHWO | |||
out[2] = shape[3]; | |||
out[3] = shape[4]; | |||
out[4] = shape[2]; | |||
return out; | |||
} else { | |||
mgb_throw( | |||
MegBrainError, | |||
"Unsupported shape ndim %u in convert NCHW shape to NHWC.", | |||
shape.size()); | |||
} | |||
} | |||
using FormatRule = std::function<ValueRefList( | |||
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; | |||
static std::unordered_map<Typeinfo*, FormatRule> format_rules; | |||
@@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { | |||
ValueRefList reshape_rule( | |||
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
const FormatTransformation& t) { | |||
mgb_assert(inputs.size() == 2); | |||
mgb_assert(inputs.size() >= 1); | |||
auto& src = inputs[0].cast(t.value_type()); | |||
if (auto_convert && src.format() == FT::NHWC) { | |||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
if (shape.layout().total_nr_elems() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
return t.wrap_outputs(outputs); | |||
if (inputs.size() == 1) { | |||
if (op.shape.size() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||
auto outputs = imperative::apply( | |||
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply(op, {nchw_src}); | |||
return t.wrap_outputs(outputs); | |||
} | |||
} else if (inputs.size() == 2) { | |||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
if (shape.layout().total_nr_elems() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
auto outputs = imperative::apply( | |||
op, | |||
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
return t.wrap_outputs(outputs); | |||
} | |||
} | |||
} | |||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | |||
@@ -180,22 +216,38 @@ ValueRefList reshape_rule( | |||
ValueRefList broadcast_rule( | |||
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
const FormatTransformation& t) { | |||
mgb_assert(inputs.size() == 2); | |||
mgb_assert(inputs.size() >= 1); | |||
auto& src = inputs[0].cast(t.value_type()); | |||
if (auto_convert && src.format() == FT::NHWC) { | |||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
if (shape.layout().total_nr_elems() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
return t.wrap_outputs(outputs); | |||
if (inputs.size() == 1) { | |||
if (op.shape.size() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||
auto outputs = imperative::apply( | |||
*Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply(op, {nchw_src}); | |||
return t.wrap_outputs(outputs); | |||
} | |||
} else if (inputs.size() == 2) { | |||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
if (shape.layout().total_nr_elems() == 4) { | |||
// output is still NHWC format | |||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
auto outputs = imperative::apply( | |||
op, | |||
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
return t.wrap_outputs(outputs, FT::NHWC); | |||
} else { | |||
// will not maintain src's format | |||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
auto outputs = imperative::apply( | |||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
return t.wrap_outputs(outputs); | |||
} | |||
} | |||
} | |||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | |||
@@ -240,8 +292,7 @@ ValueRefList subtensor_rule( | |||
// only support NHWC2NCHW convert, otherwise maintain src's format | |||
if (!(auto_convert && src.format() == FT::NHWC)) { | |||
return {t.wrap_output( | |||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
src.format())}; | |||
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||
} | |||
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | |||
auto outputs = imperative::apply( | |||
@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule( | |||
// only support NHWC2NCHW convert, otherwise maintain src's format | |||
if (!(auto_convert && src.format() == FT::NHWC)) { | |||
return {t.wrap_output( | |||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
src.format())}; | |||
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||
} | |||
// value has been broadcasted to src's fake NCHW shape. | |||
auto& value = inputs[1].cast(t.value_type()); | |||
@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper( | |||
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
// mgb_assert(inputs.size() == 1); | |||
auto& src = inputs[0].cast(t.value_type()); | |||
return t.wrap_outputs( | |||
imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||
} | |||
ValueRefList batchnorm_rule( | |||
@@ -457,6 +506,7 @@ struct FormatRuleRegistry { | |||
ValueRefList FormatTransformation::apply_transformation( | |||
const Operator& op, Span<ValueRef> inputs) { | |||
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); | |||
if (auto* apply_op = op.as<ApplyOp>()) { | |||
// all inputs should be FormattedTensorValue | |||
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | |||
@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
} | |||
case GetAttr::Value: { | |||
auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); | |||
return imperative::apply(op, SmallVector<ValueRef>{nchw_src}); | |||
return imperative::apply(op, {nchw_src}); | |||
} | |||
default: | |||
return imperative::apply(op, unwrap_inputs(inputs)); | |||
@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
if (inp_ref) { | |||
auto&& format = inp_ref->format(); | |||
return wrap_outputs( | |||
imperative::apply(op, unwrap_inputs(inputs)), format); | |||
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | |||
} else { | |||
mgb_log_warn( | |||
"Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||
@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
auto format = inp_ref->format(); | |||
GenericFunction callback = | |||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
// make param grads as FormattedTensor | |||
GenericFunction new_callback = | |||
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||
auto wrapped_inputs = SmallVector<ValueRef>{ | |||
@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
}; | |||
auto&& outputs = imperative::apply( | |||
op, inp_ref->value(), FunctionValue::make(new_callback)); | |||
// make params(GradValue) as FormattedTensor | |||
return wrap_outputs(outputs, format); | |||
} else { | |||
mgb_log_warn( | |||
@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
return imperative::apply(op, inputs); | |||
} | |||
} else if (auto* set_grad = op.as<SetGrad>()) { | |||
// make grads in Function backward as FormattedTensor | |||
size_t nr_inputs = set_grad->nr_inputs(); | |||
size_t nr_outputs = inputs.size() - nr_inputs; | |||
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
@@ -377,8 +377,6 @@ public: | |||
SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
std::shared_ptr<GradKey> key() const { return m_key; } | |||
GenericFunction grad_fn() const { return m_grad_fn; } | |||
size_t nr_inputs() const { return m_nr_inputs; } | |||