From 273c0e874548a361149e73166b70cb963af5c878 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 1 Mar 2022 15:16:53 +0800 Subject: [PATCH] fix(autodiff): fix some bugs in relation to 2nd order grad 1. implement double backward for batchnorm 2. fix grad attach in nested grad manager 3. pad empty tensor for unsatisfied output_has_grad 4. support double backward for jit subgraph 5. support double backward for autodiff.Function 6. readd debug flag MGE_LOG_OP_DISPATCH GitOrigin-RevId: cd31ddc620a35e0582c9721df7290c972fa3c610 --- imperative/python/megengine/core/autodiff/grad.py | 5 +- imperative/python/megengine/core/tensor/utils.py | 41 ++++-- imperative/python/megengine/functional/nn.py | 8 +- imperative/python/src/grad.cpp | 8 +- imperative/python/src/grad.h | 1 + imperative/python/src/tensor.cpp | 70 +++++++---- imperative/python/src/transformation.h | 29 +++-- .../python/test/unit/autodiff/test_grad_manger.py | 2 + imperative/python/test/unit/core/test_autodiff.py | 1 + imperative/python/test/unit/core/test_function.py | 38 ++++++ imperative/python/test/unit/core/test_subgraph.py | 43 +++++++ imperative/src/impl/dispatch.cpp | 42 +++++-- imperative/src/impl/op_def.cpp | 5 +- imperative/src/impl/ops/batch_norm.cpp | 137 ++++++++++++++++++++- imperative/src/impl/ops/utility.cpp | 4 +- imperative/src/impl/subgraph_detail.cpp | 7 +- imperative/src/impl/transformations/eval.cpp | 2 +- imperative/src/impl/transformations/grad.cpp | 83 +++++++++---- .../include/megbrain/imperative/graph_builder.h | 4 +- .../src/include/megbrain/imperative/ops/utility.h | 1 - .../include/megbrain/imperative/subgraph_detail.h | 7 +- .../megbrain/imperative/transformations/grad.h | 21 +++- .../src/include/megbrain/imperative/utils/helper.h | 8 +- src/core/include/megbrain/ir/ops.td | 2 + 24 files changed, 460 insertions(+), 109 deletions(-) diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 5f4c011b..a45d24f2 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -212,10 +212,7 @@ class Function: if self.__single_output: outputs = (outputs,) - for grad in reversed(group): - if grad._impl is None: - continue - outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) + outputs = core2.set_grad(normalized_backward, args, outputs) if self.__single_output: (outputs,) = outputs return outputs diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index b816f41a..145396dc 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -209,7 +209,6 @@ def subgraph( outputs = gen.send(None) nr_outputs = len(outputs) forward_fn = build(builder, outputs, [False] * nr_outputs) - output_grads = [builder.input() for _ in range(nr_outputs)] input_grads = gen.send(output_grads) assert len(input_grads) == nr_inputs @@ -222,25 +221,49 @@ def subgraph( ] encoded_input_grads = [grad for grad in input_grads if grad is not None] backward_fn = build( - builder, encoded_input_grads, [False] * len(encoded_input_grads) + builder, encoded_input_grads, [True] * len(encoded_input_grads) ) class SubgraphOp(Function): def __init__(self): self.inputs = None + self.output_shapes = None def forward(self, *inputs): self.inputs = inputs - return apply(forward_fn(), *inputs) + outputs = apply(forward_fn(), *inputs) + if len(outputs) > 1: + self.output_shapes = [output.shape for output in outputs] + return outputs def backward(self, *output_grads): inputs = self.inputs - self.inputs = None - encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) - input_grads = [ - encoded_input_grads[i] if i is not None else None - for i in indices - ] + any_valid = False + all_valid = True + for output_grad in output_grads: + if output_grad is None: + all_valid = False + else: + any_valid = True + if not any_valid: + input_grads = [None] * len(indices) + else: + if not all_valid: + assert self.output_shapes is not None + from ...functional import zeros + + output_grads = [ + zeros(self.output_shapes[i]) if grad is None else grad + for i, grad in enumerate(output_grads) + ] + self = None + encoded_input_grads = apply( + backward_fn(), *inputs, *output_grads + ) + input_grads = [ + encoded_input_grads[i] if i is not None else None + for i in indices + ] return input_grads gen.close() diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 346766ee..56a22971 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -896,7 +896,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: @lru_cache(maxsize=None) -def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): +def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): @subgraph_fn( "LeakyReLU", dtype=dtype, @@ -925,7 +925,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: Refer to :class:`~.LeakyReLU` for more information. """ - leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) + leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) (oup,) = leakyReLU(inp) return oup @@ -1399,7 +1399,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): f("fma3", input, inv_var_wt, f("+", f("*", neg_channel_mean, inv_var_wt), bias)) - return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) + return (outvar, channel_mean, channel_var), (True, True, True) @subgraph("SyncBnStage1Inference", dtype, device, 6) def syncbn_stage1_inference(inputs, f, c): @@ -1509,7 +1509,7 @@ def sync_batch_norm( """ _eps_mode = eps_mode.lower() assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) - if _eps_mode == "additive" and not (is_distributed() and training): + if _eps_mode == "additive" and not (is_distributed() or training): return batch_norm( inp, running_mean, diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index fd9c70e1..6cf41f12 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -121,13 +121,13 @@ void GradKeyWrapper::enter() { m_key = m_transformation->key(); m_key->name(m_name); grad_key_map[m_key] = this; - TransformationManager::get_instance().register_at( - m_transformation); + m_transformation_guard = + TransformationManager::get_instance() + .register_at(m_transformation); } void GradKeyWrapper::exit() { - TransformationManager::get_instance().unregister( - m_transformation); + m_transformation_guard.reset(); grad_key_map.erase(m_key); m_key = {}; m_transformation.reset(); diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index 914b0ced..3b03ade6 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -29,6 +29,7 @@ struct GradKeyWrapper : NonCopyableObj { std::string m_name; std::shared_ptr m_key; std::shared_ptr m_transformation; + std::unique_ptr> m_transformation_guard; GradKeyWrapper(); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 75fea2ba..dac67b00 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -449,15 +449,24 @@ void init_tensor(py::module m) { interpreter::Interpreter::inst().create_channel()) ->get(); interpreter_for_py = channel; - transformations.register_at( - std::make_shared( - std::shared_ptr(channel, [](Channel*) {}))); - transformations.register_at( - std::make_shared()); - transformations.register_at( - std::make_shared()); - transformations.register_at( - std::make_shared()); + MGB_MARK_USED_VAR( + transformations + .register_at( + std::make_shared( + std::shared_ptr(channel, [](Channel*) {}))) + .release()); + MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); + MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); + MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); static py::exception py_async_error( m, "AsyncError", PyExc_RuntimeError); @@ -681,6 +690,9 @@ void init_tensor(py::module m) { std::pair> profiler; std::optional trace_result; std::function array_comparator; + std::unique_ptr> tracing_guard; + std::unique_ptr> compiled_guard; + std::unique_ptr> lazy_eval_guard; bool compare_value(ValueRef lhs, ValueRef rhs) { auto lvalue = lhs.cast_ref(); @@ -730,13 +742,16 @@ void init_tensor(py::module m) { std::make_shared(¤t_graph)); } } - transformations.register_at(self.compiled); + compiled_guard = + transformations.register_at(self.compiled); // start execute because InputCallback depends self.compiled->execute(); } else if (self.tracing) { - transformations.register_at(self.tracing); + tracing_guard = + transformations.register_at(self.tracing); if (self.lazy_eval) { - transformations.register_at(self.lazy_eval); + lazy_eval_guard = + transformations.register_at(self.lazy_eval); } } else { mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); @@ -746,16 +761,16 @@ void init_tensor(py::module m) { void exit() { auto& self = *this; if (self.tracing) { - transformations.unregister(self.tracing); + tracing_guard.reset(); self.trace_result = self.tracing->get_result(); self.tracing.reset(); if (self.lazy_eval) { auto lazy_eval = std::move(self.lazy_eval); - transformations.unregister(lazy_eval); + lazy_eval_guard.reset(); lazy_eval->check_exception(); } } else if (self.compiled) { - transformations.unregister(self.compiled); + compiled_guard.reset(); self.compiled->wait(); } else { mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); @@ -829,17 +844,19 @@ void init_tensor(py::module m) { [](Trace& self) { mgb_assert(bool(self.tracing) ^ bool(self.compiled)); if (self.tracing) { - transformations.unregister(self.tracing); + self.tracing_guard.reset(); } else if (self.compiled) { - transformations.unregister(self.compiled); + self.compiled_guard.reset(); } }) .def("end_excluded_region", [](Trace& self) { mgb_assert(bool(self.tracing) ^ bool(self.compiled)); if (self.tracing) { - transformations.register_at(self.tracing); + self.tracing_guard = + transformations.register_at(self.tracing); } else if (self.compiled) { - transformations.register_at(self.compiled); + self.compiled_guard = + transformations.register_at(self.compiled); } }); @@ -900,11 +917,8 @@ void init_tensor(py::module m) { GradKeyWrapper::get(output.cast()))); }); - m.def("set_grad", [](py::object py_key, py::function backward_fn, - std::vector inputs, + m.def("set_grad", [](py::function backward_fn, std::vector inputs, std::vector outputs) { - mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); - auto* key = reinterpret_cast(py_key.ptr())->inst(); GenericFunction generic_backward_fn = [backward_fn](Span output_grads) -> ValueRefList { py::list output_grad_tws; @@ -937,8 +951,8 @@ void init_tensor(py::module m) { values[i + inputs.size()] = outputs[i].cast().m_tensor->data(); } - auto wrapped_output_values = imperative::apply( - SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); + auto wrapped_output_values = + imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values); std::vector wrapped_outputs; mgb_assert(wrapped_output_values.size() == outputs.size()); for (auto&& output_value : wrapped_output_values) { @@ -956,8 +970,10 @@ void init_tensor(py::module m) { mgb_assert(module_trace_hook); module_trace_transformation = std::make_shared(module_trace_hook); - transformations.register_at( - module_trace_transformation); + MGB_MARK_USED_VAR(transformations + .register_at( + module_trace_transformation) + .release()); } return module_trace_transformation; }; diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index bafc4369..4999deb0 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -18,11 +18,13 @@ #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/value.h" #include "megbrain/utils/small_vector.h" namespace mgb::imperative::python { struct TransformationManager { +public: enum Segment { ModuleTrace, DTypePromote, @@ -35,8 +37,21 @@ struct TransformationManager { std::array>, 7> segments; +private: + template + void unregister(std::shared_ptr transformation) noexcept { + mgb_assert(segment < segments.size()); + auto iter = std::find( + segments[segment].begin(), segments[segment].end(), transformation); + mgb_assert(iter != segments[segment].end()); + transformation->unregister(); + segments[segment].erase(iter); + } + +public: template - void register_at(std::shared_ptr transformation) { + [[nodiscard]] std::unique_ptr> register_at( + std::shared_ptr transformation) { mgb_assert(segment < segments.size()); std::shared_ptr next; for (size_t i = segment; i < segments.size(); ++i) { @@ -51,16 +66,8 @@ struct TransformationManager { transformation->register_at(next->pos()); } segments[segment].push_back(transformation); - } - - template - void unregister(std::shared_ptr transformation) noexcept { - mgb_assert(segment < segments.size()); - auto iter = std::find( - segments[segment].begin(), segments[segment].end(), transformation); - mgb_assert(iter != segments[segment].end()); - transformation->unregister(); - segments[segment].erase(iter); + return std::make_unique>( + [this, transformation]() { unregister(transformation); }); } static TransformationManager& get_instance() { diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 393d689f..09d0160a 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -452,6 +452,8 @@ def test_2nd_grad_with_custom_gradient(): return y def backward(self, dy): + if dy is None: + return None dx = -MySin()(self.inp) * dy return dx diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 355667af..6d51fe44 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -14,6 +14,7 @@ import pytest import megengine as mge import megengine.distributed as dist import megengine.functional as F +import megengine.module as M from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core.autodiff.grad import Grad diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index 8d471bff..9d63f222 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -318,3 +318,41 @@ def test_throw_on_non_tensor_argument(): func = NonTensorArg() with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): func(x, 1) + + +def test_multiple_grad(): + data_shape = (9, 2, 6) + av = np.random.random(data_shape).astype(np.float32) + + class MulFunc(Function): + def forward(self, a): + self.a = a + return a * 10 + + def backward(self, grad_o): + return grad_o * 20 + + class Simple(Module): + def __init__(self, a): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.layer1 = MulFunc() + + def forward(self): + x = self.layer1(self.a) + return x + + net = Simple(av) + gm = ad.GradManager().attach(net.parameters()) + gm2 = ad.GradManager().attach(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) + + opt.clear_grad() + with gm: + with gm2: + loss = net() + gm.backward(loss.sum()) + opt.step() + + np.testing.assert_almost_equal(loss.numpy(), (av * 10)) + np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) diff --git a/imperative/python/test/unit/core/test_subgraph.py b/imperative/python/test/unit/core/test_subgraph.py index 81af45cb..10f22eaf 100644 --- a/imperative/python/test/unit/core/test_subgraph.py +++ b/imperative/python/test/unit/core/test_subgraph.py @@ -109,3 +109,46 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, _assert_allclose(out1.numpy(), out2.numpy()) _assert_allclose(grad1.numpy(), grad2.numpy()) + + +@functools.lru_cache(maxsize=None) +def _get_mul_fn(dtype, device): + @subgraph_fn( + "Mul", + dtype=dtype, + device=device, + nr_inputs=2, + gopt_level=None, + jit_fusion=False, + custom_grad=True, + ) + def mul(inputs, f, c): + x, y = inputs[0:2] + z = f("*", x, y) + (dz,) = yield (z,) + dx = f("*", dz, y) + dy = f("*", dz, x) + yield (dx, dy) + + return mul + + +def test_subgraph_jit_backward(): + x_np = np.random.rand(3, 4, 5).astype("float32") + x1 = megengine.Tensor(x_np) + x2 = megengine.Tensor(x_np) + mul = _get_mul_fn(x1.dtype, x1.device) + gm = GradManager() + gm.attach([x1, x2]) + with gm: + y1 = x1 * x1 + y2 = mul(x2, x2) + gm.backward(y1) + with gm: + y1 = x1 * x1 + y2 = mul(x2, x2) + gm.backward(y1 + y2) + with gm: + y1 = x1 * x1 + y2 = mul(x2, x2) + gm.backward(y2) diff --git a/imperative/src/impl/dispatch.cpp b/imperative/src/impl/dispatch.cpp index e132a8e1..7f94d314 100644 --- a/imperative/src/impl/dispatch.cpp +++ b/imperative/src/impl/dispatch.cpp @@ -18,18 +18,44 @@ namespace mgb { namespace imperative { +namespace { -ValueRefList apply(const Operator& op, Span inputs) { +ValueRefList apply_release(const Operator& op, Span inputs) { + auto& context = Transformation::get_context(); + size_t& depth = context.next_transformation; + mgb_assert(depth < context.transformations.size()); + auto& transformation = *context.transformations[depth++]; + CleanupGuard _{[&] { --depth; }}; + return transformation.apply_transformation(op, inputs); +} + +MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span inputs) { auto& context = Transformation::get_context(); size_t& depth = context.next_transformation; - // TODO: add fallback transformation - bool fallback = depth >= context.transformations.size(); - if (mgb_unlikely(fallback)) { - return op.fallback(inputs); + mgb_assert(depth < context.transformations.size()); + static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; + const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1; + mgb_log_debug( + "%s apply %s to %s", prefix, op.to_string().c_str(), + imperative::to_string(inputs).c_str()); + ValueRefList result; + auto& transformation = *context.transformations[depth++]; + CleanupGuard _{[&] { --depth; }}; + result = transformation.apply_transformation(op, inputs); + mgb_log_debug( + "%s returns %s", prefix, + imperative::to_string(Span(result)).c_str()); + return result; +} + +} // namespace + +ValueRefList apply(const Operator& op, Span inputs) { + static bool debug = MGB_GETENV("MGE_LOG_OP_DISPATCH"); + if (mgb_unlikely(debug)) { + return apply_debug(op, inputs); } else { - auto& transformation = *context.transformations[depth++]; - CleanupGuard _{[&] { --depth; }}; - return transformation.apply_transformation(op, inputs); + return apply_release(op, inputs); } } diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 49348789..4ab971dd 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -106,7 +106,8 @@ EncodedSubgraph OpDef::make_forward_graph( } std::string OpDef::to_string() const { - std::string builder = trait()->make_name(*this) + "{"; + std::string builder = trait()->name; + builder += "{"; for (auto&& [name, value] : props(*this)) { builder += name; builder += ": "; @@ -196,7 +197,7 @@ std::string Subgraph::repr() const { if (auto* p = op->try_cast_final()) { buf << p->type; } else { - buf << op->make_name(); + buf << op->to_string(); } for (size_t i : ins) { buf << " "; diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index 852b8828..f28a163e 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -11,13 +11,94 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "../op_trait.h" +#include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/proxy_graph_detail.h" +#include "megbrain/imperative/subgraph_detail.h" +#include "megbrain/tensor.h" namespace mgb { namespace imperative { - namespace { +EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) { + Subgraph::Builder builder{ + [](std::shared_ptr op, SmallVector inputs, + size_t nr_outputs) { + auto [outputs, validated] = + OpDef::infer_output_attrs_fallible(*op, inputs); + mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch"); + return outputs; + }}; + auto f = [&](auto&& op, auto... args) { + return builder.write_expr( + op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; + }; + + auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); + auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); + auto sub = Elemwise::make(Elemwise::Mode::SUB); + auto mul = Elemwise::make(Elemwise::Mode::MUL); + auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); + auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV); + auto broadcast = Broadcast::make(); + + auto c = [&](TensorPtr tensor, DType dtype) { + auto result = builder.write_constant( + tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()}); + if (tensor->dtype() != dtype) { + result = f(TypeCvt::make(dtype), result); + } + return result; + }; + auto ci = [&](megdnn::dt_int32 value) { + return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32()); + }; + auto cf = [&](megdnn::dt_float32 value) { + return c(Tensor::make_scalar(DTypeScalar(value), device), dtype); + }; + + auto desc = LogicalTensorDesc{TensorLayout{dtype}, device}; + auto x = builder.write_input(desc); + auto y_grad = builder.write_input(desc); + auto save_mean = builder.write_input(desc); + auto save_invstd = builder.write_input(desc); + auto weight = builder.write_input(desc); + auto reserved = builder.write_input(desc); + MGB_MARK_USED_VAR(reserved); + + // assert x.ndim == 4 + auto input_shape = f(GetVarShape::make(), x); + auto channels = f(GetVarShape::make(1), x); + auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1)); + auto input_elems = f(prod, input_shape); + auto reduce_size = f(floor_div, input_elems, channels); + auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size); + auto mean = f(broadcast, save_mean, input_shape); + auto invstd = save_invstd; + auto norm = f(div, cf(1), reduce_size_f); + auto output_grad_sum = f(sum, y_grad, reduce_shape); + auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape); + auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape); + auto proj_scale = + f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)), + input_shape); + auto grad_scale = f( + mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape)); + auto proj = f(mul, f(sub, x, mean), proj_scale); + auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale); + auto weight_grad = f(mul, dot_p, invstd); + auto bias_grad = output_grad_sum; + + builder.add_outputs({weight_grad, bias_grad, x_grad}); + + auto bn_backward = builder.encode(); + return bn_backward; +} + +namespace bn { + std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); return BatchNorm::make(node->param()); @@ -72,8 +153,60 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); -} // anonymous namespace +} // namespace bn + +namespace bn_backward { + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return BatchNormBackward::make(node->param()); +} + +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto& op = def.cast_final_safe(); + cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve; + x = inputs[0]; + y_grad = inputs[1]; + save_mean = inputs[2]; + save_variance = inputs[3]; + weight = inputs[4]; + if (inputs.size() == 6) { + reserve = inputs[5]; + } + return opr::BatchNormBackward::make( + x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0] + .node() + ->owner_opr() + ->usable_output(); +} + +EncodedSubgraph make_backward_graph( + const OpDef& def, const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad) { + def.cast_final_safe(); + size_t nr_inputs = 6; + size_t nr_outputs = 3; + mgb_assert(inputs.size() == nr_inputs); + mgb_assert(input_requires_grad.size() == nr_inputs); + mgb_assert(output_has_grad.size() == nr_outputs); + auto dtype = inputs[0].layout.dtype; + auto device = inputs[0].comp_node; + auto bn_backward = generate_batchnorm_backward_graph(dtype, device); + auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward( + bn_backward, inputs, input_requires_grad, output_has_grad); + return bn_double_backward; +} + +OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .make_backward_graph(make_backward_graph) + .fallback(); +} // namespace bn_backward + +} // anonymous namespace } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 37d4cfaf..870444e4 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -762,7 +762,9 @@ EncodedSubgraph make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { - return {}; + return OpDef::make_backward_graph( + *def.cast_final_safe().op, inputs, input_requires_grad, + output_has_grad); } OP_TRAIT_REG(JITFusionOp, JITFusionOp) diff --git a/imperative/src/impl/subgraph_detail.cpp b/imperative/src/impl/subgraph_detail.cpp index ecf00447..724f6cd6 100644 --- a/imperative/src/impl/subgraph_detail.cpp +++ b/imperative/src/impl/subgraph_detail.cpp @@ -96,10 +96,11 @@ SmallVector get_input_layout_constraint( return res; } -static EncodedSubgraph make_backward_graph_from_forward( +EncodedSubgraph make_backward_graph_from_forward( + const EncodedSubgraph& forward_graph, const SmallVector& inputs, const SmallVector& input_requires_grad, - const SmallVector& output_has_grad, EncodedSubgraph forward_graph) { + const SmallVector& output_has_grad) { using namespace std::placeholders; using var_t = Subgraph::var_t; using vars_t = Subgraph::vars_t; @@ -179,7 +180,7 @@ EncodedSubgraph make_backward_graph( const SmallVector& output_has_grad) { auto forward_graph = OpDef::make_forward_graph(def, inputs); return make_backward_graph_from_forward( - inputs, input_requires_grad, output_has_grad, forward_graph); + forward_graph, inputs, input_requires_grad, output_has_grad); } } // namespace subgraph_detail diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 5e357e75..d0478921 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -139,7 +139,7 @@ ValueRefList InterpreterTransformation::apply_transformation( return {ValueRef()}; } } else { - return imperative::apply(op, inputs); + return op.fallback(inputs); } } diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index 49b386ed..b7660236 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -62,7 +62,8 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( std::shared_ptr op, Span inputs, Span outputs) : backward_graph(backward_graph), output_mask_offset(inputs.size()), - grad_mask_offset(inputs.size() + outputs.size()) { + grad_mask_offset(inputs.size() + outputs.size()), + op(op) { auto& save_for_backward = backward_graph->save_for_backward; mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); size_t count = std::count_if( @@ -92,6 +93,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( closure.push_back(outputs[i]); } } + if (outputs.size() > 1) { + output_descs.reserve(outputs.size()); + for (auto&& output : outputs) { + auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0]; + output_descs.push_back({symbolic_shape, output.dtype(), output.device()}); + } + } } void BackwardGraphWithClosure::operator()( Span grads, std::function receiver) { @@ -100,23 +108,46 @@ void BackwardGraphWithClosure::operator()( for (auto&& value : closure) { args[nargs++] = value; } - bool null_grad = false; + size_t null_grad = 0; + size_t valid_grad = 0; for (size_t i = 0; i < grads.size(); ++i) { if (backward_graph->save_for_backward[grad_mask_offset + i]) { if (grads[i]) { - mgb_assert(!null_grad, "null_grad"); + valid_grad++; args[nargs++] = grads[i]; } else { - null_grad = true; + null_grad++; + nargs++; } } } - if (null_grad) { + if (valid_grad == 0) { return; } - auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); - SmallVector igrads = {igrads_.begin(), igrads_.end()}; - igrads_.clear(); + if (null_grad > 0) { + auto zeros_like = [](const OutputDesc& desc) { + HostTensorStorage storage(*desc.device); + storage.ensure_size(desc.dtype->size()); + std::memset(storage.ptr(), 0, desc.dtype->size()); + auto t = imperative::apply( + CreateTensor( + CreateTensor::Unique, *desc.device, *desc.dtype, + ValueShape()), + HostStorage::make(storage))[0]; + auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0]; + return res; + }; + nargs = closure.size(); + for (size_t i = 0; i < grads.size(); ++i) { + if (backward_graph->save_for_backward[grad_mask_offset + i]) { + if (!grads[i]) { + args[nargs] = zeros_like(output_descs[i]); + } + nargs++; + } + } + } + auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); auto&& iter = igrads.begin(); for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { if (p) { @@ -221,9 +252,11 @@ void GradKey::backward() { if (!dest) { continue; } - if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { + if (!dest.m_producer_record.next && dest->callback) { // I'm the last grad producer, invoke callback - dest->callback(dest->m_grad); + if (dest->m_grad) { + dest->callback(dest->m_grad); + } } } grad_fn->clear(); @@ -394,16 +427,22 @@ ValueRefList GradTransformation::apply_transformation( return imperative::apply(op, inputs); } if (auto* attach_grad = op.as()) { - if (!has_key(attach_grad->key())) { + auto& tensor = inputs[0]; + if (auto&& grad_value = tensor.as_ref(m_value_type)) { + mgb_assert(!has_key(attach_grad->key())); + auto output = fallback()[0]; + return record_grad(m_value_type.make(output, m_key, grad_value->slot())); + } else if (!has_key(attach_grad->key())) { return fallback(); + } else { + GenericFunction callback = + (GenericFunction&)inputs[1].cast(); + auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { + auto ret = callback({&grad, 1}); + assert(ret.empty()); + }); + return {record_grad(output)}; } - auto tensor = inputs[0]; - GenericFunction callback = (GenericFunction&)inputs[1].cast(); - auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { - auto ret = callback({&grad, 1}); - assert(ret.empty()); - }); - return {record_grad(output)}; } else if (auto* grad_backward = op.as()) { if (!has_key(grad_backward->key())) { return fallback(); @@ -431,10 +470,10 @@ ValueRefList GradTransformation::apply_transformation( mgb_assert(inputs.size() > nr_inputs); size_t nr_outputs = inputs.size() - nr_inputs; Span inputs_ = {inputs.data(), nr_inputs}; - Span outputs_ = {inputs.data() + nr_inputs, nr_outputs}; - backward.m_input_has_grad = SmallVector(nr_inputs, true); - backward.m_output_attrs = - SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); + auto outputs_ = fallback(); + backward.m_input_has_grad.resize(nr_inputs, true); + backward.m_output_attrs.resize( + nr_outputs, CustomBackward::OutputAttr{true, true}); backward.m_backward = [fn = set_grad->grad_fn()](Span inputs) { auto result = fn(inputs); return SmallVector(result.begin(), result.end()); diff --git a/imperative/src/include/megbrain/imperative/graph_builder.h b/imperative/src/include/megbrain/imperative/graph_builder.h index 048865b0..2c5ce742 100644 --- a/imperative/src/include/megbrain/imperative/graph_builder.h +++ b/imperative/src/include/megbrain/imperative/graph_builder.h @@ -31,6 +31,7 @@ class Subgraph::Builder { using infer_fn_t = std::function; using encoded_graph_t = EncodedSubgraph; using var_map_t = std::unordered_map; + using mask_t = SmallVector; vars_t m_inputs; SmallVector> m_constants; vars_t m_outputs; @@ -94,6 +95,7 @@ public: descs_t get_descs(vars_t vars) { descs_t descs; for (auto&& var : vars) { + mgb_assert(var, "invalid var"); descs.push_back(get_desc(var)); } return descs; @@ -128,4 +130,4 @@ public: expr_iter_t end() { return m_exprs.end(); } }; } // namespace imperative -} // namespace mgb \ No newline at end of file +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index 79dd970f..c182945d 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -38,7 +38,6 @@ struct ShapeInfer final : OpDefImplBase { std::shared_ptr op; SmallVector devices; SmallVector dtypes; - EncodedSubgraph graph; ShapeInfer() = default; ShapeInfer( std::shared_ptr op, SmallVector devices, diff --git a/imperative/src/include/megbrain/imperative/subgraph_detail.h b/imperative/src/include/megbrain/imperative/subgraph_detail.h index d180fb39..e55f19ad 100644 --- a/imperative/src/include/megbrain/imperative/subgraph_detail.h +++ b/imperative/src/include/megbrain/imperative/subgraph_detail.h @@ -39,6 +39,11 @@ EncodedSubgraph make_backward_graph( SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs); +EncodedSubgraph make_backward_graph_from_forward( + const EncodedSubgraph& forward, const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad); + } // namespace subgraph_detail } // namespace imperative -} // namespace mgb \ No newline at end of file +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 0a1970ae..7c5cec9d 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -29,6 +29,15 @@ struct BackwardGraphWithClosure { SmallVector closure; size_t output_mask_offset; size_t grad_mask_offset; + std::shared_ptr op; + + struct OutputDesc { + ValueRef shape; + DTypeValue::ref_t dtype; + CompNodeValue::ref_t device; + }; + + SmallVector output_descs; BackwardGraphWithClosure( std::shared_ptr backward_graph, @@ -356,20 +365,22 @@ public: class SetGrad : public OperatorImpl { private: - std::shared_ptr m_key; GenericFunction m_grad_fn; size_t m_nr_inputs; public: - SetGrad(std::shared_ptr key, GenericFunction grad_fn, size_t nr_inputs) - : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} + SetGrad(GenericFunction grad_fn, size_t nr_inputs) + : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} GenericFunction grad_fn() const { return m_grad_fn; } size_t nr_inputs() const { return m_nr_inputs; } - std::string to_string() const override { - return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); + std::string to_string() const override { return ssprintf("SetGradValue{}"); } + + ValueRefList fallback(Span inputs) const override { + auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs); + return {outputs.begin(), outputs.end()}; } }; diff --git a/imperative/src/include/megbrain/imperative/utils/helper.h b/imperative/src/include/megbrain/imperative/utils/helper.h index 36c325e7..142a7e7c 100644 --- a/imperative/src/include/megbrain/imperative/utils/helper.h +++ b/imperative/src/include/megbrain/imperative/utils/helper.h @@ -15,12 +15,14 @@ #include #include +#include "megbrain/utils/metahelper.h" + namespace mgb { namespace imperative { -template -class CleanupGuard { +template > +class CleanupGuard : public NonCopyableObj { private: T m_callback; @@ -37,4 +39,4 @@ inline std::string quoted(std::string str) { } // namespace imperative -} // namespace mgb \ No newline at end of file +} // namespace mgb diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index b244e1a3..c1cccb38 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -89,6 +89,8 @@ def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWin def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; +def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>; + def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;