1. implement double backward for batchnorm
2. fix grad attach in nested grad manager
3. pad empty tensor for unsatisfied output_has_grad
4. support double backward for jit subgraph
5. support double backward for autodiff.Function
6. readd debug flag MGE_LOG_OP_DISPATCH
GitOrigin-RevId: cd31ddc620
tags/v1.9.0
@@ -212,10 +212,7 @@ class Function: | |||||
if self.__single_output: | if self.__single_output: | ||||
outputs = (outputs,) | outputs = (outputs,) | ||||
for grad in reversed(group): | |||||
if grad._impl is None: | |||||
continue | |||||
outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs) | |||||
outputs = core2.set_grad(normalized_backward, args, outputs) | |||||
if self.__single_output: | if self.__single_output: | ||||
(outputs,) = outputs | (outputs,) = outputs | ||||
return outputs | return outputs | ||||
@@ -209,7 +209,6 @@ def subgraph( | |||||
outputs = gen.send(None) | outputs = gen.send(None) | ||||
nr_outputs = len(outputs) | nr_outputs = len(outputs) | ||||
forward_fn = build(builder, outputs, [False] * nr_outputs) | forward_fn = build(builder, outputs, [False] * nr_outputs) | ||||
output_grads = [builder.input() for _ in range(nr_outputs)] | output_grads = [builder.input() for _ in range(nr_outputs)] | ||||
input_grads = gen.send(output_grads) | input_grads = gen.send(output_grads) | ||||
assert len(input_grads) == nr_inputs | assert len(input_grads) == nr_inputs | ||||
@@ -222,25 +221,49 @@ def subgraph( | |||||
] | ] | ||||
encoded_input_grads = [grad for grad in input_grads if grad is not None] | encoded_input_grads = [grad for grad in input_grads if grad is not None] | ||||
backward_fn = build( | backward_fn = build( | ||||
builder, encoded_input_grads, [False] * len(encoded_input_grads) | |||||
builder, encoded_input_grads, [True] * len(encoded_input_grads) | |||||
) | ) | ||||
class SubgraphOp(Function): | class SubgraphOp(Function): | ||||
def __init__(self): | def __init__(self): | ||||
self.inputs = None | self.inputs = None | ||||
self.output_shapes = None | |||||
def forward(self, *inputs): | def forward(self, *inputs): | ||||
self.inputs = inputs | self.inputs = inputs | ||||
return apply(forward_fn(), *inputs) | |||||
outputs = apply(forward_fn(), *inputs) | |||||
if len(outputs) > 1: | |||||
self.output_shapes = [output.shape for output in outputs] | |||||
return outputs | |||||
def backward(self, *output_grads): | def backward(self, *output_grads): | ||||
inputs = self.inputs | inputs = self.inputs | ||||
self.inputs = None | |||||
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) | |||||
input_grads = [ | |||||
encoded_input_grads[i] if i is not None else None | |||||
for i in indices | |||||
] | |||||
any_valid = False | |||||
all_valid = True | |||||
for output_grad in output_grads: | |||||
if output_grad is None: | |||||
all_valid = False | |||||
else: | |||||
any_valid = True | |||||
if not any_valid: | |||||
input_grads = [None] * len(indices) | |||||
else: | |||||
if not all_valid: | |||||
assert self.output_shapes is not None | |||||
from ...functional import zeros | |||||
output_grads = [ | |||||
zeros(self.output_shapes[i]) if grad is None else grad | |||||
for i, grad in enumerate(output_grads) | |||||
] | |||||
self = None | |||||
encoded_input_grads = apply( | |||||
backward_fn(), *inputs, *output_grads | |||||
) | |||||
input_grads = [ | |||||
encoded_input_grads[i] if i is not None else None | |||||
for i in indices | |||||
] | |||||
return input_grads | return input_grads | ||||
gen.close() | gen.close() | ||||
@@ -896,7 +896,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: | |||||
@lru_cache(maxsize=None) | @lru_cache(maxsize=None) | ||||
def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): | |||||
def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None): | |||||
@subgraph_fn( | @subgraph_fn( | ||||
"LeakyReLU", | "LeakyReLU", | ||||
dtype=dtype, | dtype=dtype, | ||||
@@ -925,7 +925,7 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||||
Refer to :class:`~.LeakyReLU` for more information. | Refer to :class:`~.LeakyReLU` for more information. | ||||
""" | """ | ||||
leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||||
leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) | |||||
(oup,) = leakyReLU(inp) | (oup,) = leakyReLU(inp) | ||||
return oup | return oup | ||||
@@ -1399,7 +1399,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): | |||||
f("fma3", input, inv_var_wt, | f("fma3", input, inv_var_wt, | ||||
f("+", f("*", neg_channel_mean, inv_var_wt), | f("+", f("*", neg_channel_mean, inv_var_wt), | ||||
bias)) | bias)) | ||||
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) | |||||
return (outvar, channel_mean, channel_var), (True, True, True) | |||||
@subgraph("SyncBnStage1Inference", dtype, device, 6) | @subgraph("SyncBnStage1Inference", dtype, device, 6) | ||||
def syncbn_stage1_inference(inputs, f, c): | def syncbn_stage1_inference(inputs, f, c): | ||||
@@ -1509,7 +1509,7 @@ def sync_batch_norm( | |||||
""" | """ | ||||
_eps_mode = eps_mode.lower() | _eps_mode = eps_mode.lower() | ||||
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) | ||||
if _eps_mode == "additive" and not (is_distributed() and training): | |||||
if _eps_mode == "additive" and not (is_distributed() or training): | |||||
return batch_norm( | return batch_norm( | ||||
inp, | inp, | ||||
running_mean, | running_mean, | ||||
@@ -121,13 +121,13 @@ void GradKeyWrapper::enter() { | |||||
m_key = m_transformation->key(); | m_key = m_transformation->key(); | ||||
m_key->name(m_name); | m_key->name(m_name); | ||||
grad_key_map[m_key] = this; | grad_key_map[m_key] = this; | ||||
TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||||
m_transformation); | |||||
m_transformation_guard = | |||||
TransformationManager::get_instance() | |||||
.register_at<TransformationManager::Grad>(m_transformation); | |||||
} | } | ||||
void GradKeyWrapper::exit() { | void GradKeyWrapper::exit() { | ||||
TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||||
m_transformation); | |||||
m_transformation_guard.reset(); | |||||
grad_key_map.erase(m_key); | grad_key_map.erase(m_key); | ||||
m_key = {}; | m_key = {}; | ||||
m_transformation.reset(); | m_transformation.reset(); | ||||
@@ -29,6 +29,7 @@ struct GradKeyWrapper : NonCopyableObj { | |||||
std::string m_name; | std::string m_name; | ||||
std::shared_ptr<GradKey> m_key; | std::shared_ptr<GradKey> m_key; | ||||
std::shared_ptr<GradTransformation> m_transformation; | std::shared_ptr<GradTransformation> m_transformation; | ||||
std::unique_ptr<CleanupGuard<>> m_transformation_guard; | |||||
GradKeyWrapper(); | GradKeyWrapper(); | ||||
@@ -449,15 +449,24 @@ void init_tensor(py::module m) { | |||||
interpreter::Interpreter::inst().create_channel()) | interpreter::Interpreter::inst().create_channel()) | ||||
->get(); | ->get(); | ||||
interpreter_for_py = channel; | interpreter_for_py = channel; | ||||
transformations.register_at<Segment::Eval>( | |||||
std::make_shared<InterpreterTransformation>( | |||||
std::shared_ptr<Channel>(channel, [](Channel*) {}))); | |||||
transformations.register_at<Segment::Scalar>( | |||||
std::make_shared<ScalarTransformation>()); | |||||
transformations.register_at<Segment::DTypePromote>( | |||||
std::make_shared<DTypePromoteTransformation>()); | |||||
transformations.register_at<Segment::DimExpansion>( | |||||
std::make_shared<DimExpansionTransformation>()); | |||||
MGB_MARK_USED_VAR( | |||||
transformations | |||||
.register_at<Segment::Eval>( | |||||
std::make_shared<InterpreterTransformation>( | |||||
std::shared_ptr<Channel>(channel, [](Channel*) {}))) | |||||
.release()); | |||||
MGB_MARK_USED_VAR(transformations | |||||
.register_at<Segment::Scalar>( | |||||
std::make_shared<ScalarTransformation>()) | |||||
.release()); | |||||
MGB_MARK_USED_VAR(transformations | |||||
.register_at<Segment::DTypePromote>( | |||||
std::make_shared<DTypePromoteTransformation>()) | |||||
.release()); | |||||
MGB_MARK_USED_VAR(transformations | |||||
.register_at<Segment::DimExpansion>( | |||||
std::make_shared<DimExpansionTransformation>()) | |||||
.release()); | |||||
static py::exception<interpreter::AsyncError> py_async_error( | static py::exception<interpreter::AsyncError> py_async_error( | ||||
m, "AsyncError", PyExc_RuntimeError); | m, "AsyncError", PyExc_RuntimeError); | ||||
@@ -681,6 +690,9 @@ void init_tensor(py::module m) { | |||||
std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler; | std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler; | ||||
std::optional<TraceResult> trace_result; | std::optional<TraceResult> trace_result; | ||||
std::function<bool(py::object, py::object)> array_comparator; | std::function<bool(py::object, py::object)> array_comparator; | ||||
std::unique_ptr<CleanupGuard<>> tracing_guard; | |||||
std::unique_ptr<CleanupGuard<>> compiled_guard; | |||||
std::unique_ptr<CleanupGuard<>> lazy_eval_guard; | |||||
bool compare_value(ValueRef lhs, ValueRef rhs) { | bool compare_value(ValueRef lhs, ValueRef rhs) { | ||||
auto lvalue = lhs.cast_ref<HostValue>(); | auto lvalue = lhs.cast_ref<HostValue>(); | ||||
@@ -730,13 +742,16 @@ void init_tensor(py::module m) { | |||||
std::make_shared<GraphProfiler>(¤t_graph)); | std::make_shared<GraphProfiler>(¤t_graph)); | ||||
} | } | ||||
} | } | ||||
transformations.register_at<Segment::Trace>(self.compiled); | |||||
compiled_guard = | |||||
transformations.register_at<Segment::Trace>(self.compiled); | |||||
// start execute because InputCallback depends | // start execute because InputCallback depends | ||||
self.compiled->execute(); | self.compiled->execute(); | ||||
} else if (self.tracing) { | } else if (self.tracing) { | ||||
transformations.register_at<Segment::Trace>(self.tracing); | |||||
tracing_guard = | |||||
transformations.register_at<Segment::Trace>(self.tracing); | |||||
if (self.lazy_eval) { | if (self.lazy_eval) { | ||||
transformations.register_at<Segment::Eval>(self.lazy_eval); | |||||
lazy_eval_guard = | |||||
transformations.register_at<Segment::Eval>(self.lazy_eval); | |||||
} | } | ||||
} else { | } else { | ||||
mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | ||||
@@ -746,16 +761,16 @@ void init_tensor(py::module m) { | |||||
void exit() { | void exit() { | ||||
auto& self = *this; | auto& self = *this; | ||||
if (self.tracing) { | if (self.tracing) { | ||||
transformations.unregister<Segment::Trace>(self.tracing); | |||||
tracing_guard.reset(); | |||||
self.trace_result = self.tracing->get_result(); | self.trace_result = self.tracing->get_result(); | ||||
self.tracing.reset(); | self.tracing.reset(); | ||||
if (self.lazy_eval) { | if (self.lazy_eval) { | ||||
auto lazy_eval = std::move(self.lazy_eval); | auto lazy_eval = std::move(self.lazy_eval); | ||||
transformations.unregister<Segment::Eval>(lazy_eval); | |||||
lazy_eval_guard.reset(); | |||||
lazy_eval->check_exception(); | lazy_eval->check_exception(); | ||||
} | } | ||||
} else if (self.compiled) { | } else if (self.compiled) { | ||||
transformations.unregister<Segment::Trace>(self.compiled); | |||||
compiled_guard.reset(); | |||||
self.compiled->wait(); | self.compiled->wait(); | ||||
} else { | } else { | ||||
mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); | ||||
@@ -829,17 +844,19 @@ void init_tensor(py::module m) { | |||||
[](Trace& self) { | [](Trace& self) { | ||||
mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | ||||
if (self.tracing) { | if (self.tracing) { | ||||
transformations.unregister<Segment::Trace>(self.tracing); | |||||
self.tracing_guard.reset(); | |||||
} else if (self.compiled) { | } else if (self.compiled) { | ||||
transformations.unregister<Segment::Trace>(self.compiled); | |||||
self.compiled_guard.reset(); | |||||
} | } | ||||
}) | }) | ||||
.def("end_excluded_region", [](Trace& self) { | .def("end_excluded_region", [](Trace& self) { | ||||
mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | mgb_assert(bool(self.tracing) ^ bool(self.compiled)); | ||||
if (self.tracing) { | if (self.tracing) { | ||||
transformations.register_at<Segment::Trace>(self.tracing); | |||||
self.tracing_guard = | |||||
transformations.register_at<Segment::Trace>(self.tracing); | |||||
} else if (self.compiled) { | } else if (self.compiled) { | ||||
transformations.register_at<Segment::Trace>(self.compiled); | |||||
self.compiled_guard = | |||||
transformations.register_at<Segment::Trace>(self.compiled); | |||||
} | } | ||||
}); | }); | ||||
@@ -900,11 +917,8 @@ void init_tensor(py::module m) { | |||||
GradKeyWrapper::get(output.cast<GradKeyValue>()))); | GradKeyWrapper::get(output.cast<GradKeyValue>()))); | ||||
}); | }); | ||||
m.def("set_grad", [](py::object py_key, py::function backward_fn, | |||||
std::vector<py::object> inputs, | |||||
m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs, | |||||
std::vector<py::object> outputs) { | std::vector<py::object> outputs) { | ||||
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | |||||
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | |||||
GenericFunction generic_backward_fn = | GenericFunction generic_backward_fn = | ||||
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | [backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | ||||
py::list output_grad_tws; | py::list output_grad_tws; | ||||
@@ -937,8 +951,8 @@ void init_tensor(py::module m) { | |||||
values[i + inputs.size()] = | values[i + inputs.size()] = | ||||
outputs[i].cast<TensorWrapper>().m_tensor->data(); | outputs[i].cast<TensorWrapper>().m_tensor->data(); | ||||
} | } | ||||
auto wrapped_output_values = imperative::apply( | |||||
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | |||||
auto wrapped_output_values = | |||||
imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values); | |||||
std::vector<py::object> wrapped_outputs; | std::vector<py::object> wrapped_outputs; | ||||
mgb_assert(wrapped_output_values.size() == outputs.size()); | mgb_assert(wrapped_output_values.size() == outputs.size()); | ||||
for (auto&& output_value : wrapped_output_values) { | for (auto&& output_value : wrapped_output_values) { | ||||
@@ -956,8 +970,10 @@ void init_tensor(py::module m) { | |||||
mgb_assert(module_trace_hook); | mgb_assert(module_trace_hook); | ||||
module_trace_transformation = | module_trace_transformation = | ||||
std::make_shared<ModuleTraceTransformation>(module_trace_hook); | std::make_shared<ModuleTraceTransformation>(module_trace_hook); | ||||
transformations.register_at<Segment::ModuleTrace>( | |||||
module_trace_transformation); | |||||
MGB_MARK_USED_VAR(transformations | |||||
.register_at<Segment::ModuleTrace>( | |||||
module_trace_transformation) | |||||
.release()); | |||||
} | } | ||||
return module_trace_transformation; | return module_trace_transformation; | ||||
}; | }; | ||||
@@ -18,11 +18,13 @@ | |||||
#include "megbrain/imperative/dispatch.h" | #include "megbrain/imperative/dispatch.h" | ||||
#include "megbrain/imperative/transformation.h" | #include "megbrain/imperative/transformation.h" | ||||
#include "megbrain/imperative/utils/helper.h" | |||||
#include "megbrain/imperative/value.h" | #include "megbrain/imperative/value.h" | ||||
#include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
struct TransformationManager { | struct TransformationManager { | ||||
public: | |||||
enum Segment { | enum Segment { | ||||
ModuleTrace, | ModuleTrace, | ||||
DTypePromote, | DTypePromote, | ||||
@@ -35,8 +37,21 @@ struct TransformationManager { | |||||
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | ||||
private: | |||||
template <Segment segment> | |||||
void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||||
mgb_assert(segment < segments.size()); | |||||
auto iter = std::find( | |||||
segments[segment].begin(), segments[segment].end(), transformation); | |||||
mgb_assert(iter != segments[segment].end()); | |||||
transformation->unregister(); | |||||
segments[segment].erase(iter); | |||||
} | |||||
public: | |||||
template <Segment segment> | template <Segment segment> | ||||
void register_at(std::shared_ptr<Transformation> transformation) { | |||||
[[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at( | |||||
std::shared_ptr<Transformation> transformation) { | |||||
mgb_assert(segment < segments.size()); | mgb_assert(segment < segments.size()); | ||||
std::shared_ptr<Transformation> next; | std::shared_ptr<Transformation> next; | ||||
for (size_t i = segment; i < segments.size(); ++i) { | for (size_t i = segment; i < segments.size(); ++i) { | ||||
@@ -51,16 +66,8 @@ struct TransformationManager { | |||||
transformation->register_at(next->pos()); | transformation->register_at(next->pos()); | ||||
} | } | ||||
segments[segment].push_back(transformation); | segments[segment].push_back(transformation); | ||||
} | |||||
template <Segment segment> | |||||
void unregister(std::shared_ptr<Transformation> transformation) noexcept { | |||||
mgb_assert(segment < segments.size()); | |||||
auto iter = std::find( | |||||
segments[segment].begin(), segments[segment].end(), transformation); | |||||
mgb_assert(iter != segments[segment].end()); | |||||
transformation->unregister(); | |||||
segments[segment].erase(iter); | |||||
return std::make_unique<CleanupGuard<>>( | |||||
[this, transformation]() { unregister<segment>(transformation); }); | |||||
} | } | ||||
static TransformationManager& get_instance() { | static TransformationManager& get_instance() { | ||||
@@ -452,6 +452,8 @@ def test_2nd_grad_with_custom_gradient(): | |||||
return y | return y | ||||
def backward(self, dy): | def backward(self, dy): | ||||
if dy is None: | |||||
return None | |||||
dx = -MySin()(self.inp) * dy | dx = -MySin()(self.inp) * dy | ||||
return dx | return dx | ||||
@@ -14,6 +14,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | |||||
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | ||||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | ||||
from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
@@ -318,3 +318,41 @@ def test_throw_on_non_tensor_argument(): | |||||
func = NonTensorArg() | func = NonTensorArg() | ||||
with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): | with pytest.raises(TypeError, match=r"op .* expect type Tensor as inputs"): | ||||
func(x, 1) | func(x, 1) | ||||
def test_multiple_grad(): | |||||
data_shape = (9, 2, 6) | |||||
av = np.random.random(data_shape).astype(np.float32) | |||||
class MulFunc(Function): | |||||
def forward(self, a): | |||||
self.a = a | |||||
return a * 10 | |||||
def backward(self, grad_o): | |||||
return grad_o * 20 | |||||
class Simple(Module): | |||||
def __init__(self, a): | |||||
super().__init__() | |||||
self.a = Parameter(a, dtype=np.float32) | |||||
self.layer1 = MulFunc() | |||||
def forward(self): | |||||
x = self.layer1(self.a) | |||||
return x | |||||
net = Simple(av) | |||||
gm = ad.GradManager().attach(net.parameters()) | |||||
gm2 = ad.GradManager().attach(net.parameters()) | |||||
opt = optimizer.SGD(net.parameters(), lr=1.0) | |||||
opt.clear_grad() | |||||
with gm: | |||||
with gm2: | |||||
loss = net() | |||||
gm.backward(loss.sum()) | |||||
opt.step() | |||||
np.testing.assert_almost_equal(loss.numpy(), (av * 10)) | |||||
np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) |
@@ -109,3 +109,46 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, | |||||
_assert_allclose(out1.numpy(), out2.numpy()) | _assert_allclose(out1.numpy(), out2.numpy()) | ||||
_assert_allclose(grad1.numpy(), grad2.numpy()) | _assert_allclose(grad1.numpy(), grad2.numpy()) | ||||
@functools.lru_cache(maxsize=None) | |||||
def _get_mul_fn(dtype, device): | |||||
@subgraph_fn( | |||||
"Mul", | |||||
dtype=dtype, | |||||
device=device, | |||||
nr_inputs=2, | |||||
gopt_level=None, | |||||
jit_fusion=False, | |||||
custom_grad=True, | |||||
) | |||||
def mul(inputs, f, c): | |||||
x, y = inputs[0:2] | |||||
z = f("*", x, y) | |||||
(dz,) = yield (z,) | |||||
dx = f("*", dz, y) | |||||
dy = f("*", dz, x) | |||||
yield (dx, dy) | |||||
return mul | |||||
def test_subgraph_jit_backward(): | |||||
x_np = np.random.rand(3, 4, 5).astype("float32") | |||||
x1 = megengine.Tensor(x_np) | |||||
x2 = megengine.Tensor(x_np) | |||||
mul = _get_mul_fn(x1.dtype, x1.device) | |||||
gm = GradManager() | |||||
gm.attach([x1, x2]) | |||||
with gm: | |||||
y1 = x1 * x1 | |||||
y2 = mul(x2, x2) | |||||
gm.backward(y1) | |||||
with gm: | |||||
y1 = x1 * x1 | |||||
y2 = mul(x2, x2) | |||||
gm.backward(y1 + y2) | |||||
with gm: | |||||
y1 = x1 * x1 | |||||
y2 = mul(x2, x2) | |||||
gm.backward(y2) |
@@ -18,18 +18,44 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
namespace { | |||||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||||
ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) { | |||||
auto& context = Transformation::get_context(); | |||||
size_t& depth = context.next_transformation; | |||||
mgb_assert(depth < context.transformations.size()); | |||||
auto& transformation = *context.transformations[depth++]; | |||||
CleanupGuard _{[&] { --depth; }}; | |||||
return transformation.apply_transformation(op, inputs); | |||||
} | |||||
MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) { | |||||
auto& context = Transformation::get_context(); | auto& context = Transformation::get_context(); | ||||
size_t& depth = context.next_transformation; | size_t& depth = context.next_transformation; | ||||
// TODO: add fallback transformation | |||||
bool fallback = depth >= context.transformations.size(); | |||||
if (mgb_unlikely(fallback)) { | |||||
return op.fallback(inputs); | |||||
mgb_assert(depth < context.transformations.size()); | |||||
static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||||
const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1; | |||||
mgb_log_debug( | |||||
"%s apply %s to %s", prefix, op.to_string().c_str(), | |||||
imperative::to_string(inputs).c_str()); | |||||
ValueRefList result; | |||||
auto& transformation = *context.transformations[depth++]; | |||||
CleanupGuard _{[&] { --depth; }}; | |||||
result = transformation.apply_transformation(op, inputs); | |||||
mgb_log_debug( | |||||
"%s returns %s", prefix, | |||||
imperative::to_string(Span<ValueRef>(result)).c_str()); | |||||
return result; | |||||
} | |||||
} // namespace | |||||
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||||
static bool debug = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||||
if (mgb_unlikely(debug)) { | |||||
return apply_debug(op, inputs); | |||||
} else { | } else { | ||||
auto& transformation = *context.transformations[depth++]; | |||||
CleanupGuard _{[&] { --depth; }}; | |||||
return transformation.apply_transformation(op, inputs); | |||||
return apply_release(op, inputs); | |||||
} | } | ||||
} | } | ||||
@@ -106,7 +106,8 @@ EncodedSubgraph OpDef::make_forward_graph( | |||||
} | } | ||||
std::string OpDef::to_string() const { | std::string OpDef::to_string() const { | ||||
std::string builder = trait()->make_name(*this) + "{"; | |||||
std::string builder = trait()->name; | |||||
builder += "{"; | |||||
for (auto&& [name, value] : props(*this)) { | for (auto&& [name, value] : props(*this)) { | ||||
builder += name; | builder += name; | ||||
builder += ": "; | builder += ": "; | ||||
@@ -196,7 +197,7 @@ std::string Subgraph::repr() const { | |||||
if (auto* p = op->try_cast_final<OprAttr>()) { | if (auto* p = op->try_cast_final<OprAttr>()) { | ||||
buf << p->type; | buf << p->type; | ||||
} else { | } else { | ||||
buf << op->make_name(); | |||||
buf << op->to_string(); | |||||
} | } | ||||
for (size_t i : ins) { | for (size_t i : ins) { | ||||
buf << " "; | buf << " "; | ||||
@@ -11,13 +11,94 @@ | |||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
#include "megbrain/imperative/graph_builder.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
#include "megbrain/imperative/proxy_graph_detail.h" | |||||
#include "megbrain/imperative/subgraph_detail.h" | |||||
#include "megbrain/tensor.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
namespace { | namespace { | ||||
EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) { | |||||
Subgraph::Builder<LogicalTensorDesc> builder{ | |||||
[](std::shared_ptr<OpDef> op, SmallVector<LogicalTensorDesc> inputs, | |||||
size_t nr_outputs) { | |||||
auto [outputs, validated] = | |||||
OpDef::infer_output_attrs_fallible(*op, inputs); | |||||
mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch"); | |||||
return outputs; | |||||
}}; | |||||
auto f = [&](auto&& op, auto... args) { | |||||
return builder.write_expr( | |||||
op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; | |||||
}; | |||||
auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); | |||||
auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); | |||||
auto sub = Elemwise::make(Elemwise::Mode::SUB); | |||||
auto mul = Elemwise::make(Elemwise::Mode::MUL); | |||||
auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); | |||||
auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV); | |||||
auto broadcast = Broadcast::make(); | |||||
auto c = [&](TensorPtr tensor, DType dtype) { | |||||
auto result = builder.write_constant( | |||||
tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()}); | |||||
if (tensor->dtype() != dtype) { | |||||
result = f(TypeCvt::make(dtype), result); | |||||
} | |||||
return result; | |||||
}; | |||||
auto ci = [&](megdnn::dt_int32 value) { | |||||
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32()); | |||||
}; | |||||
auto cf = [&](megdnn::dt_float32 value) { | |||||
return c(Tensor::make_scalar(DTypeScalar(value), device), dtype); | |||||
}; | |||||
auto desc = LogicalTensorDesc{TensorLayout{dtype}, device}; | |||||
auto x = builder.write_input(desc); | |||||
auto y_grad = builder.write_input(desc); | |||||
auto save_mean = builder.write_input(desc); | |||||
auto save_invstd = builder.write_input(desc); | |||||
auto weight = builder.write_input(desc); | |||||
auto reserved = builder.write_input(desc); | |||||
MGB_MARK_USED_VAR(reserved); | |||||
// assert x.ndim == 4 | |||||
auto input_shape = f(GetVarShape::make(), x); | |||||
auto channels = f(GetVarShape::make(1), x); | |||||
auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1)); | |||||
auto input_elems = f(prod, input_shape); | |||||
auto reduce_size = f(floor_div, input_elems, channels); | |||||
auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size); | |||||
auto mean = f(broadcast, save_mean, input_shape); | |||||
auto invstd = save_invstd; | |||||
auto norm = f(div, cf(1), reduce_size_f); | |||||
auto output_grad_sum = f(sum, y_grad, reduce_shape); | |||||
auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape); | |||||
auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape); | |||||
auto proj_scale = | |||||
f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)), | |||||
input_shape); | |||||
auto grad_scale = f( | |||||
mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape)); | |||||
auto proj = f(mul, f(sub, x, mean), proj_scale); | |||||
auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale); | |||||
auto weight_grad = f(mul, dot_p, invstd); | |||||
auto bias_grad = output_grad_sum; | |||||
builder.add_outputs({weight_grad, bias_grad, x_grad}); | |||||
auto bn_backward = builder.encode(); | |||||
return bn_backward; | |||||
} | |||||
namespace bn { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | auto* node = &node_->cast_final_safe<opr::BatchNorm>(); | ||||
return BatchNorm::make(node->param()); | return BatchNorm::make(node->param()); | ||||
@@ -72,8 +153,60 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | |||||
} // namespace bn | |||||
namespace bn_backward { | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
auto* node = &node_->cast_final_safe<opr::BatchNormBackward>(); | |||||
return BatchNormBackward::make(node->param()); | |||||
} | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto& op = def.cast_final_safe<BatchNormBackward>(); | |||||
cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve; | |||||
x = inputs[0]; | |||||
y_grad = inputs[1]; | |||||
save_mean = inputs[2]; | |||||
save_variance = inputs[3]; | |||||
weight = inputs[4]; | |||||
if (inputs.size() == 6) { | |||||
reserve = inputs[5]; | |||||
} | |||||
return opr::BatchNormBackward::make( | |||||
x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0] | |||||
.node() | |||||
->owner_opr() | |||||
->usable_output(); | |||||
} | |||||
EncodedSubgraph make_backward_graph( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad) { | |||||
def.cast_final_safe<BatchNormBackward>(); | |||||
size_t nr_inputs = 6; | |||||
size_t nr_outputs = 3; | |||||
mgb_assert(inputs.size() == nr_inputs); | |||||
mgb_assert(input_requires_grad.size() == nr_inputs); | |||||
mgb_assert(output_has_grad.size() == nr_outputs); | |||||
auto dtype = inputs[0].layout.dtype; | |||||
auto device = inputs[0].comp_node; | |||||
auto bn_backward = generate_batchnorm_backward_graph(dtype, device); | |||||
auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward( | |||||
bn_backward, inputs, input_requires_grad, output_has_grad); | |||||
return bn_double_backward; | |||||
} | |||||
OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward) | |||||
.make_from_op_node(make_from_op_node) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.make_backward_graph(make_backward_graph) | |||||
.fallback(); | |||||
} // namespace bn_backward | |||||
} // anonymous namespace | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -762,7 +762,9 @@ EncodedSubgraph make_backward_graph( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
return {}; | |||||
return OpDef::make_backward_graph( | |||||
*def.cast_final_safe<JITFusionOp>().op, inputs, input_requires_grad, | |||||
output_has_grad); | |||||
} | } | ||||
OP_TRAIT_REG(JITFusionOp, JITFusionOp) | OP_TRAIT_REG(JITFusionOp, JITFusionOp) | ||||
@@ -96,10 +96,11 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint( | |||||
return res; | return res; | ||||
} | } | ||||
static EncodedSubgraph make_backward_graph_from_forward( | |||||
EncodedSubgraph make_backward_graph_from_forward( | |||||
const EncodedSubgraph& forward_graph, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad, EncodedSubgraph forward_graph) { | |||||
const SmallVector<bool>& output_has_grad) { | |||||
using namespace std::placeholders; | using namespace std::placeholders; | ||||
using var_t = Subgraph::var_t; | using var_t = Subgraph::var_t; | ||||
using vars_t = Subgraph::vars_t; | using vars_t = Subgraph::vars_t; | ||||
@@ -179,7 +180,7 @@ EncodedSubgraph make_backward_graph( | |||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
auto forward_graph = OpDef::make_forward_graph(def, inputs); | auto forward_graph = OpDef::make_forward_graph(def, inputs); | ||||
return make_backward_graph_from_forward( | return make_backward_graph_from_forward( | ||||
inputs, input_requires_grad, output_has_grad, forward_graph); | |||||
forward_graph, inputs, input_requires_grad, output_has_grad); | |||||
} | } | ||||
} // namespace subgraph_detail | } // namespace subgraph_detail | ||||
@@ -139,7 +139,7 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||||
return {ValueRef()}; | return {ValueRef()}; | ||||
} | } | ||||
} else { | } else { | ||||
return imperative::apply(op, inputs); | |||||
return op.fallback(inputs); | |||||
} | } | ||||
} | } | ||||
@@ -62,7 +62,8 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||||
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs) | std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs) | ||||
: backward_graph(backward_graph), | : backward_graph(backward_graph), | ||||
output_mask_offset(inputs.size()), | output_mask_offset(inputs.size()), | ||||
grad_mask_offset(inputs.size() + outputs.size()) { | |||||
grad_mask_offset(inputs.size() + outputs.size()), | |||||
op(op) { | |||||
auto& save_for_backward = backward_graph->save_for_backward; | auto& save_for_backward = backward_graph->save_for_backward; | ||||
mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); | mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); | ||||
size_t count = std::count_if( | size_t count = std::count_if( | ||||
@@ -92,6 +93,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||||
closure.push_back(outputs[i]); | closure.push_back(outputs[i]); | ||||
} | } | ||||
} | } | ||||
if (outputs.size() > 1) { | |||||
output_descs.reserve(outputs.size()); | |||||
for (auto&& output : outputs) { | |||||
auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0]; | |||||
output_descs.push_back({symbolic_shape, output.dtype(), output.device()}); | |||||
} | |||||
} | |||||
} | } | ||||
void BackwardGraphWithClosure::operator()( | void BackwardGraphWithClosure::operator()( | ||||
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | ||||
@@ -100,23 +108,46 @@ void BackwardGraphWithClosure::operator()( | |||||
for (auto&& value : closure) { | for (auto&& value : closure) { | ||||
args[nargs++] = value; | args[nargs++] = value; | ||||
} | } | ||||
bool null_grad = false; | |||||
size_t null_grad = 0; | |||||
size_t valid_grad = 0; | |||||
for (size_t i = 0; i < grads.size(); ++i) { | for (size_t i = 0; i < grads.size(); ++i) { | ||||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | if (backward_graph->save_for_backward[grad_mask_offset + i]) { | ||||
if (grads[i]) { | if (grads[i]) { | ||||
mgb_assert(!null_grad, "null_grad"); | |||||
valid_grad++; | |||||
args[nargs++] = grads[i]; | args[nargs++] = grads[i]; | ||||
} else { | } else { | ||||
null_grad = true; | |||||
null_grad++; | |||||
nargs++; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (null_grad) { | |||||
if (valid_grad == 0) { | |||||
return; | return; | ||||
} | } | ||||
auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||||
SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()}; | |||||
igrads_.clear(); | |||||
if (null_grad > 0) { | |||||
auto zeros_like = [](const OutputDesc& desc) { | |||||
HostTensorStorage storage(*desc.device); | |||||
storage.ensure_size(desc.dtype->size()); | |||||
std::memset(storage.ptr(), 0, desc.dtype->size()); | |||||
auto t = imperative::apply( | |||||
CreateTensor( | |||||
CreateTensor::Unique, *desc.device, *desc.dtype, | |||||
ValueShape()), | |||||
HostStorage::make(storage))[0]; | |||||
auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0]; | |||||
return res; | |||||
}; | |||||
nargs = closure.size(); | |||||
for (size_t i = 0; i < grads.size(); ++i) { | |||||
if (backward_graph->save_for_backward[grad_mask_offset + i]) { | |||||
if (!grads[i]) { | |||||
args[nargs] = zeros_like(output_descs[i]); | |||||
} | |||||
nargs++; | |||||
} | |||||
} | |||||
} | |||||
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||||
auto&& iter = igrads.begin(); | auto&& iter = igrads.begin(); | ||||
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { | for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { | ||||
if (p) { | if (p) { | ||||
@@ -221,9 +252,11 @@ void GradKey::backward() { | |||||
if (!dest) { | if (!dest) { | ||||
continue; | continue; | ||||
} | } | ||||
if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { | |||||
if (!dest.m_producer_record.next && dest->callback) { | |||||
// I'm the last grad producer, invoke callback | // I'm the last grad producer, invoke callback | ||||
dest->callback(dest->m_grad); | |||||
if (dest->m_grad) { | |||||
dest->callback(dest->m_grad); | |||||
} | |||||
} | } | ||||
} | } | ||||
grad_fn->clear(); | grad_fn->clear(); | ||||
@@ -394,16 +427,22 @@ ValueRefList GradTransformation::apply_transformation( | |||||
return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
} | } | ||||
if (auto* attach_grad = op.as<AttachGrad>()) { | if (auto* attach_grad = op.as<AttachGrad>()) { | ||||
if (!has_key(attach_grad->key())) { | |||||
auto& tensor = inputs[0]; | |||||
if (auto&& grad_value = tensor.as_ref(m_value_type)) { | |||||
mgb_assert(!has_key(attach_grad->key())); | |||||
auto output = fallback()[0]; | |||||
return record_grad(m_value_type.make(output, m_key, grad_value->slot())); | |||||
} else if (!has_key(attach_grad->key())) { | |||||
return fallback(); | return fallback(); | ||||
} else { | |||||
GenericFunction callback = | |||||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | |||||
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||||
auto ret = callback({&grad, 1}); | |||||
assert(ret.empty()); | |||||
}); | |||||
return {record_grad(output)}; | |||||
} | } | ||||
auto tensor = inputs[0]; | |||||
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||||
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { | |||||
auto ret = callback({&grad, 1}); | |||||
assert(ret.empty()); | |||||
}); | |||||
return {record_grad(output)}; | |||||
} else if (auto* grad_backward = op.as<GradBackward>()) { | } else if (auto* grad_backward = op.as<GradBackward>()) { | ||||
if (!has_key(grad_backward->key())) { | if (!has_key(grad_backward->key())) { | ||||
return fallback(); | return fallback(); | ||||
@@ -431,10 +470,10 @@ ValueRefList GradTransformation::apply_transformation( | |||||
mgb_assert(inputs.size() > nr_inputs); | mgb_assert(inputs.size() > nr_inputs); | ||||
size_t nr_outputs = inputs.size() - nr_inputs; | size_t nr_outputs = inputs.size() - nr_inputs; | ||||
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | ||||
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; | |||||
backward.m_input_has_grad = SmallVector(nr_inputs, true); | |||||
backward.m_output_attrs = | |||||
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||||
auto outputs_ = fallback(); | |||||
backward.m_input_has_grad.resize(nr_inputs, true); | |||||
backward.m_output_attrs.resize( | |||||
nr_outputs, CustomBackward::OutputAttr{true, true}); | |||||
backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) { | backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) { | ||||
auto result = fn(inputs); | auto result = fn(inputs); | ||||
return SmallVector<ValueRef>(result.begin(), result.end()); | return SmallVector<ValueRef>(result.begin(), result.end()); | ||||
@@ -31,6 +31,7 @@ class Subgraph::Builder { | |||||
using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>; | using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>; | ||||
using encoded_graph_t = EncodedSubgraph; | using encoded_graph_t = EncodedSubgraph; | ||||
using var_map_t = std::unordered_map<var_t, var_t>; | using var_map_t = std::unordered_map<var_t, var_t>; | ||||
using mask_t = SmallVector<bool>; | |||||
vars_t m_inputs; | vars_t m_inputs; | ||||
SmallVector<std::pair<var_t, TensorPtr>> m_constants; | SmallVector<std::pair<var_t, TensorPtr>> m_constants; | ||||
vars_t m_outputs; | vars_t m_outputs; | ||||
@@ -94,6 +95,7 @@ public: | |||||
descs_t get_descs(vars_t vars) { | descs_t get_descs(vars_t vars) { | ||||
descs_t descs; | descs_t descs; | ||||
for (auto&& var : vars) { | for (auto&& var : vars) { | ||||
mgb_assert(var, "invalid var"); | |||||
descs.push_back(get_desc(var)); | descs.push_back(get_desc(var)); | ||||
} | } | ||||
return descs; | return descs; | ||||
@@ -128,4 +130,4 @@ public: | |||||
expr_iter_t end() { return m_exprs.end(); } | expr_iter_t end() { return m_exprs.end(); } | ||||
}; | }; | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | |||||
} // namespace mgb |
@@ -38,7 +38,6 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||||
std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
SmallVector<CompNode> devices; | SmallVector<CompNode> devices; | ||||
SmallVector<DType> dtypes; | SmallVector<DType> dtypes; | ||||
EncodedSubgraph graph; | |||||
ShapeInfer() = default; | ShapeInfer() = default; | ||||
ShapeInfer( | ShapeInfer( | ||||
std::shared_ptr<OpDef> op, SmallVector<CompNode> devices, | std::shared_ptr<OpDef> op, SmallVector<CompNode> devices, | ||||
@@ -39,6 +39,11 @@ EncodedSubgraph make_backward_graph( | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs); | const OpDef& def, const SmallVector<TensorPtr>& inputs); | ||||
EncodedSubgraph make_backward_graph_from_forward( | |||||
const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad); | |||||
} // namespace subgraph_detail | } // namespace subgraph_detail | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | |||||
} // namespace mgb |
@@ -29,6 +29,15 @@ struct BackwardGraphWithClosure { | |||||
SmallVector<ValueRef> closure; | SmallVector<ValueRef> closure; | ||||
size_t output_mask_offset; | size_t output_mask_offset; | ||||
size_t grad_mask_offset; | size_t grad_mask_offset; | ||||
std::shared_ptr<OpDef> op; | |||||
struct OutputDesc { | |||||
ValueRef shape; | |||||
DTypeValue::ref_t dtype; | |||||
CompNodeValue::ref_t device; | |||||
}; | |||||
SmallVector<OutputDesc> output_descs; | |||||
BackwardGraphWithClosure( | BackwardGraphWithClosure( | ||||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | ||||
@@ -356,20 +365,22 @@ public: | |||||
class SetGrad : public OperatorImpl<SetGrad> { | class SetGrad : public OperatorImpl<SetGrad> { | ||||
private: | private: | ||||
std::shared_ptr<GradKey> m_key; | |||||
GenericFunction m_grad_fn; | GenericFunction m_grad_fn; | ||||
size_t m_nr_inputs; | size_t m_nr_inputs; | ||||
public: | public: | ||||
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | |||||
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||||
SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||||
: m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||||
GenericFunction grad_fn() const { return m_grad_fn; } | GenericFunction grad_fn() const { return m_grad_fn; } | ||||
size_t nr_inputs() const { return m_nr_inputs; } | size_t nr_inputs() const { return m_nr_inputs; } | ||||
std::string to_string() const override { | |||||
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | |||||
std::string to_string() const override { return ssprintf("SetGradValue{}"); } | |||||
ValueRefList fallback(Span<ValueRef> inputs) const override { | |||||
auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs); | |||||
return {outputs.begin(), outputs.end()}; | |||||
} | } | ||||
}; | }; | ||||
@@ -15,12 +15,14 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <sstream> | #include <sstream> | ||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
template <typename T> | |||||
class CleanupGuard { | |||||
template <typename T = std::function<void()>> | |||||
class CleanupGuard : public NonCopyableObj { | |||||
private: | private: | ||||
T m_callback; | T m_callback; | ||||
@@ -37,4 +39,4 @@ inline std::string quoted(std::string str) { | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | |||||
} // namespace mgb |
@@ -89,6 +89,8 @@ def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWin | |||||
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | ||||
def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>; | |||||
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | ||||
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | ||||