1. implement double backward for batchnorm
2. fix grad attach in nested grad manager
3. pad empty tensor for unsatisfied output_has_grad
4. support double backward for jit subgraph
5. support double backward for autodiff.Function
6. readd debug flag MGE_LOG_OP_DISPATCH
GitOrigin-RevId: cd31ddc620
tags/v1.9.0
@@ -212,10 +212,7 @@ class Function: | |||
if self.__single_output: | |||
outputs = (outputs,) | |||
for grad in reversed(group): | |||
if grad._impl is None: | |||
continue | |||
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||
outputs = core2.set_grad(normalized_backward, args, outputs) | |||
if self.__single_output: | |||
(outputs,) = outputs | |||
return outputs | |||
@@ -209,7 +209,6 @@ def subgraph( | |||
outputs = gen.send(None) | |||
nr_outputs = len(outputs) | |||
forward_fn = build(builder, outputs, [False] * nr_outputs) | |||
output_grads = [builder.input() for _ in range(nr_outputs)] | |||
input_grads = gen.send(output_grads) | |||
assert len(input_grads) == nr_inputs | |||
@@ -222,25 +221,49 @@ def subgraph( | |||
] | |||
encoded_input_grads = [grad for grad in input_grads if grad is not None] | |||
backward_fn = build( | |||
builder, encoded_input_grads, [False] * len(encoded_input_grads) | |||
builder, encoded_input_grads, [True] * len(encoded_input_grads) | |||
) | |||
class SubgraphOp(Function): | |||
def __init__(self): | |||
self.inputs = None | |||
self.output_shapes = None | |||
def forward(self, *inputs): | |||
self.inputs = inputs | |||
return apply(forward_fn(), *inputs) | |||
outputs = apply(forward_fn(), *inputs) | |||
if len(outputs) > 1: | |||
self.output_shapes = [output.shape for output in outputs] | |||
return outputs | |||
def backward(self, *output_grads): | |||
inputs = self.inputs | |||
self.inputs = None | |||
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) | |||
input_grads = [ | |||
encoded_input_grads[i] if i is not None else None | |||
for i in indices | |||
] | |||
any_valid = False | |||
all_valid = True | |||
for output_grad in output_grads: | |||
if output_grad is None: | |||
all_valid = False | |||
else: | |||
any_valid = True | |||
if not any_valid: | |||
input_grads = [None] * len(indices) | |||
else: | |||
if not all_valid: | |||
assert self.output_shapes is not None | |||
from ...functional import zeros | |||
output_grads = [ | |||
zeros(self.output_shapes[i]) if grad is None else grad | |||
for i, grad in enumerate(output_grads) | |||
] | |||
self = None | |||
encoded_input_grads = apply( | |||
backward_fn(), *inputs, *output_grads | |||
) | |||
input_grads = [ | |||
encoded_input_grads[i] if i is not None else None | |||
for i in indices | |||
] | |||
return input_grads | |||
gen.close() | |||
@@ -896,7 +896,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: | |||
@lru_cache(maxsize=None) | |||
def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): | |||
def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): | |||
@subgraph_fn( | |||
"LeakyReLU", | |||
dtype=dtype, | |||
@@ -925,7 +925,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||
Refer to :class:`~.LeakyReLU` for more information. | |||
""" | |||
leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||
leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||
(oup,) = leakyReLU(inp) | |||
return oup | |||
@@ -1399,7 +1399,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): | |||
f("fma3", input, inv_var_wt, | |||
f("+", f("*", neg_channel_mean, inv_var_wt), | |||
bias)) | |||
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) | |||
return (outvar, channel_mean, channel_var), (True, True, True) | |||
@subgraph("SyncBnStage1Inference", dtype, device, 6) | |||
def syncbn_stage1_inference(inputs, f, c): | |||
@@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||
""" | |||
_eps_mode = eps_mode.lower() | |||
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | |||
if _eps_mode == "additive" and not (is_distributed() and training): | |||
if _eps_mode == "additive" and not (is_distributed() or training): | |||
return batch_norm( | |||
inp, | |||
running_mean, | |||
@@ -121,13 +121,13 @@ void GradKeyWrapper::enter() { | |||
m_key = m_transformation->key(); | |||
m_key->name(m_name); | |||
grad_key_map[m_key] = this; | |||
TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||
m_transformation); | |||
m_transformation_guard = | |||
TransformationManager::get_instance() | |||
.register_at<TransformationManager::Grad>(m_transformation); | |||
} | |||
void GradKeyWrapper::exit() { | |||
TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||
m_transformation); | |||
m_transformation_guard.reset(); | |||
grad_key_map.erase(m_key); | |||
m_key = {}; | |||
m_transformation.reset(); | |||
@@ -29,6 +29,7 @@ struct GradKeyWrapper : NonCopyableObj { | |||
std::string m_name; | |||
std::shared_ptr<GradKey> m_key; | |||
std::shared_ptr<GradTransformation> m_transformation; | |||
std::unique_ptr<CleanupGuard<>> m_transformation_guard; | |||
GradKeyWrapper(); | |||
@@ -449,15 +449,24 @@ void init_tensor(py::module m) { | |||
interpreter::Interpreter::inst().create_channel()) | |||
->get(); | |||
interpreter_for_py = channel; | |||
transformations.register_at<Segment::Eval>( | |||
std::make_shared<InterpreterTransformation>( | |||
std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||
transformations.register_at<Segment::Scalar>( | |||
std::make_shared<ScalarTransformation>()); | |||
transformations.register_at<Segment::DTypePromote>( | |||
std::make_shared<DTypePromoteTransformation>()); | |||
transformations.register_at<Segment::DimExpansion>( | |||
std::make_shared<DimExpansionTransformation>()); | |||
MGB_MARK_USED_VAR( | |||
transformations | |||
.register_at<Segment::Eval>( | |||
std::make_shared<InterpreterTransformation>( | |||
std::shared_ptr<Channel>(channel, [](Channel*) {}))) | |||
.release()); | |||
MGB_MARK_USED_VAR(transformations | |||
.register_at<Segment::Scalar>( | |||
std::make_shared<ScalarTransformation>()) | |||
.release()); | |||
MGB_MARK_USED_VAR(transformations | |||
.register_at<Segment::DTypePromote>( | |||
std::make_shared<DTypePromoteTransformation>()) | |||
.release()); | |||
MGB_MARK_USED_VAR(transformations | |||
.register_at<Segment::DimExpansion>( | |||
std::make_shared<DimExpansionTransformation>()) | |||
.release()); | |||
static py::exception<interpreter::AsyncError> py_async_error( | |||
m, "AsyncError", PyExc_RuntimeError); | |||
@@ -681,6 +690,9 @@ void init_tensor(py::module m) { | |||
std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler; | |||
std::optional<TraceResult> trace_result; | |||
std::function<bool(py::object, py::object)> array_comparator; | |||
std::unique_ptr<CleanupGuard<>> tracing_guard; | |||
std::unique_ptr<CleanupGuard<>> compiled_guard; | |||
std::unique_ptr<CleanupGuard<>> lazy_eval_guard; | |||
bool compare_value(ValueRef lhs, ValueRef rhs) { | |||
auto lvalue = lhs.cast_ref<HostValue>(); | |||
@@ -730,13 +742,16 @@ void init_tensor(py::module m) { | |||
std::make_shared<GraphProfiler>(¤t_graph)); | |||
} | |||
} | |||
transformations.register_at<Segment::Trace>(self.compiled); | |||
compiled_guard = | |||
transformations.register_at<Segment::Trace>(self.compiled); | |||
// start execute because InputCallback depends | |||
self.compiled->execute(); | |||
} else if (self.tracing) { | |||
transformations.register_at<Segment::Trace>(self.tracing); | |||
tracing_guard = | |||
transformations.register_at<Segment::Trace>(self.tracing); | |||
if (self.lazy_eval) { | |||
transformations.register_at<Segment::Eval>(self.lazy_eval); | |||
lazy_eval_guard = | |||
transformations.register_at<Segment::Eval>(self.lazy_eval); | |||
} | |||
} else { | |||
mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | |||
@@ -746,16 +761,16 @@ void init_tensor(py::module m) { | |||
void exit() { | |||
auto& self = *this; | |||
if (self.tracing) { | |||
transformations.unregister<Segment::Trace>(self.tracing); | |||
tracing_guard.reset(); | |||
self.trace_result = self.tracing->get_result(); | |||
self.tracing.reset(); | |||
if (self.lazy_eval) { | |||
auto lazy_eval = std::move(self.lazy_eval); | |||
transformations.unregister<Segment::Eval>(lazy_eval); | |||
lazy_eval_guard.reset(); | |||
lazy_eval->check_exception(); | |||
} | |||
} else if (self.compiled) { | |||
transformations.unregister<Segment::Trace>(self.compiled); | |||
compiled_guard.reset(); | |||
self.compiled->wait(); | |||
} else { | |||
mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | |||
@@ -829,17 +844,19 @@ void init_tensor(py::module m) { | |||
[](Trace& self) { | |||
mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | |||
if (self.tracing) { | |||
transformations.unregister<Segment::Trace>(self.tracing); | |||
self.tracing_guard.reset(); | |||
} else if (self.compiled) { | |||
transformations.unregister<Segment::Trace>(self.compiled); | |||
self.compiled_guard.reset(); | |||
} | |||
}) | |||
.def("end_excluded_region", [](Trace& self) { | |||
mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | |||
if (self.tracing) { | |||
transformations.register_at<Segment::Trace>(self.tracing); | |||
self.tracing_guard = | |||
transformations.register_at<Segment::Trace>(self.tracing); | |||
} else if (self.compiled) { | |||
transformations.register_at<Segment::Trace>(self.compiled); | |||
self.compiled_guard = | |||
transformations.register_at<Segment::Trace>(self.compiled); | |||
} | |||
}); | |||
@@ -900,11 +917,8 @@ void init_tensor(py::module m) { | |||
GradKeyWrapper::get(output.cast<GradKeyValue>()))); | |||
}); | |||
m.def("set_grad", [](py::object py_key, py::function backward_fn, | |||
std::vector<py::object> inputs, | |||
m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs, | |||
std::vector<py::object> outputs) { | |||
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | |||
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | |||
GenericFunction generic_backward_fn = | |||
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | |||
py::list output_grad_tws; | |||
@@ -937,8 +951,8 @@ void init_tensor(py::module m) { | |||
values[i + inputs.size()] = | |||
outputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
} | |||
auto wrapped_output_values = imperative::apply( | |||
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | |||
auto wrapped_output_values = | |||
imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values); | |||
std::vector<py::object> wrapped_outputs; | |||
mgb_assert(wrapped_output_values.size() == outputs.size()); | |||
for (auto&& output_value : wrapped_output_values) { | |||
@@ -956,8 +970,10 @@ void init_tensor(py::module m) { | |||
mgb_assert(module_trace_hook); | |||
module_trace_transformation = | |||
std::make_shared<ModuleTraceTransformation>(module_trace_hook); | |||
transformations.register_at<Segment::ModuleTrace>( | |||
module_trace_transformation); | |||
MGB_MARK_USED_VAR(transformations | |||
.register_at<Segment::ModuleTrace>( | |||
module_trace_transformation) | |||
.release()); | |||
} | |||
return module_trace_transformation; | |||
}; | |||
@@ -18,11 +18,13 @@ | |||
#include "megbrain/imperative/dispatch.h" | |||
#include "megbrain/imperative/transformation.h" | |||
#include "megbrain/imperative/utils/helper.h" | |||
#include "megbrain/imperative/value.h" | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb::imperative::python { | |||
struct TransformationManager { | |||
public: | |||
enum Segment { | |||
ModuleTrace, | |||
DTypePromote, | |||
@@ -35,8 +37,21 @@ struct TransformationManager { | |||
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||
private: | |||
template <Segment segment> | |||
void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||
mgb_assert(segment < segments.size()); | |||
auto iter = std::find( | |||
segments[segment].begin(), segments[segment].end(), transformation); | |||
mgb_assert(iter != segments[segment].end()); | |||
transformation->unregister(); | |||
segments[segment].erase(iter); | |||
} | |||
public: | |||
template <Segment segment> | |||
void register_at(std::shared_ptr<Transformation> transformation) { | |||
[[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at( | |||
std::shared_ptr<Transformation> transformation) { | |||
mgb_assert(segment < segments.size()); | |||
std::shared_ptr<Transformation> next; | |||
for (size_t i = segment; i < segments.size(); ++i) { | |||
@@ -51,16 +66,8 @@ struct TransformationManager { | |||
transformation->register_at(next->pos()); | |||
} | |||
segments[segment].push_back(transformation); | |||
} | |||
template <Segment segment> | |||
void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||
mgb_assert(segment < segments.size()); | |||
auto iter = std::find( | |||
segments[segment].begin(), segments[segment].end(), transformation); | |||
mgb_assert(iter != segments[segment].end()); | |||
transformation->unregister(); | |||
segments[segment].erase(iter); | |||
return std::make_unique<CleanupGuard<>>( | |||
[this, transformation]() { unregister<segment>(transformation); }); | |||
} | |||
static TransformationManager& get_instance() { | |||
@@ -452,6 +452,8 @@ def test_2nd_grad_with_custom_gradient(): | |||
return y | |||
def backward(self, dy): | |||
if dy is None: | |||
return None | |||
dx = -MySin()(self.inp) * dy | |||
return dx | |||
@@ -14,6 +14,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | |||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | |||
from megengine.core.autodiff.grad import Grad | |||
@@ -318,3 +318,41 @@ def test_throw_on_non_tensor_argument(): | |||
func = NonTensorArg() | |||
with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): | |||
func(x, 1) | |||
def test_multiple_grad(): | |||
data_shape = (9, 2, 6) | |||
av = np.random.random(data_shape).astype(np.float32) | |||
class MulFunc(Function): | |||
def forward(self, a): | |||
self.a = a | |||
return a * 10 | |||
def backward(self, grad_o): | |||
return grad_o * 20 | |||
class Simple(Module): | |||
def __init__(self, a): | |||
super().__init__() | |||
self.a = Parameter(a, dtype=np.float32) | |||
self.layer1 = MulFunc() | |||
def forward(self): | |||
x = self.layer1(self.a) | |||
return x | |||
net = Simple(av) | |||
gm = ad.GradManager().attach(net.parameters()) | |||
gm2 = ad.GradManager().attach(net.parameters()) | |||
opt = optimizer.SGD(net.parameters(), lr=1.0) | |||
opt.clear_grad() | |||
with gm: | |||
with gm2: | |||
loss = net() | |||
gm.backward(loss.sum()) | |||
opt.step() | |||
np.testing.assert_almost_equal(loss.numpy(), (av * 10)) | |||
np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) |
@@ -109,3 +109,46 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, | |||
_assert_allclose(out1.numpy(), out2.numpy()) | |||
_assert_allclose(grad1.numpy(), grad2.numpy()) | |||
@functools.lru_cache(maxsize=None) | |||
def _get_mul_fn(dtype, device): | |||
@subgraph_fn( | |||
"Mul", | |||
dtype=dtype, | |||
device=device, | |||
nr_inputs=2, | |||
gopt_level=None, | |||
jit_fusion=False, | |||
custom_grad=True, | |||
) | |||
def mul(inputs, f, c): | |||
x, y = inputs[0:2] | |||
z = f("*", x, y) | |||
(dz,) = yield (z,) | |||
dx = f("*", dz, y) | |||
dy = f("*", dz, x) | |||
yield (dx, dy) | |||
return mul | |||
def test_subgraph_jit_backward(): | |||
x_np = np.random.rand(3, 4, 5).astype("float32") | |||
x1 = megengine.Tensor(x_np) | |||
x2 = megengine.Tensor(x_np) | |||
mul = _get_mul_fn(x1.dtype, x1.device) | |||
gm = GradManager() | |||
gm.attach([x1, x2]) | |||
with gm: | |||
y1 = x1 * x1 | |||
y2 = mul(x2, x2) | |||
gm.backward(y1) | |||
with gm: | |||
y1 = x1 * x1 | |||
y2 = mul(x2, x2) | |||
gm.backward(y1 + y2) | |||
with gm: | |||
y1 = x1 * x1 | |||
y2 = mul(x2, x2) | |||
gm.backward(y2) |
@@ -18,18 +18,44 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) { | |||
auto& context = Transformation::get_context(); | |||
size_t& depth = context.next_transformation; | |||
mgb_assert(depth < context.transformations.size()); | |||
auto& transformation = *context.transformations[depth++]; | |||
CleanupGuard _{[&] { --depth; }}; | |||
return transformation.apply_transformation(op, inputs); | |||
} | |||
MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) { | |||
auto& context = Transformation::get_context(); | |||
size_t& depth = context.next_transformation; | |||
// TODO: add fallback transformation | |||
bool fallback = depth >= context.transformations.size(); | |||
if (mgb_unlikely(fallback)) { | |||
return op.fallback(inputs); | |||
mgb_assert(depth < context.transformations.size()); | |||
static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||
const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1; | |||
mgb_log_debug( | |||
"%s apply %s to %s", prefix, op.to_string().c_str(), | |||
imperative::to_string(inputs).c_str()); | |||
ValueRefList result; | |||
auto& transformation = *context.transformations[depth++]; | |||
CleanupGuard _{[&] { --depth; }}; | |||
result = transformation.apply_transformation(op, inputs); | |||
mgb_log_debug( | |||
"%s returns %s", prefix, | |||
imperative::to_string(Span<ValueRef>(result)).c_str()); | |||
return result; | |||
} | |||
} // namespace | |||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
static bool debug = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||
if (mgb_unlikely(debug)) { | |||
return apply_debug(op, inputs); | |||
} else { | |||
auto& transformation = *context.transformations[depth++]; | |||
CleanupGuard _{[&] { --depth; }}; | |||
return transformation.apply_transformation(op, inputs); | |||
return apply_release(op, inputs); | |||
} | |||
} | |||
@@ -106,7 +106,8 @@ EncodedSubgraph OpDef::make_forward_graph( | |||
} | |||
std::string OpDef::to_string() const { | |||
std::string builder = trait()->make_name(*this) + "{"; | |||
std::string builder = trait()->name; | |||
builder += "{"; | |||
for (auto&& [name, value] : props(*this)) { | |||
builder += name; | |||
builder += ": "; | |||
@@ -196,7 +197,7 @@ std::string Subgraph::repr() const { | |||
if (auto* p = op->try_cast_final<OprAttr>()) { | |||
buf << p->type; | |||
} else { | |||
buf << op->make_name(); | |||
buf << op->to_string(); | |||
} | |||
for (size_t i : ins) { | |||
buf << " "; | |||
@@ -11,13 +11,94 @@ | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/graph_builder.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/imperative/subgraph_detail.h" | |||
#include "megbrain/tensor.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) { | |||
Subgraph::Builder<LogicalTensorDesc> builder{ | |||
[](std::shared_ptr<OpDef> op, SmallVector<LogicalTensorDesc> inputs, | |||
size_t nr_outputs) { | |||
auto [outputs, validated] = | |||
OpDef::infer_output_attrs_fallible(*op, inputs); | |||
mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch"); | |||
return outputs; | |||
}}; | |||
auto f = [&](auto&& op, auto... args) { | |||
return builder.write_expr( | |||
op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; | |||
}; | |||
auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); | |||
auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); | |||
auto sub = Elemwise::make(Elemwise::Mode::SUB); | |||
auto mul = Elemwise::make(Elemwise::Mode::MUL); | |||
auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); | |||
auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV); | |||
auto broadcast = Broadcast::make(); | |||
auto c = [&](TensorPtr tensor, DType dtype) { | |||
auto result = builder.write_constant( | |||
tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()}); | |||
if (tensor->dtype() != dtype) { | |||
result = f(TypeCvt::make(dtype), result); | |||
} | |||
return result; | |||
}; | |||
auto ci = [&](megdnn::dt_int32 value) { | |||
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32()); | |||
}; | |||
auto cf = [&](megdnn::dt_float32 value) { | |||
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype); | |||
}; | |||
auto desc = LogicalTensorDesc{TensorLayout{dtype}, device}; | |||
auto x = builder.write_input(desc); | |||
auto y_grad = builder.write_input(desc); | |||
auto save_mean = builder.write_input(desc); | |||
auto save_invstd = builder.write_input(desc); | |||
auto weight = builder.write_input(desc); | |||
auto reserved = builder.write_input(desc); | |||
MGB_MARK_USED_VAR(reserved); | |||
// assert x.ndim == 4 | |||
auto input_shape = f(GetVarShape::make(), x); | |||
auto channels = f(GetVarShape::make(1), x); | |||
auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1)); | |||
auto input_elems = f(prod, input_shape); | |||
auto reduce_size = f(floor_div, input_elems, channels); | |||
auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size); | |||
auto mean = f(broadcast, save_mean, input_shape); | |||
auto invstd = save_invstd; | |||
auto norm = f(div, cf(1), reduce_size_f); | |||
auto output_grad_sum = f(sum, y_grad, reduce_shape); | |||
auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape); | |||
auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape); | |||
auto proj_scale = | |||
f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)), | |||
input_shape); | |||
auto grad_scale = f( | |||
mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape)); | |||
auto proj = f(mul, f(sub, x, mean), proj_scale); | |||
auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale); | |||
auto weight_grad = f(mul, dot_p, invstd); | |||
auto bias_grad = output_grad_sum; | |||
builder.add_outputs({weight_grad, bias_grad, x_grad}); | |||
auto bn_backward = builder.encode(); | |||
return bn_backward; | |||
} | |||
namespace bn { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | |||
return BatchNorm::make(node->param()); | |||
@@ -72,8 +153,60 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.fallback(); | |||
} // anonymous namespace | |||
} // namespace bn | |||
namespace bn_backward { | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
auto* node = &node_->cast_final_safe<opr::BatchNormBackward>(); | |||
return BatchNormBackward::make(node->param()); | |||
} | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto& op = def.cast_final_safe<BatchNormBackward>(); | |||
cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve; | |||
x = inputs[0]; | |||
y_grad = inputs[1]; | |||
save_mean = inputs[2]; | |||
save_variance = inputs[3]; | |||
weight = inputs[4]; | |||
if (inputs.size() == 6) { | |||
reserve = inputs[5]; | |||
} | |||
return opr::BatchNormBackward::make( | |||
x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0] | |||
.node() | |||
->owner_opr() | |||
->usable_output(); | |||
} | |||
EncodedSubgraph make_backward_graph( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad) { | |||
def.cast_final_safe<BatchNormBackward>(); | |||
size_t nr_inputs = 6; | |||
size_t nr_outputs = 3; | |||
mgb_assert(inputs.size() == nr_inputs); | |||
mgb_assert(input_requires_grad.size() == nr_inputs); | |||
mgb_assert(output_has_grad.size() == nr_outputs); | |||
auto dtype = inputs[0].layout.dtype; | |||
auto device = inputs[0].comp_node; | |||
auto bn_backward = generate_batchnorm_backward_graph(dtype, device); | |||
auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward( | |||
bn_backward, inputs, input_requires_grad, output_has_grad); | |||
return bn_double_backward; | |||
} | |||
OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.make_backward_graph(make_backward_graph) | |||
.fallback(); | |||
} // namespace bn_backward | |||
} // anonymous namespace | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -762,7 +762,9 @@ EncodedSubgraph make_backward_graph( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad) { | |||
return {}; | |||
return OpDef::make_backward_graph( | |||
*def.cast_final_safe<JITFusionOp>().op, inputs, input_requires_grad, | |||
output_has_grad); | |||
} | |||
OP_TRAIT_REG(JITFusionOp, JITFusionOp) | |||
@@ -96,10 +96,11 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint( | |||
return res; | |||
} | |||
static EncodedSubgraph make_backward_graph_from_forward( | |||
EncodedSubgraph make_backward_graph_from_forward( | |||
const EncodedSubgraph& forward_graph, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad, EncodedSubgraph forward_graph) { | |||
const SmallVector<bool>& output_has_grad) { | |||
using namespace std::placeholders; | |||
using var_t = Subgraph::var_t; | |||
using vars_t = Subgraph::vars_t; | |||
@@ -179,7 +180,7 @@ EncodedSubgraph make_backward_graph( | |||
const SmallVector<bool>& output_has_grad) { | |||
auto forward_graph = OpDef::make_forward_graph(def, inputs); | |||
return make_backward_graph_from_forward( | |||
inputs, input_requires_grad, output_has_grad, forward_graph); | |||
forward_graph, inputs, input_requires_grad, output_has_grad); | |||
} | |||
} // namespace subgraph_detail | |||
@@ -139,7 +139,7 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||
return {ValueRef()}; | |||
} | |||
} else { | |||
return imperative::apply(op, inputs); | |||
return op.fallback(inputs); | |||
} | |||
} | |||
@@ -62,7 +62,8 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs) | |||
: backward_graph(backward_graph), | |||
output_mask_offset(inputs.size()), | |||
grad_mask_offset(inputs.size() + outputs.size()) { | |||
grad_mask_offset(inputs.size() + outputs.size()), | |||
op(op) { | |||
auto& save_for_backward = backward_graph->save_for_backward; | |||
mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); | |||
size_t count = std::count_if( | |||
@@ -92,6 +93,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
closure.push_back(outputs[i]); | |||
} | |||
} | |||
if (outputs.size() > 1) { | |||
output_descs.reserve(outputs.size()); | |||
for (auto&& output : outputs) { | |||
auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0]; | |||
output_descs.push_back({symbolic_shape, output.dtype(), output.device()}); | |||
} | |||
} | |||
} | |||
void BackwardGraphWithClosure::operator()( | |||
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
@@ -100,23 +108,46 @@ void BackwardGraphWithClosure::operator()( | |||
for (auto&& value : closure) { | |||
args[nargs++] = value; | |||
} | |||
bool null_grad = false; | |||
size_t null_grad = 0; | |||
size_t valid_grad = 0; | |||
for (size_t i = 0; i < grads.size(); ++i) { | |||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
if (grads[i]) { | |||
mgb_assert(!null_grad, "null_grad"); | |||
valid_grad++; | |||
args[nargs++] = grads[i]; | |||
} else { | |||
null_grad = true; | |||
null_grad++; | |||
nargs++; | |||
} | |||
} | |||
} | |||
if (null_grad) { | |||
if (valid_grad == 0) { | |||
return; | |||
} | |||
auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()}; | |||
igrads_.clear(); | |||
if (null_grad > 0) { | |||
auto zeros_like = [](const OutputDesc& desc) { | |||
HostTensorStorage storage(*desc.device); | |||
storage.ensure_size(desc.dtype->size()); | |||
std::memset(storage.ptr(), 0, desc.dtype->size()); | |||
auto t = imperative::apply( | |||
CreateTensor( | |||
CreateTensor::Unique, *desc.device, *desc.dtype, | |||
ValueShape()), | |||
HostStorage::make(storage))[0]; | |||
auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0]; | |||
return res; | |||
}; | |||
nargs = closure.size(); | |||
for (size_t i = 0; i < grads.size(); ++i) { | |||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||
if (!grads[i]) { | |||
args[nargs] = zeros_like(output_descs[i]); | |||
} | |||
nargs++; | |||
} | |||
} | |||
} | |||
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
auto&& iter = igrads.begin(); | |||
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { | |||
if (p) { | |||
@@ -221,9 +252,11 @@ void GradKey::backward() { | |||
if (!dest) { | |||
continue; | |||
} | |||
if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { | |||
if (!dest.m_producer_record.next && dest->callback) { | |||
// I'm the last grad producer, invoke callback | |||
dest->callback(dest->m_grad); | |||
if (dest->m_grad) { | |||
dest->callback(dest->m_grad); | |||
} | |||
} | |||
} | |||
grad_fn->clear(); | |||
@@ -394,16 +427,22 @@ ValueRefList GradTransformation::apply_transformation( | |||
return imperative::apply(op, inputs); | |||
} | |||
if (auto* attach_grad = op.as<AttachGrad>()) { | |||
if (!has_key(attach_grad->key())) { | |||
auto& tensor = inputs[0]; | |||
if (auto&& grad_value = tensor.as_ref(m_value_type)) { | |||
mgb_assert(!has_key(attach_grad->key())); | |||
auto output = fallback()[0]; | |||
return record_grad(m_value_type.make(output, m_key, grad_value->slot())); | |||
} else if (!has_key(attach_grad->key())) { | |||
return fallback(); | |||
} else { | |||
GenericFunction callback = | |||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||
auto ret = callback({&grad, 1}); | |||
assert(ret.empty()); | |||
}); | |||
return {record_grad(output)}; | |||
} | |||
auto tensor = inputs[0]; | |||
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||
auto ret = callback({&grad, 1}); | |||
assert(ret.empty()); | |||
}); | |||
return {record_grad(output)}; | |||
} else if (auto* grad_backward = op.as<GradBackward>()) { | |||
if (!has_key(grad_backward->key())) { | |||
return fallback(); | |||
@@ -431,10 +470,10 @@ ValueRefList GradTransformation::apply_transformation( | |||
mgb_assert(inputs.size() > nr_inputs); | |||
size_t nr_outputs = inputs.size() - nr_inputs; | |||
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||
backward.m_input_has_grad = SmallVector(nr_inputs, true); | |||
backward.m_output_attrs = | |||
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
auto outputs_ = fallback(); | |||
backward.m_input_has_grad.resize(nr_inputs, true); | |||
backward.m_output_attrs.resize( | |||
nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) { | |||
auto result = fn(inputs); | |||
return SmallVector<ValueRef>(result.begin(), result.end()); | |||
@@ -31,6 +31,7 @@ class Subgraph::Builder { | |||
using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>; | |||
using encoded_graph_t = EncodedSubgraph; | |||
using var_map_t = std::unordered_map<var_t, var_t>; | |||
using mask_t = SmallVector<bool>; | |||
vars_t m_inputs; | |||
SmallVector<std::pair<var_t, TensorPtr>> m_constants; | |||
vars_t m_outputs; | |||
@@ -94,6 +95,7 @@ public: | |||
descs_t get_descs(vars_t vars) { | |||
descs_t descs; | |||
for (auto&& var : vars) { | |||
mgb_assert(var, "invalid var"); | |||
descs.push_back(get_desc(var)); | |||
} | |||
return descs; | |||
@@ -128,4 +130,4 @@ public: | |||
expr_iter_t end() { return m_exprs.end(); } | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb | |||
} // namespace mgb |
@@ -38,7 +38,6 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||
std::shared_ptr<OpDef> op; | |||
SmallVector<CompNode> devices; | |||
SmallVector<DType> dtypes; | |||
EncodedSubgraph graph; | |||
ShapeInfer() = default; | |||
ShapeInfer( | |||
std::shared_ptr<OpDef> op, SmallVector<CompNode> devices, | |||
@@ -39,6 +39,11 @@ EncodedSubgraph make_backward_graph( | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs); | |||
EncodedSubgraph make_backward_graph_from_forward( | |||
const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad); | |||
} // namespace subgraph_detail | |||
} // namespace imperative | |||
} // namespace mgb | |||
} // namespace mgb |
@@ -29,6 +29,15 @@ struct BackwardGraphWithClosure { | |||
SmallVector<ValueRef> closure; | |||
size_t output_mask_offset; | |||
size_t grad_mask_offset; | |||
std::shared_ptr<OpDef> op; | |||
struct OutputDesc { | |||
ValueRef shape; | |||
DTypeValue::ref_t dtype; | |||
CompNodeValue::ref_t device; | |||
}; | |||
SmallVector<OutputDesc> output_descs; | |||
BackwardGraphWithClosure( | |||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | |||
@@ -356,20 +365,22 @@ public: | |||
class SetGrad : public OperatorImpl<SetGrad> { | |||
private: | |||
std::shared_ptr<GradKey> m_key; | |||
GenericFunction m_grad_fn; | |||
size_t m_nr_inputs; | |||
public: | |||
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | |||
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
GenericFunction grad_fn() const { return m_grad_fn; } | |||
size_t nr_inputs() const { return m_nr_inputs; } | |||
std::string to_string() const override { | |||
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | |||
std::string to_string() const override { return ssprintf("SetGradValue{}"); } | |||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs); | |||
return {outputs.begin(), outputs.end()}; | |||
} | |||
}; | |||
@@ -15,12 +15,14 @@ | |||
#include <memory> | |||
#include <sstream> | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb { | |||
namespace imperative { | |||
template <typename T> | |||
class CleanupGuard { | |||
template <typename T = std::function<void()>> | |||
class CleanupGuard : public NonCopyableObj { | |||
private: | |||
T m_callback; | |||
@@ -37,4 +39,4 @@ inline std::string quoted(std::string str) { | |||
} // namespace imperative | |||
} // namespace mgb | |||
} // namespace mgb |
@@ -89,6 +89,8 @@ def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWin | |||
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | |||
def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>; | |||
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | |||
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | |||