This reverts commitrelease-1.2d67043ba82
. GitOrigin-RevId:2be2e2ef16
@@ -11,7 +11,6 @@ | |||
#include "./grad.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/utils/mempool.h" | |||
@@ -33,14 +32,14 @@ struct GradSlotWeakPtr { | |||
size_t idx; | |||
}; | |||
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject { | |||
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
clear(); | |||
return {}; | |||
} | |||
} backward_graph_cache; | |||
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
ApplyContext& ctx, const apply_result_t& outputs) { | |||
// hash | |||
static_assert(alignof(size_t) % alignof(bool) == 0); | |||
@@ -73,23 +72,23 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
} | |||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
auto bg = proxy_graph_detail::make_backward_graph( | |||
*ctx.op, inputs, input_requires_grad, output_has_grad); | |||
if (bg.backward) { | |||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
auto result = std::make_shared<BackwardGraphResult>( | |||
proxy_graph_detail::make_backward_graph( | |||
*ctx.op, inputs, input_requires_grad, output_has_grad)); | |||
if (!result->backward) { | |||
result.reset(); | |||
} | |||
backward_graph_cache.emplace(key, ret); | |||
return ret; | |||
backward_graph_cache.emplace(key, result); | |||
return result; | |||
} | |||
struct BackwardGraphWithClosure { | |||
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph; | |||
std::shared_ptr<BackwardGraphResult> backward_graph; | |||
SmallVector<std::shared_ptr<Tensor>> closure; | |||
size_t output_mask_offset; | |||
size_t grad_mask_offset; | |||
BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_, | |||
BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_, | |||
ApplyContext& ctx, const apply_result_t& outputs) | |||
: backward_graph(backward_graph_), | |||
output_mask_offset(ctx.nargs), | |||
@@ -108,18 +107,9 @@ struct BackwardGraphWithClosure { | |||
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] | |||
auto& save_for_backward = backward_graph->save_for_backward; | |||
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); | |||
size_t count = std::count_if(save_for_backward.begin(), | |||
save_for_backward.end(), | |||
ranges::identity{}); | |||
if (backward_graph->precomp) { | |||
auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | |||
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||
closure.reserve(precomp.size() + count); | |||
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); | |||
} else { | |||
closure.reserve(count); | |||
} | |||
closure.reserve(std::count_if(save_for_backward.begin(), | |||
save_for_backward.end(), | |||
ranges::identity{})); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
if (save_for_backward[i]) { | |||
closure.push_back(ctx.args[i]->shared_from_this()); | |||
@@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) { | |||
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||
return resolve_arrow(p.operator->()); | |||
} else { | |||
return std::forward<T>(p); | |||
return p; | |||
} | |||
} | |||
} | |||
@@ -1,114 +0,0 @@ | |||
/** | |||
* \file imperative/src/impl/backward_graph_opt.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
using namespace mgb; | |||
using namespace imperative; | |||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||
: input_has_grad(src.input_has_grad) { | |||
if (!src.backward->same_type<BackwardGraph>()) { | |||
// backward graph only contains a single op | |||
backward = src.backward; | |||
save_for_backward = src.save_for_backward; | |||
return; | |||
} | |||
save_for_backward.resize(src.save_for_backward.size(), false); | |||
precomp.reset(new BackwardGraph); | |||
backward.reset(new BackwardGraph); | |||
auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||
auto&& mask = src.save_for_backward; | |||
size_t input_size = src.input_has_grad.size(); | |||
size_t output_size = (mask.size() - input_size) / 2; | |||
mgb_assert(input_size + output_size * 2 == mask.size()); | |||
auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||
auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||
// optimization: move ops (e.g. GetVarShape) to forward to | |||
// reduce memory footprint | |||
struct VInfo { | |||
bool appears_in_backward = false; | |||
}; | |||
std::unordered_map<size_t, VInfo> vinfo; | |||
// step 1.1: ops not in whitelist must run in backward. | |||
// mark their inputs as always appears in backward | |||
for (auto&& [op, iv, ov] : graph.exprs) { | |||
if (!op->same_type<GetVarShape>()) { | |||
for (auto&& v : iv) { | |||
vinfo[v].appears_in_backward = true; | |||
} | |||
} | |||
} | |||
// step 1.2: inputs only available in backward (i.e. grads) | |||
// should be marked as always appears in backward | |||
for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
if (!mask[i]) continue; | |||
if (i > input_size + output_size) { | |||
vinfo[graph.inputs[j]].appears_in_backward = true; | |||
} | |||
++j; | |||
} | |||
// step 2: try to move ops to forward, if not all their inputs | |||
// are marked always appears in backward (otherwise no memory saving) | |||
for (auto&& expr : graph.exprs) { | |||
auto&& [op, iv, ov] = expr; | |||
if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) { | |||
bgraph.exprs.push_back(expr); | |||
for (auto&& v : ov) { | |||
vinfo[v].appears_in_backward = true; | |||
} | |||
// logically should also mark all inputs as appears in backward | |||
// but clearly that's a no-op. | |||
} else { | |||
fgraph.exprs.push_back(expr); | |||
for (auto&& v : ov) { | |||
if (vinfo[v].appears_in_backward) { | |||
// appears_in_backward won't change after this point | |||
// so it is safe to set fgraph.outputs based on current value | |||
fgraph.outputs.push_back(v); | |||
} | |||
} | |||
} | |||
} | |||
// initialize remaining parts | |||
fgraph.constants = graph.constants; | |||
fgraph.inputs.reserve(input_size + output_size); | |||
for (size_t i = 0, j = 0; i < input_size + output_size; ++i) { | |||
if (!mask[i]) { | |||
fgraph.inputs.push_back(1000000000 + i); | |||
continue; | |||
} | |||
fgraph.inputs.push_back(graph.inputs[j++]); | |||
} | |||
bgraph.constants = graph.constants; | |||
bgraph.outputs = graph.outputs; | |||
bgraph.inputs = fgraph.outputs; | |||
for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
if (mask[i]) { | |||
auto&& v = graph.inputs[j++]; | |||
if (vinfo[v].appears_in_backward) { | |||
save_for_backward[i] = true; | |||
bgraph.inputs.push_back(v); | |||
} | |||
} | |||
} | |||
} |
@@ -1,25 +0,0 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/backward_graph_opt.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./op_def.h" | |||
namespace mgb::imperative { | |||
struct OptimizedBackwardGraphResult { | |||
std::shared_ptr<OpDef> precomp; | |||
std::shared_ptr<OpDef> backward; | |||
std::vector<bool> save_for_backward; | |||
std::vector<bool> input_has_grad; | |||
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||
}; | |||
} // namespace mgb::imperative |
@@ -13,68 +13,11 @@ | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/backward_graph_opt.h" | |||
using namespace mgb; | |||
using namespace cg; | |||
using namespace imperative; | |||
template <typename T> | |||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||
T ret; | |||
size_t i = 0; | |||
for (auto&& t : inputs) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : outputs) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : grads) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
return ret; | |||
} | |||
template <typename T, typename U> | |||
T expand_grads(const U& bg, const T& outputs) { | |||
T ret(bg.input_has_grad.size()); | |||
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||
if (bg.input_has_grad[i]) { | |||
ret[i] = outputs[j++]; | |||
} | |||
} | |||
return ret; | |||
} | |||
template <typename T> | |||
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||
T ret = precomp; | |||
size_t i = 0; | |||
for (auto&& t : inputs) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : outputs) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : grads) { | |||
if (bg.save_for_backward[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
return ret; | |||
} | |||
TEST(TestImperative, BackwardGraphBasic) { | |||
HostTensorGenerator<> gen; | |||
SmallVector<HostTensorND> hvs; | |||
@@ -178,65 +121,27 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
} | |||
TEST(TestImperative, BatchNormGrad) { | |||
auto cn = CompNode::load("xpux"); | |||
using Param = opr::BatchNorm::Param; | |||
size_t N=2, C=3, H=5, W=5; | |||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
{true, true ,true, false, false}, {false, false, false, false, true}); | |||
} | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
{true, true ,true}, {false, false, true}); | |||
} | |||
} | |||
TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
auto cn = CompNode::load("xpux"); | |||
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | |||
HostTensorGenerator<> gen; | |||
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | |||
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
auto obg = OptimizedBackwardGraphResult(bg); | |||
ASSERT_EQ(obg.save_for_backward.size(), 4); | |||
ASSERT_FALSE(obg.save_for_backward[0]); | |||
ASSERT_FALSE(obg.save_for_backward[1]); | |||
ASSERT_FALSE(obg.save_for_backward[2]); | |||
auto a_hv = gen({42}); | |||
auto b_hv = gen({5, 42}); | |||
auto dc_hv = gen({5, 42}); | |||
auto a_tn = Tensor::make(*a_hv); | |||
auto b_tn = Tensor::make(*b_hv); | |||
auto dc_tn = Tensor::make(*dc_hv); | |||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||
auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||
ASSERT_EQ(precomp.size(), 2); | |||
ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||
ASSERT_LE(precomp[0]->shape()[0], 2); | |||
ASSERT_EQ(precomp[1]->shape().ndim, 1); | |||
ASSERT_LE(precomp[1]->shape()[0], 2); | |||
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||
ASSERT_EQ(grads2.size(), 2); | |||
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||
MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value()); | |||
auto cn = CompNode::load("xpux"); | |||
using Param = opr::BatchNorm::Param; | |||
size_t N=2, C=3, H=5, W=5; | |||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
{true, true ,true, false, false}, {false, false, false, false, true}); | |||
} | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
attr.param.write_pod(param); | |||
OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
{true, true ,true}, {false, false, true}); | |||
} | |||
} |