@@ -0,0 +1,161 @@ | |||||
/** | |||||
* \file src/core/impl/graph/seq_modifier_base.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./seq_modifier_base.h" | |||||
#if MGB_ENABLE_SUBLINEAR | |||||
using namespace mgb; | |||||
using namespace cg; | |||||
void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) { | |||||
m_orig_opr_seq = &opr_seq; | |||||
m_var_storage.clear(); | |||||
m_seq.clear(); | |||||
m_var_mempool.reorder_free(); | |||||
m_opr_mempool.reorder_free(); | |||||
m_nr_endpoint_oprs = 0; | |||||
ThinHashMap<VarNode*, Var*> varmap; | |||||
for (auto orig_opr : *m_orig_opr_seq) { | |||||
auto time = m_seq.size(); | |||||
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||||
auto opr = m_seq.back().get(); | |||||
m_nr_endpoint_oprs += opr->is_endpoint; | |||||
for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||||
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||||
continue; | |||||
auto iter = varmap.find(dep.first); | |||||
if (iter == varmap.end()) { | |||||
// input var needs not to be considered | |||||
continue; | |||||
} | |||||
auto ivar = iter->second; | |||||
bool exist = false; | |||||
for (auto i : opr->input) { | |||||
if (i == ivar) { | |||||
exist = true; | |||||
break; | |||||
} | |||||
} | |||||
if (exist) { | |||||
// same var for different inputs | |||||
continue; | |||||
} | |||||
opr->input.push_back(ivar); | |||||
auto&& prev_rec = ivar->access_rec.back(); | |||||
prev_rec.stride = time - prev_rec.opr->time; | |||||
ivar->access_rec.emplace_back(opr); | |||||
} | |||||
for (auto i : orig_opr->output()) { | |||||
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||||
auto iter = var2memsize->find(i); | |||||
if (iter == var2memsize->end()) { | |||||
// some vars are ignored; see split_into_cn2oprseq() | |||||
continue; | |||||
} | |||||
m_var_storage.emplace_back( | |||||
m_var_mempool.alloc_unique(i, iter->second, opr)); | |||||
auto ovar = m_var_storage.back().get(); | |||||
varmap[i] = ovar; | |||||
opr->output.push_back(ovar); | |||||
} | |||||
mgb_assert(!opr->output.empty()); | |||||
} | |||||
// remove unused output | |||||
for (auto&& i : m_seq) { | |||||
auto&& oarr = i->output; | |||||
for (size_t j = 0; j < oarr.size();) { | |||||
if (oarr[j]->access_rec.size() == 1) { | |||||
std::swap(oarr[j], oarr.back()); | |||||
oarr.pop_back(); | |||||
} else | |||||
++j; | |||||
} | |||||
} | |||||
} | |||||
bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) { | |||||
m_new_inputs.assign(inputs.begin(), inputs.end()); | |||||
bool changed = false; | |||||
for (auto&& i : m_new_inputs) { | |||||
auto iter = m_var_map.find(i); | |||||
if (iter != m_var_map.end()) { | |||||
i = iter->second; | |||||
changed = true; | |||||
} | |||||
} | |||||
return changed; | |||||
} | |||||
OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs( | |||||
OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) { | |||||
auto config = opr->config(); | |||||
// update operator instance id to bybass the shallow copy's cache if | |||||
// it's a dup-opr-copying due to discarding. | |||||
// Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||||
// because: | |||||
// 0) recomp-opr would be copied iff its input vars is changed | |||||
// 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||||
// and config, we use instance id to differentiate them. | |||||
config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt)); | |||||
config.update_instance_id(reinterpret_cast<void*>( | |||||
reinterpret_cast<size_t>(this) + | |||||
((static_cast<size_t>(recomp) + 1) << 10) * recomp_cnt)); | |||||
// Note: if all outputs of op were placed on the same comp_node, since its | |||||
// stream maybe changed during seq_comp_node_opt, output's comp_node has | |||||
// higher priority than opr->config() | |||||
auto out_cn = opr->output(0)->comp_node(); | |||||
for (auto i : opr->output()) { | |||||
auto cn = i->comp_node(); | |||||
if (out_cn != cn) { | |||||
out_cn = {}; | |||||
break; | |||||
} | |||||
} | |||||
if (out_cn.valid()) | |||||
config.comp_node(out_cn); | |||||
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||||
mgb_assert(opr_new != opr); | |||||
auto&& out0 = opr->output(); | |||||
auto&& out1 = opr_new->output(); | |||||
mgb_assert(out0.size() == out1.size()); | |||||
bool stream_changed = false; | |||||
for (size_t i = 0; i < out0.size(); ++i) { | |||||
auto &&cn0 = out0[i]->comp_node(), | |||||
&&cn1 = out1[i]->comp_node(); | |||||
if (cn0 != cn1) { | |||||
mgb_assert(recomp); | |||||
mgb_assert(cn0.locator().type == cn1.locator().type && | |||||
cn0.locator().device == cn1.locator().device); | |||||
out1[i]->comp_node(cn0); | |||||
stream_changed = true; | |||||
} | |||||
m_var_map[out0[i]] = out1[i]; | |||||
} | |||||
if (stream_changed) { | |||||
opr_new->on_output_comp_node_stream_changed(); | |||||
} | |||||
return opr_new; | |||||
} | |||||
#endif // MGB_ENABLE_SUBLINEAR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,237 @@ | |||||
/** | |||||
* \file src/core/impl/graph/seq_modifier_base.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "./memory_optimizer.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/graph/cg.h" | |||||
#include "megbrain/plugin/opr_footprint.h" | |||||
#include "megbrain/serialization/opr_shallow_copy.h" | |||||
#include "megbrain/system.h" | |||||
#include "megbrain/utils/async_worker.h" | |||||
#include "megbrain/utils/arith_helper.h" | |||||
#include "megbrain/utils/mempool.h" | |||||
#include "megbrain/utils/timer.h" | |||||
#if MGB_ENABLE_SUBLINEAR | |||||
namespace mgb { | |||||
namespace cg { | |||||
/*! | |||||
* \brief modifying computing sequence, with basically the same idea of Training | |||||
* Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization | |||||
*/ | |||||
class SeqModifierBase { | |||||
public: | |||||
/*! | |||||
* describes modifications that should be applied to an operator sequnce: | |||||
* maps from an opr to the oprs that should be duplicated and inserted | |||||
* before it. | |||||
*/ | |||||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||||
struct Var; | |||||
struct Opr; | |||||
class ModifyActionPlannerBase { | |||||
const SeqModifierBase* const m_par_modifier; | |||||
const OprNodeArray* m_orig_opr_seq; | |||||
MemPool<Var> m_var_mempool; | |||||
MemPool<Opr> m_opr_mempool; | |||||
std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||||
std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||||
size_t m_nr_endpoint_oprs = 0; | |||||
public: | |||||
//! special creation time used for oprs duplicated from others | |||||
static constexpr size_t DUPOPR_TIME = | |||||
std::numeric_limits<size_t>::max() - 1; | |||||
const SeqModifierBase* const par_modifier() { | |||||
return m_par_modifier; | |||||
} | |||||
const OprNodeArray* const orig_opr_seq() { | |||||
return m_orig_opr_seq; | |||||
} | |||||
MemPool<Var>& var_mempool() { | |||||
return m_var_mempool; | |||||
} | |||||
MemPool<Opr>& opr_mempool() { | |||||
return m_opr_mempool; | |||||
} | |||||
std::vector<MemPool<Var>::UniquePtr>& var_storage() { | |||||
return m_var_storage; | |||||
} | |||||
std::vector<MemPool<Opr>::UniquePtr>& seq() { | |||||
return m_seq; | |||||
} | |||||
size_t& nr_endpoint_oprs() { | |||||
return m_nr_endpoint_oprs; | |||||
} | |||||
ModifyActionPlannerBase(SeqModifierBase* par) | |||||
: m_par_modifier{par} {} | |||||
~ModifyActionPlannerBase() noexcept { | |||||
m_opr_mempool.disable_freelist(); | |||||
m_var_mempool.disable_freelist(); | |||||
} | |||||
//! init m_orig_opr_seq from opr_seq, should be called first. | |||||
void init_seq(const OprNodeArray& opr_seq); | |||||
}; | |||||
SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {} | |||||
MemoryOptimizerHelper& mem_opt() { | |||||
return m_mem_opt; | |||||
} | |||||
ComputingGraphImpl* const owner_graph() { | |||||
return m_owner_graph; | |||||
} | |||||
ThinHashMap<VarNode*, VarNode*>& var_map() { | |||||
return m_var_map; | |||||
} | |||||
/*! | |||||
* \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||||
* m_var_map | |||||
* \return new operator | |||||
*/ | |||||
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, bool recomp, size_t recomp_cnt=0); | |||||
/*! | |||||
* \brief replace input vars according to m_var_map, and store results in | |||||
* m_new_inputs; | |||||
* \return whether any var is changed | |||||
*/ | |||||
bool replace_vars(const VarNodeArray& inputs); | |||||
//! see memory_optimizer set_priority_before_opt | |||||
void set_priority_before_opt(const VarNodeArray& endpoints) { | |||||
m_mem_opt.set_priority_before_opt(endpoints); | |||||
} | |||||
//! see memory_optimizer restore_graph_option | |||||
void restore_graph_option() { | |||||
m_mem_opt.restore_graph_option(); | |||||
} | |||||
private: | |||||
MemoryOptimizerHelper m_mem_opt; | |||||
ComputingGraphImpl* const m_owner_graph = nullptr; | |||||
//! map from original var to replaced var | |||||
ThinHashMap<VarNode*, VarNode*> m_var_map; | |||||
VarNodeArray m_new_inputs; //!< setup by replace_vars | |||||
}; | |||||
struct SeqModifierBase::Opr { | |||||
OperatorNodeBase* const orig_opr; | |||||
std::vector<Var*> input, output; | |||||
const size_t time; //!< index in opr sequence | |||||
const bool is_endpoint; | |||||
double estimate_compute_time = 1; | |||||
//! input vars that have been discarded and need to be recomputed before | |||||
//! this opr; for internal use by apply_discard_plan() | |||||
std::vector<Var*> inputs_to_recompute; | |||||
//! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||||
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||||
//! [begin, end) interval of *time* for oprs belonging to this block; setup | |||||
//! by make_discard_plan() | |||||
size_t block_begin_time = 0, block_end_time = 0; | |||||
Opr(OperatorNodeBase* opr, size_t t) | |||||
: orig_opr{opr}, | |||||
time{t}, | |||||
is_endpoint{opr->owner_graph() | |||||
->options() | |||||
.opr_attribute.get_sublinear_memory_endpoint( | |||||
opr)} {} | |||||
}; | |||||
struct SeqModifierBase::Var { | |||||
VarNode* const orig_var; | |||||
size_t size; //!< memory usage in bytes of this var | |||||
size_t recomp_id = 0; | |||||
double last_access_time = 0; | |||||
//! write or read access of a var | |||||
struct AccessRecord { | |||||
Opr* const opr; | |||||
const size_t time; | |||||
size_t stride; | |||||
explicit AccessRecord(Opr* o = nullptr) | |||||
: opr{o}, time{o->time}, stride{0} {} | |||||
}; | |||||
//! access_rec[0] is the creation opr, and others are reader oprs | |||||
std::vector<AccessRecord> access_rec; | |||||
/*! | |||||
* An index in access_rec | |||||
* | |||||
* if valid, then the var should be discarded after | |||||
* discard_tailing_access->opr finishes | |||||
* | |||||
* setup by make_discard_plan | |||||
*/ | |||||
Maybe<size_t> discard_tailing_access; | |||||
/*! | |||||
* An index in access_rec | |||||
* maintained during make_discard_plan(), for the next access relative to | |||||
* current operator | |||||
*/ | |||||
Maybe<size_t> next_access; | |||||
AccessRecord* visit_discard_tailing_access() { | |||||
return discard_tailing_access.valid() | |||||
? &access_rec.at(discard_tailing_access.val()) | |||||
: nullptr; | |||||
} | |||||
AccessRecord* visit_next_access() { | |||||
return next_access.valid() ? &access_rec.at(next_access.val()) | |||||
: nullptr; | |||||
} | |||||
auto owner_opr() const { return access_rec[0].opr; } | |||||
auto last_access_opr() const { return access_rec.back().opr; } | |||||
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||||
access_rec.emplace_back(opr); | |||||
} | |||||
}; | |||||
} // namespace cg | |||||
} // namespace mgb | |||||
#endif // MGB_ENABLE_SUBLINEAR | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) { | |||||
} | } | ||||
} // namespace | } // namespace | ||||
/* ====================== Abstract Opr & Var ====================== */ | |||||
struct SeqModifierForSublinearMemory::Opr { | |||||
OperatorNodeBase* const orig_opr; | |||||
std::vector<Var*> input, output; | |||||
const size_t time; //!< index in opr sequence | |||||
const bool is_endpoint; | |||||
//! input vars that have been discarded and need to be recomputed before | |||||
//! this opr; for internal use by apply_discard_plan() | |||||
std::vector<Var*> inputs_to_recompute; | |||||
//! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||||
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||||
//! [begin, end) interval of *time* for oprs belonging to this block; setup | |||||
//! by make_discard_plan() | |||||
size_t block_begin_time = 0, block_end_time = 0; | |||||
Opr(OperatorNodeBase* opr, size_t t) | |||||
: orig_opr{opr}, | |||||
time{t}, | |||||
is_endpoint{opr->owner_graph() | |||||
->options() | |||||
.opr_attribute.get_sublinear_memory_endpoint( | |||||
opr)} {} | |||||
}; | |||||
struct SeqModifierForSublinearMemory::Var { | |||||
//! write or read access of a var | |||||
struct AccessRecord { | |||||
Opr* const opr; | |||||
const size_t time; | |||||
size_t stride; //!< time distance until next read; 0 for last access | |||||
explicit AccessRecord(Opr* o = nullptr) | |||||
: opr{o}, time{o->time}, stride{0} {} | |||||
}; | |||||
VarNode* const orig_var; | |||||
const size_t size; //!< memory usage in bytes of this var | |||||
//! access_rec[0] is the creation opr, and others are reader oprs | |||||
std::vector<AccessRecord> access_rec; | |||||
/*! | |||||
* An index in access_rec | |||||
* | |||||
* if valid, then the var should be discarded after | |||||
* discard_tailing_access->opr finishes | |||||
* | |||||
* setup by make_discard_plan | |||||
*/ | |||||
Maybe<size_t> discard_tailing_access; | |||||
/*! | |||||
* An index in access_rec | |||||
* maintained during make_discard_plan(), for the next access relative to | |||||
* current operator | |||||
*/ | |||||
Maybe<size_t> next_access; | |||||
AccessRecord* visit_discard_tailing_access() { | |||||
return discard_tailing_access.valid() | |||||
? &access_rec.at(discard_tailing_access.val()) | |||||
: nullptr; | |||||
} | |||||
AccessRecord* visit_next_access() { | |||||
return next_access.valid() ? &access_rec.at(next_access.val()) | |||||
: nullptr; | |||||
} | |||||
auto owner_opr() const { return access_rec[0].opr; } | |||||
auto last_access_opr() const { return access_rec.back().opr; } | |||||
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||||
access_rec.emplace_back(opr); | |||||
} | |||||
}; | |||||
/* ====================== ModifyActionPlanner ====================== */ | /* ====================== ModifyActionPlanner ====================== */ | ||||
class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||||
//! special creation time used for oprs duplicated from others | |||||
static constexpr size_t DUPOPR_TIME = | |||||
std::numeric_limits<size_t>::max() - 1; | |||||
class SeqModifierForSublinearMemory::ModifyActionPlanner : public ModifyActionPlannerBase { | |||||
using VarArray = std::vector<Var*>; | using VarArray = std::vector<Var*>; | ||||
using VarSet = ThinHashSet<Var*>; | using VarSet = ThinHashSet<Var*>; | ||||
using OprArray = std::vector<Opr*>; | using OprArray = std::vector<Opr*>; | ||||
const SeqModifierForSublinearMemory* const m_par_modifier; | |||||
const OprNodeArray* m_orig_opr_seq; | |||||
MemPool<Var> m_var_mempool; | |||||
MemPool<Opr> m_opr_mempool; | |||||
std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||||
std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||||
size_t m_nr_endpoint_oprs = 0; | |||||
VarSet m_prev_block_discard_vars; | VarSet m_prev_block_discard_vars; | ||||
std::vector<OprArray> m_blocks; | std::vector<OprArray> m_blocks; | ||||
SeqModifyAction m_action; | |||||
//! split_point_set to block | //! split_point_set to block | ||||
void split_into_blocks(const SplitPointSet& split_point_set); | void split_into_blocks(const SplitPointSet& split_point_set); | ||||
@@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||||
public: | public: | ||||
ModifyActionPlanner(SeqModifierForSublinearMemory* par) | ModifyActionPlanner(SeqModifierForSublinearMemory* par) | ||||
: m_par_modifier{par} {} | |||||
~ModifyActionPlanner() noexcept { | |||||
m_opr_mempool.disable_freelist(); | |||||
m_var_mempool.disable_freelist(); | |||||
} | |||||
//! init m_orig_opr_seq from opr_seq, should be called first. | |||||
void init_seq(const OprNodeArray& opr_seq); | |||||
: ModifyActionPlannerBase{par} {} | |||||
//! generate split point set from thresh | //! generate split point set from thresh | ||||
SplitPointSet get_split_point_set(size_t block_size_thresh); | SplitPointSet get_split_point_set(size_t block_size_thresh); | ||||
@@ -213,7 +113,7 @@ public: | |||||
void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action( | void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action( | ||||
SeqModifyAction& action) { | SeqModifyAction& action) { | ||||
action.clear(); | action.clear(); | ||||
for (auto&& opr : m_seq) { | |||||
for (auto&& opr : seq()) { | |||||
auto&& arr = opr->oprs_insert_before; | auto&& arr = opr->oprs_insert_before; | ||||
if (arr.empty()) | if (arr.empty()) | ||||
continue; | continue; | ||||
@@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||||
cur_block_alive_vars.clear(); | cur_block_alive_vars.clear(); | ||||
}; | }; | ||||
for (size_t i = 0; i < m_seq.size(); ++i) { | |||||
auto opr = m_seq[i].get(); | |||||
for (size_t i = 0; i < seq().size(); ++i) { | |||||
auto opr = seq()[i].get(); | |||||
for (auto i : opr->output) | for (auto i : opr->output) | ||||
add_alive(i); | add_alive(i); | ||||
@@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||||
remove_alive(i); | remove_alive(i); | ||||
} | } | ||||
if (i + 1 < m_seq.size() && (cur_block_usage < block_size_thresh || | |||||
(m_nr_endpoint_oprs && !opr->is_endpoint))) | |||||
if (i + 1 < seq().size() && (cur_block_usage < block_size_thresh || | |||||
(nr_endpoint_oprs() && !opr->is_endpoint))) | |||||
continue; | continue; | ||||
flush_block_member(i); | flush_block_member(i); | ||||
@@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||||
return split_point_set; | return split_point_set; | ||||
} | } | ||||
void SeqModifierForSublinearMemory::ModifyActionPlanner::init_seq( | |||||
const OprNodeArray& opr_seq) { | |||||
m_orig_opr_seq = &opr_seq; | |||||
m_var_storage.clear(); | |||||
m_seq.clear(); | |||||
m_var_mempool.reorder_free(); | |||||
m_opr_mempool.reorder_free(); | |||||
m_nr_endpoint_oprs = 0; | |||||
ThinHashMap<VarNode*, Var*> varmap; | |||||
for (auto orig_opr : *m_orig_opr_seq) { | |||||
auto time = m_seq.size(); | |||||
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||||
auto opr = m_seq.back().get(); | |||||
m_nr_endpoint_oprs += opr->is_endpoint; | |||||
for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||||
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||||
continue; | |||||
auto iter = varmap.find(dep.first); | |||||
if (iter == varmap.end()) { | |||||
// input var needs not to be considered | |||||
continue; | |||||
} | |||||
auto ivar = iter->second; | |||||
bool exist = false; | |||||
for (auto i : opr->input) { | |||||
if (i == ivar) { | |||||
exist = true; | |||||
break; | |||||
} | |||||
} | |||||
if (exist) { | |||||
// same var for different inputs | |||||
continue; | |||||
} | |||||
opr->input.push_back(ivar); | |||||
auto&& prev_rec = ivar->access_rec.back(); | |||||
prev_rec.stride = time - prev_rec.opr->time; | |||||
ivar->access_rec.emplace_back(opr); | |||||
} | |||||
for (auto i : orig_opr->output()) { | |||||
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||||
auto iter = var2memsize->find(i); | |||||
if (iter == var2memsize->end()) { | |||||
// some vars are ignored; see split_into_cn2oprseq() | |||||
continue; | |||||
} | |||||
m_var_storage.emplace_back( | |||||
m_var_mempool.alloc_unique(i, iter->second, opr)); | |||||
auto ovar = m_var_storage.back().get(); | |||||
varmap[i] = ovar; | |||||
opr->output.push_back(ovar); | |||||
} | |||||
mgb_assert(!opr->output.empty()); | |||||
} | |||||
// remove unused output | |||||
for (auto&& i : m_seq) { | |||||
auto&& oarr = i->output; | |||||
for (size_t j = 0; j < oarr.size();) { | |||||
if (oarr[j]->access_rec.size() == 1) { | |||||
std::swap(oarr[j], oarr.back()); | |||||
oarr.pop_back(); | |||||
} else | |||||
++j; | |||||
} | |||||
} | |||||
} | |||||
size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | ||||
calc_bottleneck_from_discard_plan() { | calc_bottleneck_from_discard_plan() { | ||||
size_t cur_usage = 0, max_usage = 0; | size_t cur_usage = 0, max_usage = 0; | ||||
@@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | |||||
++time; | ++time; | ||||
}; | }; | ||||
for (auto&& opr : m_seq) { | |||||
for (auto&& opr : seq()) { | |||||
for (auto&& i : opr->oprs_insert_before) | for (auto&& i : opr->oprs_insert_before) | ||||
process_opr(i.get()); | process_opr(i.get()); | ||||
process_opr(opr.get()); | process_opr(opr.get()); | ||||
@@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||||
mgb_assert(opr->time < block_end); | mgb_assert(opr->time < block_end); | ||||
auto new_opr_storage = m_opr_mempool.alloc_unique( | |||||
auto new_opr_storage = opr_mempool().alloc_unique( | |||||
opr->orig_opr, static_cast<size_t>(DUPOPR_TIME)); | opr->orig_opr, static_cast<size_t>(DUPOPR_TIME)); | ||||
auto new_opr = new_opr_storage.get(); | auto new_opr = new_opr_storage.get(); | ||||
@@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||||
Var* new_var = nullptr; | Var* new_var = nullptr; | ||||
for (auto i : opr->output) { | for (auto i : opr->output) { | ||||
auto&& ovar = m_var_mempool.alloc_unique(i->orig_var, i->size, | |||||
auto&& ovar = var_mempool().alloc_unique(i->orig_var, i->size, | |||||
new_opr); | new_opr); | ||||
new_opr->output.push_back(ovar.get()); | new_opr->output.push_back(ovar.get()); | ||||
if (i == var) | if (i == var) | ||||
@@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||||
auto ins = var_map.insert({i, ovar.get()}); | auto ins = var_map.insert({i, ovar.get()}); | ||||
mgb_assert(ins.second); | mgb_assert(ins.second); | ||||
m_var_storage.emplace_back(std::move(ovar)); | |||||
var_storage().emplace_back(std::move(ovar)); | |||||
} | } | ||||
mgb_assert(new_var); | mgb_assert(new_var); | ||||
return new_var; | return new_var; | ||||
@@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||||
add_dep(var); | add_dep(var); | ||||
}; | }; | ||||
for (auto&& _raw_opr : m_seq) { | |||||
for (auto&& _raw_opr : seq()) { | |||||
auto opr = _raw_opr.get(); | auto opr = _raw_opr.get(); | ||||
for (auto i : opr->inputs_to_recompute) | for (auto i : opr->inputs_to_recompute) | ||||
@@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||||
m_blocks.clear(); | m_blocks.clear(); | ||||
std::vector<Opr*> cur_block_member; | std::vector<Opr*> cur_block_member; | ||||
size_t i, j; | size_t i, j; | ||||
for (i = j = 0; i < m_seq.size() && j < split_point_set->size(); ++i) { | |||||
auto opr = m_seq[i].get(); | |||||
for (i = j = 0; i < seq().size() && j < split_point_set->size(); ++i) { | |||||
auto opr = seq()[i].get(); | |||||
cur_block_member.push_back(opr); | cur_block_member.push_back(opr); | ||||
if (i != split_point_set->at(j)) | if (i != split_point_set->at(j)) | ||||
continue; | continue; | ||||
@@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||||
cur_block_member.clear(); | cur_block_member.clear(); | ||||
j++; | j++; | ||||
} | } | ||||
mgb_assert(i >= m_seq.size()); | |||||
mgb_assert(i >= seq().size()); | |||||
mgb_assert(j >= split_point_set->size()); | mgb_assert(j >= split_point_set->size()); | ||||
} | } | ||||
@@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()( | |||||
} | } | ||||
void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | ||||
m_var_map.clear(); | |||||
var_map().clear(); | |||||
m_opr2replace_info.clear(); | m_opr2replace_info.clear(); | ||||
auto config = | auto config = | ||||
MemoryOptimizerHelper::SubGraphConfig() | MemoryOptimizerHelper::SubGraphConfig() | ||||
@@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | |||||
.add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) | .add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) | ||||
.add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE); | .add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE); | ||||
auto cn2oprseq = m_mem_opt.split_into_cn2oprseq(oprseq, config); | |||||
auto cn2oprseq = mem_opt().split_into_cn2oprseq(oprseq, config); | |||||
if (cn2oprseq->empty()) { | if (cn2oprseq->empty()) { | ||||
// empty graph | // empty graph | ||||
@@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||||
// each operator should be set no more than once | // each operator should be set no more than once | ||||
auto set_priority = [&](OperatorNodeBase* opr) { | auto set_priority = [&](OperatorNodeBase* opr) { | ||||
mgb_assert(modified_opr.insert(opr).second); | mgb_assert(modified_opr.insert(opr).second); | ||||
m_mem_opt.set_priority(opr, cur_priority++); | |||||
mem_opt().set_priority(opr, cur_priority++); | |||||
}; | }; | ||||
auto on_opr_visited = [&](OperatorNodeBase* opr) { | auto on_opr_visited = [&](OperatorNodeBase* opr) { | ||||
@@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||||
mgb_assert(action.empty()); | mgb_assert(action.empty()); | ||||
} | } | ||||
bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { | |||||
m_new_inputs.assign(inputs.begin(), inputs.end()); | |||||
bool changed = false; | |||||
for (auto&& i : m_new_inputs) { | |||||
auto iter = m_var_map.find(i); | |||||
if (iter != m_var_map.end()) { | |||||
i = iter->second; | |||||
changed = true; | |||||
} | |||||
} | |||||
return changed; | |||||
} | |||||
OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | |||||
OperatorNodeBase* opr, bool recomp) { | |||||
auto config = opr->config(); | |||||
// update operator instance id to bybass the shallow copy's cache if | |||||
// it's a dup-opr-copying due to discarding. | |||||
// Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||||
// because: | |||||
// 0) recomp-opr would be copied iff its input vars is changed | |||||
// 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||||
// and config, we use instance id to differentiate them. | |||||
config.name(opr->name() + (recomp ? ":recomp" : ":dup")); | |||||
if (!recomp) { | |||||
config.update_instance_id(this); | |||||
} | |||||
// Note: if all outputs of op were placed on the same comp_node, since its | |||||
// stream maybe changed during seq_comp_node_opt, output's comp_node has | |||||
// higher priority than opr->config() | |||||
auto out_cn = opr->output(0)->comp_node(); | |||||
for (auto i : opr->output()) { | |||||
auto cn = i->comp_node(); | |||||
if (out_cn != cn) { | |||||
out_cn = {}; | |||||
break; | |||||
} | |||||
} | |||||
if (out_cn.valid()) | |||||
config.comp_node(out_cn); | |||||
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||||
mgb_assert(opr_new != opr); | |||||
auto&& out0 = opr->output(); | |||||
auto&& out1 = opr_new->output(); | |||||
mgb_assert(out0.size() == out1.size()); | |||||
bool stream_changed = false; | |||||
for (size_t i = 0; i < out0.size(); ++i) { | |||||
auto &&cn0 = out0[i]->comp_node(), | |||||
&&cn1 = out1[i]->comp_node(); | |||||
if (cn0 != cn1) { | |||||
mgb_assert(recomp); | |||||
mgb_assert(cn0.locator().type == cn1.locator().type && | |||||
cn0.locator().device == cn1.locator().device); | |||||
out1[i]->comp_node(cn0); | |||||
stream_changed = true; | |||||
} | |||||
m_var_map[out0[i]] = out1[i]; | |||||
} | |||||
if (stream_changed) { | |||||
opr_new->on_output_comp_node_stream_changed(); | |||||
} | |||||
return opr_new; | |||||
} | |||||
void SeqModifierForSublinearMemory::modify_endpoint_vars( | void SeqModifierForSublinearMemory::modify_endpoint_vars( | ||||
VarNodeArray& endpoints) { | VarNodeArray& endpoints) { | ||||
auto comp_seq = MemoryOptimizerHelper::CompSeq(m_owner_graph, endpoints); | |||||
auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints); | |||||
reset_opr_seq(*comp_seq.m_seq); | reset_opr_seq(*comp_seq.m_seq); | ||||
for (auto&& i : endpoints) { | for (auto&& i : endpoints) { | ||||
auto iter = m_var_map.find(i); | |||||
if (iter != m_var_map.end()) { | |||||
auto iter = var_map().find(i); | |||||
if (iter != var_map().end()) { | |||||
i = iter->second; | i = iter->second; | ||||
} | } | ||||
} | } | ||||
@@ -1357,8 +1115,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||||
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | ||||
ComputingGraphImpl* owner, Config* config_p) | ComputingGraphImpl* owner, Config* config_p) | ||||
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||||
: SeqModifierBase(owner), m_config(config_p) {} | |||||
#endif // !MGB_ENABLE_SUBLINEAR | #endif // !MGB_ENABLE_SUBLINEAR | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "./memory_optimizer.h" | #include "./memory_optimizer.h" | ||||
#include "./seq_modifier_base.h" | |||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/utils/async_worker.h" | #include "megbrain/utils/async_worker.h" | ||||
@@ -23,28 +24,31 @@ namespace cg { | |||||
* \brief modifying computing sequence, with basically the same idea of Training | * \brief modifying computing sequence, with basically the same idea of Training | ||||
* Deep Nets with Sublinear Memory Cost | * Deep Nets with Sublinear Memory Cost | ||||
*/ | */ | ||||
class SeqModifierForSublinearMemory { | |||||
/*! | |||||
* describes modifications that should be applied to an operator sequnce: | |||||
* maps from an opr to the oprs that should be duplicated and inserted | |||||
* before it. | |||||
*/ | |||||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||||
class SeqModifierForSublinearMemory : public SeqModifierBase { | |||||
//! Config options | //! Config options | ||||
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | ||||
Config* m_config; | Config* m_config; | ||||
public: | |||||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||||
//! replace endpoint vars by the ones that require more computing | |||||
void modify_endpoint_vars(VarNodeArray& endpoints); | |||||
//! check whether actual opr_seq is what we expect; throw InternalError | |||||
void sanity_check(const OprNodeArray& opr_seq); | |||||
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||||
private: | |||||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||||
//! get modifications to be taken under some specific constraints | //! get modifications to be taken under some specific constraints | ||||
class ModifyActionPlanner; | class ModifyActionPlanner; | ||||
//! search best modify action for opr seq on a single comp node | //! search best modify action for opr seq on a single comp node | ||||
class ActionSearcherSingleCN; | class ActionSearcherSingleCN; | ||||
struct Opr; | |||||
struct Var; | |||||
struct InternalDeleter { | struct InternalDeleter { | ||||
void operator()(ActionSearcherSingleCN*) const; | void operator()(ActionSearcherSingleCN*) const; | ||||
void operator()(ModifyActionPlanner*) const; | void operator()(ModifyActionPlanner*) const; | ||||
@@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory { | |||||
//! thread pool to run ModifyActionPlanner | //! thread pool to run ModifyActionPlanner | ||||
FutureThreadPool<void> m_planner_thread_pool; | FutureThreadPool<void> m_planner_thread_pool; | ||||
//! map from original var to replaced var | |||||
ThinHashMap<VarNode*, VarNode*> m_var_map; | |||||
VarNodeArray m_new_inputs; //!< setup by replace_vars | |||||
MemoryOptimizerHelper m_mem_opt; | |||||
ComputingGraphImpl* const m_owner_graph = nullptr; | |||||
CompNode::UnorderedMap<size_t> m_prev_min_bottleneck; | CompNode::UnorderedMap<size_t> m_prev_min_bottleneck; | ||||
/*! | |||||
* \brief replace input vars according to m_var_map, and store results in | |||||
* m_new_inputs; | |||||
* \return whether any var is changed | |||||
*/ | |||||
bool replace_vars(const VarNodeArray& inputs); | |||||
/*! | |||||
* \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||||
* m_var_map | |||||
* \return new operator | |||||
*/ | |||||
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, | |||||
bool recomp); | |||||
//! restore computing sequence and modify operator priority | //! restore computing sequence and modify operator priority | ||||
void reset_opr_seq(const OprNodeArray& oprseq); | void reset_opr_seq(const OprNodeArray& oprseq); | ||||
@@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory { | |||||
return std::make_shared<SplitPointSet::element_type>( | return std::make_shared<SplitPointSet::element_type>( | ||||
std::forward<Args>(args)...); | std::forward<Args>(args)...); | ||||
} | } | ||||
public: | |||||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||||
//! see memory_optimizer set_priority_before_opt | |||||
void set_priority_before_opt(const VarNodeArray& endpoints) { | |||||
m_mem_opt.set_priority_before_opt(endpoints); | |||||
} | |||||
//! see memory_optimizer restore_graph_option | |||||
void restore_graph_option() { | |||||
m_mem_opt.restore_graph_option(); | |||||
} | |||||
//! replace endpoint vars by the ones that require more computing | |||||
void modify_endpoint_vars(VarNodeArray& endpoints); | |||||
//! check whether actual opr_seq is what we expect; throw InternalError | |||||
void sanity_check(const OprNodeArray& opr_seq); | |||||
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||||
}; | }; | ||||
} // namespace cg | } // namespace cg | ||||