GitOrigin-RevId: 6edd577a70
release-1.10
@@ -50,8 +50,6 @@ class autocast: | |||||
self._origin_enabled = None | self._origin_enabled = None | ||||
self._origin_high = None | self._origin_high = None | ||||
self._origin_low = None | self._origin_low = None | ||||
self._origin_compute_mode = None | |||||
self._origin_configs = None | self._origin_configs = None | ||||
def __enter__(self): | def __enter__(self): | ||||
@@ -75,7 +73,7 @@ class autocast: | |||||
amp._set_amp_high_prec_dtype(self._origin_high) | amp._set_amp_high_prec_dtype(self._origin_high) | ||||
amp._set_amp_low_prec_dtype(self._origin_low) | amp._set_amp_low_prec_dtype(self._origin_low) | ||||
_config._reset_execution_config(*self._origin_compute_mode) | |||||
_config._reset_execution_config(*self._origin_configs) | |||||
def __call__(self, func): | def __call__(self, func): | ||||
@functools.wraps(func) | @functools.wraps(func) | ||||
@@ -15,11 +15,14 @@ from ..core import _config | |||||
def _is_nchw_format(param: Tensor): | def _is_nchw_format(param: Tensor): | ||||
# TODO: use better condition | # TODO: use better condition | ||||
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||||
return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc" | |||||
def convert_tensor_format(x: Tensor, inplace: bool = True): | def convert_tensor_format(x: Tensor, inplace: bool = True): | ||||
"""Convert NCHW Tensor to NHWC Tensor.""" | """Convert NCHW Tensor to NHWC Tensor.""" | ||||
if not _is_nchw_format(x): | |||||
return x | |||||
if x.ndim == 4: | if x.ndim == 4: | ||||
pattern = (0, 2, 3, 1) | pattern = (0, 2, 3, 1) | ||||
elif x.ndim == 5: | elif x.ndim == 5: | ||||
@@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
# TODO: use initialization from tensor after fixing format setting | # TODO: use initialization from tensor after fixing format setting | ||||
if x.format != "nhwc": | if x.format != "nhwc": | ||||
if inplace: | if inplace: | ||||
# reset will destroy backward grad | |||||
# hostvalue should still be valid, so no d2h cost. | |||||
data = x.numpy().transpose(*pattern) | data = x.numpy().transpose(*pattern) | ||||
# reset will destroy existed backward grad | |||||
x[...] = Tensor(data, format="nhwc") | x[...] = Tensor(data, format="nhwc") | ||||
else: | else: | ||||
# use mge interface to maintain grad | # use mge interface to maintain grad | ||||
@@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True): | |||||
module = deepcopy(module) | module = deepcopy(module) | ||||
for name, param in module.named_tensors(): | for name, param in module.named_tensors(): | ||||
if _is_nchw_format(param): | |||||
# hostvalue should still be valid, so no d2h cost. | |||||
convert_tensor_format(param, inplace=True) | |||||
convert_tensor_format(param, inplace=True) | |||||
return module | return module |
@@ -64,9 +64,7 @@ class Grad: | |||||
continue | continue | ||||
grad.suppress() | grad.suppress() | ||||
print("before backward") | |||||
self._impl.backward(ys, dys) | self._impl.backward(ys, dys) | ||||
print("after backward") | |||||
for grad in group: | for grad in group: | ||||
if grad is self: | if grad is self: | ||||
@@ -245,8 +245,6 @@ def conv2d( | |||||
sparse_type = "dense" if groups == 1 else "group" | sparse_type = "dense" if groups == 1 else "group" | ||||
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
with _config._override(auto_format_convert=False): | |||||
print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) | |||||
op = builtin.Convolution( | op = builtin.Convolution( | ||||
stride_h=stride_h, | stride_h=stride_h, | ||||
stride_w=stride_w, | stride_w=stride_w, | ||||
@@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) { | |||||
} | } | ||||
} | } | ||||
py::object device_obj = device2obj(device, true); | py::object device_obj = device2obj(device, true); | ||||
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||||
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none()); | |||||
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | ||||
} | } | ||||
@@ -35,6 +35,7 @@ def test_basic(): | |||||
b.format = "nhwc" | b.format = "nhwc" | ||||
assert b.format == "nhwc" | assert b.format == "nhwc" | ||||
def _compare_nchw_nhwc(data, func, is_symbolic=None): | def _compare_nchw_nhwc(data, func, is_symbolic=None): | ||||
x1 = tensor(data) | x1 = tensor(data) | ||||
x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | ||||
@@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None): | |||||
gm = GradManager().attach(model.parameters()) | gm = GradManager().attach(model.parameters()) | ||||
with gm: | with gm: | ||||
rst = func(*inps) | |||||
gm.backward(rst) | |||||
expected_grads = [param.grad for param in model.parameters()] | |||||
with mge.amp.autocast(): | |||||
rst = func(*inps) | |||||
gm.backward(rst) | |||||
expected_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||||
for param in gm.attached_tensors(): | |||||
param.grad = None | |||||
inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | ||||
model = mge.amp.convert_module_format(model) | model = mge.amp.convert_module_format(model) | ||||
gm = GradManager().attach(model.parameters()) | gm = GradManager().attach(model.parameters()) | ||||
with gm: | with gm: | ||||
rst = func(*inps) | |||||
gm.backward(rst) | |||||
actual_grads = [param.grad for param in model.parameters()] | |||||
with mge.amp.autocast(): | |||||
rst = func(*inps) | |||||
gm.backward(rst) | |||||
actual_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||||
for expected, actual in zip(expected_grads, actual_grads): | for expected, actual in zip(expected_grads, actual_grads): | ||||
# print(param.grad) | |||||
np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||||
assert expected is not None | |||||
assert actual is not None | |||||
np.testing.assert_almost_equal(expected, actual, decimal=5) | |||||
@pytest.mark.parametrize("is_symbolic", [None]) | |||||
def test_backward_basic(is_symbolic): | |||||
class Net(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.w = mge.Parameter([[2.0], [4.0], [6.0]]) | |||||
self.b = mge.Parameter(-1.0) | |||||
def forward(self, inp): | |||||
return F.matmul(inp, self.w) + self.b | |||||
inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) | |||||
_compare_backward([inp], Net(), is_symbolic) | |||||
@pytest.mark.parametrize("is_symbolic", [None]) | @pytest.mark.parametrize("is_symbolic", [None]) | ||||
@@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic): | |||||
class Net(M.Module): | class Net(M.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.conv = M.Conv2d(2, 2, 1, groups=2) | |||||
self.bn = M.BatchNorm2d(2) | |||||
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | |||||
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | |||||
# self.bn = M.BatchNorm2d(2048) | |||||
def forward(self, inp): | def forward(self, inp): | ||||
# test manually convert to NHWC, usually used in detection head | # test manually convert to NHWC, usually used in detection head | ||||
return self.bn(self.conv(inp)) | |||||
return self.conv1(self.conv0(inp)) | |||||
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||||
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | |||||
_compare_backward([inp], Net(), is_symbolic) | _compare_backward([inp], Net(), is_symbolic) | ||||
# def func(x, w, b, bn_w, bn_b): | # def func(x, w, b, bn_w, bn_b): | ||||
# x = F.conv2d(x, w, b, groups=2) | # x = F.conv2d(x, w, b, groups=2) | ||||
@@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu( | |||||
CompNode output_cn; | CompNode output_cn; | ||||
{ | { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu"); | |||||
for (auto&& info : input_infos) { | for (auto&& info : input_infos) { | ||||
auto input_cn = info->desc.comp_node; | auto input_cn = info->desc.comp_node; | ||||
if (!output_cn.valid()) { | if (!output_cn.valid()) { | ||||
@@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu( | |||||
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); | input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); | ||||
} | } | ||||
} | } | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu"); | |||||
} | } | ||||
SmallVector<DeviceTensorND> output_tensornds; | SmallVector<DeviceTensorND> output_tensornds; | ||||
@@ -530,7 +532,9 @@ void ChannelImpl::sync() { | |||||
void ChannelImpl::sync_impl() { | void ChannelImpl::sync_impl() { | ||||
m_worker.wait_all_task_finish(); | m_worker.wait_all_task_finish(); | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl"); | |||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl"); | |||||
} | } | ||||
void ChannelImpl::close() { | void ChannelImpl::close() { | ||||
@@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() { | |||||
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | ||||
auto& state = get_worker_state(); | auto& state = get_worker_state(); | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor"); | |||||
m_dtr.update_used_time(dest); | m_dtr.update_used_time(dest); | ||||
MGB_RECORD_EVENT( | MGB_RECORD_EVENT( | ||||
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | ||||
@@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||||
m_dtr.insert_candidate(dest); | m_dtr.insert_candidate(dest); | ||||
} | } | ||||
notify_tensor_unsafe(dest); | notify_tensor_unsafe(dest); | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor"); | |||||
} | } | ||||
void ChannelImpl::release_tensor(TensorInfo* dest) { | void ChannelImpl::release_tensor(TensorInfo* dest) { | ||||
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); | MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor"); | |||||
dest->ptr.reset(); | dest->ptr.reset(); | ||||
auto& state = get_worker_state(); | auto& state = get_worker_state(); | ||||
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | ||||
m_dtr.erase_candidate(dest); | m_dtr.erase_candidate(dest); | ||||
} | } | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor"); | |||||
} | } | ||||
void ChannelImpl::regenerate(TensorInfo* dest) { | void ChannelImpl::regenerate(TensorInfo* dest) { | ||||
@@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() { | |||||
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | ||||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor"); | |||||
mgb_assert(!m_waitee, "duplicate waitee"); | mgb_assert(!m_waitee, "duplicate waitee"); | ||||
m_waitee = info; | m_waitee = info; | ||||
m_waitee_id = Profiler::next_id(); | m_waitee_id = Profiler::next_id(); | ||||
@@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||||
if (require_host && !host_available()) { | if (require_host && !host_available()) { | ||||
// avoid dead lock | // avoid dead lock | ||||
lock.unlock(); | lock.unlock(); | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock"); | |||||
if (Profiler::is_profiling()) { | if (Profiler::is_profiling()) { | ||||
m_worker.add_task( | m_worker.add_task( | ||||
{Profiler::next_id(), GetValue{info}, | {Profiler::next_id(), GetValue{info}, | ||||
@@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||||
}); | }); | ||||
} | } | ||||
lock.lock(); | lock.lock(); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock"); | |||||
wait_host = true; | wait_host = true; | ||||
} | } | ||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return require_host ? host_available() : static_cast<bool>(info->ptr); | return require_host ? host_available() : static_cast<bool>(info->ptr); | ||||
}); | }); | ||||
//mgb_log_warn("after cv wait"); | |||||
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | ||||
m_waitee = nullptr; | m_waitee = nullptr; | ||||
if (wait_host) { | if (wait_host) { | ||||
auto err = info->ptr->comp_node().check_async_error(); | auto err = info->ptr->comp_node().check_async_error(); | ||||
mgb_assert(!err, "%s", err->what()); | mgb_assert(!err, "%s", err->what()); | ||||
} | } | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor"); | |||||
return info->ptr; | return info->ptr; | ||||
} | } | ||||
@@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { | |||||
if (info == m_waitee) { | if (info == m_waitee) { | ||||
MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id); | MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id); | ||||
m_cv.notify_all(); | m_cv.notify_all(); | ||||
//mgb_log_warn("cv notify_all"); | |||||
} | } | ||||
} | } | ||||
@@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||||
using namespace ranges::views; | using namespace ranges::views; | ||||
auto& state = get_worker_state(); | auto& state = get_worker_state(); | ||||
auto& options = state.options; | auto& options = state.options; | ||||
//mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str()); | |||||
// TODO: remove std::visit for support osx 10.12 | // TODO: remove std::visit for support osx 10.12 | ||||
auto cmd_visitor = [&](const auto& cmd) { | auto cmd_visitor = [&](const auto& cmd) { | ||||
using T = std::decay_t<decltype(cmd)>; | using T = std::decay_t<decltype(cmd)>; | ||||
@@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||||
for (auto& i : cmd.inputs) { | for (auto& i : cmd.inputs) { | ||||
if (mgb_unlikely(i->invalid)) { | if (mgb_unlikely(i->invalid)) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp"); | |||||
for (auto& i : cmd.outputs) { | for (auto& i : cmd.outputs) { | ||||
i->invalid = true; | i->invalid = true; | ||||
} | } | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp"); | |||||
return; | return; | ||||
} | } | ||||
} | } | ||||
@@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||||
} | } | ||||
cmd.dest->ptr->fetch_value(); | cmd.dest->ptr->fetch_value(); | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD GetValue"); | |||||
notify_tensor_unsafe(cmd.dest); | notify_tensor_unsafe(cmd.dest); | ||||
imperative_log_profile_end("GetValue"); | imperative_log_profile_end("GetValue"); | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD GetValue"); | |||||
} else if constexpr (std::is_same_v<T, Drop>) { | } else if constexpr (std::is_same_v<T, Drop>) { | ||||
if (cmd.dest->invalid) | if (cmd.dest->invalid) | ||||
return; | return; | ||||
@@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||||
cmd_visitor(cmd); | cmd_visitor(cmd); | ||||
} catch (...) { | } catch (...) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
//mgb_log_warn(">>> MGB_LOCK_GUARD catch exception"); | |||||
if constexpr (std::is_same_v<T, ApplyOp>) { | if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
for (auto oup : cmd.outputs) { | for (auto oup : cmd.outputs) { | ||||
oup->invalid = true; | oup->invalid = true; | ||||
@@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||||
if (m_waitee) { | if (m_waitee) { | ||||
notify_tensor_unsafe(m_waitee); | notify_tensor_unsafe(m_waitee); | ||||
} | } | ||||
//mgb_log_warn("<<< MGB_LOCK_GUARD catch exception"); | |||||
} | } | ||||
}, | }, | ||||
icmd.data); | icmd.data); | ||||
@@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||||
tensor.format().to_string().c_str(), | tensor.format().to_string().c_str(), | ||||
Format(target).to_string().c_str()); | Format(target).to_string().c_str()); | ||||
} | } | ||||
auto output = imperative::apply( | |||||
*Dimshuffle::make(pattern, scope), | |||||
SmallVector<ValueRef>{tensor.value()})[0]; | |||||
auto output = | |||||
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; | |||||
return m_value_type.make(output, target); | return m_value_type.make(output, target); | ||||
} | } | ||||
@@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||||
} | } | ||||
} | } | ||||
std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) { | |||||
auto out = std::vector<int32_t>(shape); | |||||
if (shape.size() == 4) { | |||||
out[1] = shape[2]; | |||||
out[2] = shape[3]; | |||||
out[3] = shape[1]; | |||||
return out; | |||||
} else if (shape.size() == 5) { | |||||
// GIOHW -> GIHWO | |||||
out[2] = shape[3]; | |||||
out[3] = shape[4]; | |||||
out[4] = shape[2]; | |||||
return out; | |||||
} else { | |||||
mgb_throw( | |||||
MegBrainError, | |||||
"Unsupported shape ndim %u in convert NCHW shape to NHWC.", | |||||
shape.size()); | |||||
} | |||||
} | |||||
using FormatRule = std::function<ValueRefList( | using FormatRule = std::function<ValueRefList( | ||||
const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; | const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; | ||||
static std::unordered_map<Typeinfo*, FormatRule> format_rules; | static std::unordered_map<Typeinfo*, FormatRule> format_rules; | ||||
@@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { | |||||
ValueRefList reshape_rule( | ValueRefList reshape_rule( | ||||
const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, | const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, | ||||
const FormatTransformation& t) { | const FormatTransformation& t) { | ||||
mgb_assert(inputs.size() == 2); | |||||
mgb_assert(inputs.size() >= 1); | |||||
auto& src = inputs[0].cast(t.value_type()); | auto& src = inputs[0].cast(t.value_type()); | ||||
if (auto_convert && src.format() == FT::NHWC) { | if (auto_convert && src.format() == FT::NHWC) { | ||||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||||
if (shape.layout().total_nr_elems() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||||
return t.wrap_outputs(outputs); | |||||
if (inputs.size() == 1) { | |||||
if (op.shape.size() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||||
auto outputs = imperative::apply( | |||||
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply(op, {nchw_src}); | |||||
return t.wrap_outputs(outputs); | |||||
} | |||||
} else if (inputs.size() == 2) { | |||||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||||
if (shape.layout().total_nr_elems() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
auto outputs = imperative::apply( | |||||
op, | |||||
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||||
return t.wrap_outputs(outputs); | |||||
} | |||||
} | } | ||||
} | } | ||||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | ||||
@@ -180,22 +216,38 @@ ValueRefList reshape_rule( | |||||
ValueRefList broadcast_rule( | ValueRefList broadcast_rule( | ||||
const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, | const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, | ||||
const FormatTransformation& t) { | const FormatTransformation& t) { | ||||
mgb_assert(inputs.size() == 2); | |||||
mgb_assert(inputs.size() >= 1); | |||||
auto& src = inputs[0].cast(t.value_type()); | auto& src = inputs[0].cast(t.value_type()); | ||||
if (auto_convert && src.format() == FT::NHWC) { | if (auto_convert && src.format() == FT::NHWC) { | ||||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||||
if (shape.layout().total_nr_elems() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||||
return t.wrap_outputs(outputs); | |||||
if (inputs.size() == 1) { | |||||
if (op.shape.size() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||||
auto outputs = imperative::apply( | |||||
*Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply(op, {nchw_src}); | |||||
return t.wrap_outputs(outputs); | |||||
} | |||||
} else if (inputs.size() == 2) { | |||||
auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||||
if (shape.layout().total_nr_elems() == 4) { | |||||
// output is still NHWC format | |||||
auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||||
auto outputs = imperative::apply( | |||||
op, | |||||
SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||||
return t.wrap_outputs(outputs, FT::NHWC); | |||||
} else { | |||||
// will not maintain src's format | |||||
auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||||
auto outputs = imperative::apply( | |||||
op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||||
return t.wrap_outputs(outputs); | |||||
} | |||||
} | } | ||||
} | } | ||||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | ||||
@@ -240,8 +292,7 @@ ValueRefList subtensor_rule( | |||||
// only support NHWC2NCHW convert, otherwise maintain src's format | // only support NHWC2NCHW convert, otherwise maintain src's format | ||||
if (!(auto_convert && src.format() == FT::NHWC)) { | if (!(auto_convert && src.format() == FT::NHWC)) { | ||||
return {t.wrap_output( | return {t.wrap_output( | ||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||||
src.format())}; | |||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||||
} | } | ||||
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | ||||
auto outputs = imperative::apply( | auto outputs = imperative::apply( | ||||
@@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule( | |||||
// only support NHWC2NCHW convert, otherwise maintain src's format | // only support NHWC2NCHW convert, otherwise maintain src's format | ||||
if (!(auto_convert && src.format() == FT::NHWC)) { | if (!(auto_convert && src.format() == FT::NHWC)) { | ||||
return {t.wrap_output( | return {t.wrap_output( | ||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||||
src.format())}; | |||||
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||||
} | } | ||||
// value has been broadcasted to src's fake NCHW shape. | // value has been broadcasted to src's fake NCHW shape. | ||||
auto& value = inputs[1].cast(t.value_type()); | auto& value = inputs[1].cast(t.value_type()); | ||||
@@ -329,8 +379,7 @@ ValueRefList identity_rule_helper( | |||||
const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | ||||
// mgb_assert(inputs.size() == 1); | // mgb_assert(inputs.size() == 1); | ||||
auto& src = inputs[0].cast(t.value_type()); | auto& src = inputs[0].cast(t.value_type()); | ||||
return t.wrap_outputs( | |||||
imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||||
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||||
} | } | ||||
ValueRefList batchnorm_rule( | ValueRefList batchnorm_rule( | ||||
@@ -457,6 +506,7 @@ struct FormatRuleRegistry { | |||||
ValueRefList FormatTransformation::apply_transformation( | ValueRefList FormatTransformation::apply_transformation( | ||||
const Operator& op, Span<ValueRef> inputs) { | const Operator& op, Span<ValueRef> inputs) { | ||||
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); | |||||
if (auto* apply_op = op.as<ApplyOp>()) { | if (auto* apply_op = op.as<ApplyOp>()) { | ||||
// all inputs should be FormattedTensorValue | // all inputs should be FormattedTensorValue | ||||
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | ||||
@@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
} | } | ||||
case GetAttr::Value: { | case GetAttr::Value: { | ||||
auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); | auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); | ||||
return imperative::apply(op, SmallVector<ValueRef>{nchw_src}); | |||||
return imperative::apply(op, {nchw_src}); | |||||
} | } | ||||
default: | default: | ||||
return imperative::apply(op, unwrap_inputs(inputs)); | return imperative::apply(op, unwrap_inputs(inputs)); | ||||
@@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
if (inp_ref) { | if (inp_ref) { | ||||
auto&& format = inp_ref->format(); | auto&& format = inp_ref->format(); | ||||
return wrap_outputs( | |||||
imperative::apply(op, unwrap_inputs(inputs)), format); | |||||
return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | |||||
} else { | } else { | ||||
mgb_log_warn( | mgb_log_warn( | ||||
"Not FormattedTensorValue input for IdentityLike op: %s, %s", | "Not FormattedTensorValue input for IdentityLike op: %s, %s", | ||||
@@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
auto format = inp_ref->format(); | auto format = inp_ref->format(); | ||||
GenericFunction callback = | GenericFunction callback = | ||||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | (GenericFunction&)inputs[1].cast<FunctionValue>(); | ||||
// make param grads as FormattedTensor | |||||
GenericFunction new_callback = | GenericFunction new_callback = | ||||
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | ||||
auto wrapped_inputs = SmallVector<ValueRef>{ | auto wrapped_inputs = SmallVector<ValueRef>{ | ||||
@@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
}; | }; | ||||
auto&& outputs = imperative::apply( | auto&& outputs = imperative::apply( | ||||
op, inp_ref->value(), FunctionValue::make(new_callback)); | op, inp_ref->value(), FunctionValue::make(new_callback)); | ||||
// make params(GradValue) as FormattedTensor | |||||
return wrap_outputs(outputs, format); | return wrap_outputs(outputs, format); | ||||
} else { | } else { | ||||
mgb_log_warn( | mgb_log_warn( | ||||
@@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
} else if (auto* set_grad = op.as<SetGrad>()) { | } else if (auto* set_grad = op.as<SetGrad>()) { | ||||
// make grads in Function backward as FormattedTensor | |||||
size_t nr_inputs = set_grad->nr_inputs(); | size_t nr_inputs = set_grad->nr_inputs(); | ||||
size_t nr_outputs = inputs.size() - nr_inputs; | size_t nr_outputs = inputs.size() - nr_inputs; | ||||
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | ||||
@@ -377,8 +377,6 @@ public: | |||||
SetGrad(GenericFunction grad_fn, size_t nr_inputs) | SetGrad(GenericFunction grad_fn, size_t nr_inputs) | ||||
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | ||||
std::shared_ptr<GradKey> key() const { return m_key; } | |||||
GenericFunction grad_fn() const { return m_grad_fn; } | GenericFunction grad_fn() const { return m_grad_fn; } | ||||
size_t nr_inputs() const { return m_nr_inputs; } | size_t nr_inputs() const { return m_nr_inputs; } | ||||