Browse Source

perf(mge): add memory optimization for backward graph

precompute ops in forward to reduce saved tensor size

GitOrigin-RevId: d67043ba82
release-1.2
Megvii Engine Team 4 years ago
parent
commit
278b2baa8c
5 changed files with 282 additions and 38 deletions
  1. +24
    -14
      imperative/python/src/grad.cpp
  2. +1
    -1
      imperative/python/src/tensor.h
  3. +114
    -0
      imperative/src/impl/backward_graph_opt.cpp
  4. +25
    -0
      imperative/src/include/megbrain/imperative/backward_graph_opt.h
  5. +118
    -23
      imperative/src/test/backward_graph.cpp

+ 24
- 14
imperative/python/src/grad.cpp View File

@@ -11,6 +11,7 @@


#include "./grad.h" #include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
@@ -32,14 +33,14 @@ struct GradSlotWeakPtr {
size_t idx; size_t idx;
}; };


struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<void> on_comp_node_finalize() override {
clear(); clear();
return {}; return {};
} }
} backward_graph_cache; } backward_graph_cache;


std::shared_ptr<BackwardGraphResult> make_backward_graph(
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) { ApplyContext& ctx, const apply_result_t& outputs) {
// hash // hash
static_assert(alignof(size_t) % alignof(bool) == 0); static_assert(alignof(size_t) % alignof(bool) == 0);
@@ -72,23 +73,23 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
inputs[i].layout.dtype = ctx.args[i]->dtype(); inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = python::input_requires_grad(ctx, i); input_requires_grad[i] = python::input_requires_grad(ctx, i);
} }
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad);
if (bg.backward) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
} }
backward_graph_cache.emplace(key, result);
return result;
backward_graph_cache.emplace(key, ret);
return ret;
} }


struct BackwardGraphWithClosure { struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph;
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure; SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset; size_t output_mask_offset;
size_t grad_mask_offset; size_t grad_mask_offset;


BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_,
BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs) ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_), : backward_graph(backward_graph_),
output_mask_offset(ctx.nargs), output_mask_offset(ctx.nargs),
@@ -107,9 +108,18 @@ struct BackwardGraphWithClosure {
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] // b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward; auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
closure.reserve(std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{}));
size_t count = std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{});
if (backward_graph->precomp) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
closure.reserve(precomp.size() + count);
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
} else {
closure.reserve(count);
}
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) { if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this()); closure.push_back(ctx.args[i]->shared_from_this());


+ 1
- 1
imperative/python/src/tensor.h View File

@@ -212,7 +212,7 @@ decltype(auto) resolve_arrow(T&& p) {
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
return resolve_arrow(p.operator->()); return resolve_arrow(p.operator->());
} else { } else {
return p;
return std::forward<T>(p);
} }
} }
} }


+ 114
- 0
imperative/src/impl/backward_graph_opt.cpp View File

@@ -0,0 +1,114 @@
/**
* \file imperative/src/impl/backward_graph_opt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/autogen.h"

using namespace mgb;
using namespace imperative;

OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
: input_has_grad(src.input_has_grad) {
if (!src.backward->same_type<BackwardGraph>()) {
// backward graph only contains a single op
backward = src.backward;
save_for_backward = src.save_for_backward;
return;
}
save_for_backward.resize(src.save_for_backward.size(), false);
precomp.reset(new BackwardGraph);
backward.reset(new BackwardGraph);

auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
auto&& mask = src.save_for_backward;
size_t input_size = src.input_has_grad.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());

auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
auto& bgraph = backward->cast_final<BackwardGraph>().graph();

// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint

struct VInfo {
bool appears_in_backward = false;
};
std::unordered_map<size_t, VInfo> vinfo;

// step 1.1: ops not in whitelist must run in backward.
// mark their inputs as always appears in backward
for (auto&& [op, iv, ov] : graph.exprs) {
if (!op->same_type<GetVarShape>()) {
for (auto&& v : iv) {
vinfo[v].appears_in_backward = true;
}
}
}
// step 1.2: inputs only available in backward (i.e. grads)
// should be marked as always appears in backward
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (!mask[i]) continue;
if (i > input_size + output_size) {
vinfo[graph.inputs[j]].appears_in_backward = true;
}
++j;
}

// step 2: try to move ops to forward, if not all their inputs
// are marked always appears in backward (otherwise no memory saving)
for (auto&& expr : graph.exprs) {
auto&& [op, iv, ov] = expr;
if (std::all_of(iv.begin(), iv.end(), [&](auto&& v){return vinfo[v].appears_in_backward;})) {
bgraph.exprs.push_back(expr);
for (auto&& v : ov) {
vinfo[v].appears_in_backward = true;
}
// logically should also mark all inputs as appears in backward
// but clearly that's a no-op.
} else {
fgraph.exprs.push_back(expr);
for (auto&& v : ov) {
if (vinfo[v].appears_in_backward) {
// appears_in_backward won't change after this point
// so it is safe to set fgraph.outputs based on current value
fgraph.outputs.push_back(v);
}
}
}
}

// initialize remaining parts

fgraph.constants = graph.constants;
fgraph.inputs.reserve(input_size + output_size);
for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
if (!mask[i]) {
fgraph.inputs.push_back(1000000000 + i);
continue;
}
fgraph.inputs.push_back(graph.inputs[j++]);
}

bgraph.constants = graph.constants;
bgraph.outputs = graph.outputs;
bgraph.inputs = fgraph.outputs;
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (mask[i]) {
auto&& v = graph.inputs[j++];
if (vinfo[v].appears_in_backward) {
save_for_backward[i] = true;
bgraph.inputs.push_back(v);
}
}
}
}

+ 25
- 0
imperative/src/include/megbrain/imperative/backward_graph_opt.h View File

@@ -0,0 +1,25 @@
/**
* \file imperative/src/include/megbrain/imperative/backward_graph_opt.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./op_def.h"

namespace mgb::imperative {

struct OptimizedBackwardGraphResult {
std::shared_ptr<OpDef> precomp;
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;

OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph);
};

} // namespace mgb::imperative

+ 118
- 23
imperative/src/test/backward_graph.cpp View File

@@ -13,11 +13,68 @@
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/backward_graph_opt.h"


using namespace mgb; using namespace mgb;
using namespace cg; using namespace cg;
using namespace imperative; using namespace imperative;


template <typename T>
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) {
T ret;
size_t i = 0;
for (auto&& t : inputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : outputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : grads) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
return ret;
}

template <typename T, typename U>
T expand_grads(const U& bg, const T& outputs) {
T ret(bg.input_has_grad.size());
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) {
if (bg.input_has_grad[i]) {
ret[i] = outputs[j++];
}
}
return ret;
}

template <typename T>
T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) {
T ret = precomp;
size_t i = 0;
for (auto&& t : inputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : outputs) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
for (auto&& t : grads) {
if (bg.save_for_backward[i++]) {
ret.push_back(t);
}
}
return ret;
}

TEST(TestImperative, BackwardGraphBasic) { TEST(TestImperative, BackwardGraphBasic) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
SmallVector<HostTensorND> hvs; SmallVector<HostTensorND> hvs;
@@ -121,27 +178,65 @@ TEST(TestImperative, BackwardGraphIdentity) {
} }


TEST(TestImperative, BatchNormGrad) { TEST(TestImperative, BatchNormGrad) {
auto cn = CompNode::load("xpux");
using Param = opr::BatchNorm::Param;
size_t N=2, C=3, H=5, W=5;
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{
auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
{true, true ,true, false, false}, {false, false, false, false, true});
}
{
auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat},
{true, true ,true}, {false, false, true});
}
auto cn = CompNode::load("xpux");
using Param = opr::BatchNorm::Param;
size_t N=2, C=3, H=5, W=5;
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{
auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
{true, true ,true, false, false}, {false, false, false, false, true});
}
{
auto op = OprAttr::make("BatchNorm");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat},
{true, true ,true}, {false, false, true});
}
}

TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto cn = CompNode::load("xpux");
LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn};
HostTensorGenerator<> gen;
auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD));
auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
auto obg = OptimizedBackwardGraphResult(bg);
ASSERT_EQ(obg.save_for_backward.size(), 4);
ASSERT_FALSE(obg.save_for_backward[0]);
ASSERT_FALSE(obg.save_for_backward[1]);
ASSERT_FALSE(obg.save_for_backward[2]);

auto a_hv = gen({42});
auto b_hv = gen({5, 42});
auto dc_hv = gen({5, 42});
auto a_tn = Tensor::make(*a_hv);
auto b_tn = Tensor::make(*b_hv);
auto dc_tn = Tensor::make(*dc_hv);
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];

auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs));

auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn});
ASSERT_EQ(precomp.size(), 2);
ASSERT_EQ(precomp[0]->shape().ndim, 1);
ASSERT_LE(precomp[0]->shape()[0], 2);
ASSERT_EQ(precomp[1]->shape().ndim, 1);
ASSERT_LE(precomp[1]->shape()[0], 2);

auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs));

ASSERT_EQ(grads2.size(), 2);
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value());
} }

Loading…
Cancel
Save