@@ -0,0 +1,161 @@ | |||
/** | |||
* \file src/core/impl/graph/seq_modifier_base.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./seq_modifier_base.h" | |||
#if MGB_ENABLE_SUBLINEAR | |||
using namespace mgb; | |||
using namespace cg; | |||
void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) { | |||
m_orig_opr_seq = &opr_seq; | |||
m_var_storage.clear(); | |||
m_seq.clear(); | |||
m_var_mempool.reorder_free(); | |||
m_opr_mempool.reorder_free(); | |||
m_nr_endpoint_oprs = 0; | |||
ThinHashMap<VarNode*, Var*> varmap; | |||
for (auto orig_opr : *m_orig_opr_seq) { | |||
auto time = m_seq.size(); | |||
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||
auto opr = m_seq.back().get(); | |||
m_nr_endpoint_oprs += opr->is_endpoint; | |||
for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||
continue; | |||
auto iter = varmap.find(dep.first); | |||
if (iter == varmap.end()) { | |||
// input var needs not to be considered | |||
continue; | |||
} | |||
auto ivar = iter->second; | |||
bool exist = false; | |||
for (auto i : opr->input) { | |||
if (i == ivar) { | |||
exist = true; | |||
break; | |||
} | |||
} | |||
if (exist) { | |||
// same var for different inputs | |||
continue; | |||
} | |||
opr->input.push_back(ivar); | |||
auto&& prev_rec = ivar->access_rec.back(); | |||
prev_rec.stride = time - prev_rec.opr->time; | |||
ivar->access_rec.emplace_back(opr); | |||
} | |||
for (auto i : orig_opr->output()) { | |||
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||
auto iter = var2memsize->find(i); | |||
if (iter == var2memsize->end()) { | |||
// some vars are ignored; see split_into_cn2oprseq() | |||
continue; | |||
} | |||
m_var_storage.emplace_back( | |||
m_var_mempool.alloc_unique(i, iter->second, opr)); | |||
auto ovar = m_var_storage.back().get(); | |||
varmap[i] = ovar; | |||
opr->output.push_back(ovar); | |||
} | |||
mgb_assert(!opr->output.empty()); | |||
} | |||
// remove unused output | |||
for (auto&& i : m_seq) { | |||
auto&& oarr = i->output; | |||
for (size_t j = 0; j < oarr.size();) { | |||
if (oarr[j]->access_rec.size() == 1) { | |||
std::swap(oarr[j], oarr.back()); | |||
oarr.pop_back(); | |||
} else | |||
++j; | |||
} | |||
} | |||
} | |||
bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) { | |||
m_new_inputs.assign(inputs.begin(), inputs.end()); | |||
bool changed = false; | |||
for (auto&& i : m_new_inputs) { | |||
auto iter = m_var_map.find(i); | |||
if (iter != m_var_map.end()) { | |||
i = iter->second; | |||
changed = true; | |||
} | |||
} | |||
return changed; | |||
} | |||
OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs( | |||
OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) { | |||
auto config = opr->config(); | |||
// update operator instance id to bybass the shallow copy's cache if | |||
// it's a dup-opr-copying due to discarding. | |||
// Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||
// because: | |||
// 0) recomp-opr would be copied iff its input vars is changed | |||
// 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||
// and config, we use instance id to differentiate them. | |||
config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt)); | |||
config.update_instance_id(reinterpret_cast<void*>( | |||
reinterpret_cast<size_t>(this) + | |||
((static_cast<size_t>(recomp) + 1) << 10) * recomp_cnt)); | |||
// Note: if all outputs of op were placed on the same comp_node, since its | |||
// stream maybe changed during seq_comp_node_opt, output's comp_node has | |||
// higher priority than opr->config() | |||
auto out_cn = opr->output(0)->comp_node(); | |||
for (auto i : opr->output()) { | |||
auto cn = i->comp_node(); | |||
if (out_cn != cn) { | |||
out_cn = {}; | |||
break; | |||
} | |||
} | |||
if (out_cn.valid()) | |||
config.comp_node(out_cn); | |||
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||
mgb_assert(opr_new != opr); | |||
auto&& out0 = opr->output(); | |||
auto&& out1 = opr_new->output(); | |||
mgb_assert(out0.size() == out1.size()); | |||
bool stream_changed = false; | |||
for (size_t i = 0; i < out0.size(); ++i) { | |||
auto &&cn0 = out0[i]->comp_node(), | |||
&&cn1 = out1[i]->comp_node(); | |||
if (cn0 != cn1) { | |||
mgb_assert(recomp); | |||
mgb_assert(cn0.locator().type == cn1.locator().type && | |||
cn0.locator().device == cn1.locator().device); | |||
out1[i]->comp_node(cn0); | |||
stream_changed = true; | |||
} | |||
m_var_map[out0[i]] = out1[i]; | |||
} | |||
if (stream_changed) { | |||
opr_new->on_output_comp_node_stream_changed(); | |||
} | |||
return opr_new; | |||
} | |||
#endif // MGB_ENABLE_SUBLINEAR | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,237 @@ | |||
/** | |||
* \file src/core/impl/graph/seq_modifier_base.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "./memory_optimizer.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/graph/cg.h" | |||
#include "megbrain/plugin/opr_footprint.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megbrain/system.h" | |||
#include "megbrain/utils/async_worker.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
#include "megbrain/utils/mempool.h" | |||
#include "megbrain/utils/timer.h" | |||
#if MGB_ENABLE_SUBLINEAR | |||
namespace mgb { | |||
namespace cg { | |||
/*! | |||
* \brief modifying computing sequence, with basically the same idea of Training | |||
* Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization | |||
*/ | |||
class SeqModifierBase { | |||
public: | |||
/*! | |||
* describes modifications that should be applied to an operator sequnce: | |||
* maps from an opr to the oprs that should be duplicated and inserted | |||
* before it. | |||
*/ | |||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||
struct Var; | |||
struct Opr; | |||
class ModifyActionPlannerBase { | |||
const SeqModifierBase* const m_par_modifier; | |||
const OprNodeArray* m_orig_opr_seq; | |||
MemPool<Var> m_var_mempool; | |||
MemPool<Opr> m_opr_mempool; | |||
std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||
std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||
size_t m_nr_endpoint_oprs = 0; | |||
public: | |||
//! special creation time used for oprs duplicated from others | |||
static constexpr size_t DUPOPR_TIME = | |||
std::numeric_limits<size_t>::max() - 1; | |||
const SeqModifierBase* const par_modifier() { | |||
return m_par_modifier; | |||
} | |||
const OprNodeArray* const orig_opr_seq() { | |||
return m_orig_opr_seq; | |||
} | |||
MemPool<Var>& var_mempool() { | |||
return m_var_mempool; | |||
} | |||
MemPool<Opr>& opr_mempool() { | |||
return m_opr_mempool; | |||
} | |||
std::vector<MemPool<Var>::UniquePtr>& var_storage() { | |||
return m_var_storage; | |||
} | |||
std::vector<MemPool<Opr>::UniquePtr>& seq() { | |||
return m_seq; | |||
} | |||
size_t& nr_endpoint_oprs() { | |||
return m_nr_endpoint_oprs; | |||
} | |||
ModifyActionPlannerBase(SeqModifierBase* par) | |||
: m_par_modifier{par} {} | |||
~ModifyActionPlannerBase() noexcept { | |||
m_opr_mempool.disable_freelist(); | |||
m_var_mempool.disable_freelist(); | |||
} | |||
//! init m_orig_opr_seq from opr_seq, should be called first. | |||
void init_seq(const OprNodeArray& opr_seq); | |||
}; | |||
SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {} | |||
MemoryOptimizerHelper& mem_opt() { | |||
return m_mem_opt; | |||
} | |||
ComputingGraphImpl* const owner_graph() { | |||
return m_owner_graph; | |||
} | |||
ThinHashMap<VarNode*, VarNode*>& var_map() { | |||
return m_var_map; | |||
} | |||
/*! | |||
* \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||
* m_var_map | |||
* \return new operator | |||
*/ | |||
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, bool recomp, size_t recomp_cnt=0); | |||
/*! | |||
* \brief replace input vars according to m_var_map, and store results in | |||
* m_new_inputs; | |||
* \return whether any var is changed | |||
*/ | |||
bool replace_vars(const VarNodeArray& inputs); | |||
//! see memory_optimizer set_priority_before_opt | |||
void set_priority_before_opt(const VarNodeArray& endpoints) { | |||
m_mem_opt.set_priority_before_opt(endpoints); | |||
} | |||
//! see memory_optimizer restore_graph_option | |||
void restore_graph_option() { | |||
m_mem_opt.restore_graph_option(); | |||
} | |||
private: | |||
MemoryOptimizerHelper m_mem_opt; | |||
ComputingGraphImpl* const m_owner_graph = nullptr; | |||
//! map from original var to replaced var | |||
ThinHashMap<VarNode*, VarNode*> m_var_map; | |||
VarNodeArray m_new_inputs; //!< setup by replace_vars | |||
}; | |||
struct SeqModifierBase::Opr { | |||
OperatorNodeBase* const orig_opr; | |||
std::vector<Var*> input, output; | |||
const size_t time; //!< index in opr sequence | |||
const bool is_endpoint; | |||
double estimate_compute_time = 1; | |||
//! input vars that have been discarded and need to be recomputed before | |||
//! this opr; for internal use by apply_discard_plan() | |||
std::vector<Var*> inputs_to_recompute; | |||
//! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||
//! [begin, end) interval of *time* for oprs belonging to this block; setup | |||
//! by make_discard_plan() | |||
size_t block_begin_time = 0, block_end_time = 0; | |||
Opr(OperatorNodeBase* opr, size_t t) | |||
: orig_opr{opr}, | |||
time{t}, | |||
is_endpoint{opr->owner_graph() | |||
->options() | |||
.opr_attribute.get_sublinear_memory_endpoint( | |||
opr)} {} | |||
}; | |||
struct SeqModifierBase::Var { | |||
VarNode* const orig_var; | |||
size_t size; //!< memory usage in bytes of this var | |||
size_t recomp_id = 0; | |||
double last_access_time = 0; | |||
//! write or read access of a var | |||
struct AccessRecord { | |||
Opr* const opr; | |||
const size_t time; | |||
size_t stride; | |||
explicit AccessRecord(Opr* o = nullptr) | |||
: opr{o}, time{o->time}, stride{0} {} | |||
}; | |||
//! access_rec[0] is the creation opr, and others are reader oprs | |||
std::vector<AccessRecord> access_rec; | |||
/*! | |||
* An index in access_rec | |||
* | |||
* if valid, then the var should be discarded after | |||
* discard_tailing_access->opr finishes | |||
* | |||
* setup by make_discard_plan | |||
*/ | |||
Maybe<size_t> discard_tailing_access; | |||
/*! | |||
* An index in access_rec | |||
* maintained during make_discard_plan(), for the next access relative to | |||
* current operator | |||
*/ | |||
Maybe<size_t> next_access; | |||
AccessRecord* visit_discard_tailing_access() { | |||
return discard_tailing_access.valid() | |||
? &access_rec.at(discard_tailing_access.val()) | |||
: nullptr; | |||
} | |||
AccessRecord* visit_next_access() { | |||
return next_access.valid() ? &access_rec.at(next_access.val()) | |||
: nullptr; | |||
} | |||
auto owner_opr() const { return access_rec[0].opr; } | |||
auto last_access_opr() const { return access_rec.back().opr; } | |||
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||
access_rec.emplace_back(opr); | |||
} | |||
}; | |||
} // namespace cg | |||
} // namespace mgb | |||
#endif // MGB_ENABLE_SUBLINEAR | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) { | |||
} | |||
} // namespace | |||
/* ====================== Abstract Opr & Var ====================== */ | |||
struct SeqModifierForSublinearMemory::Opr { | |||
OperatorNodeBase* const orig_opr; | |||
std::vector<Var*> input, output; | |||
const size_t time; //!< index in opr sequence | |||
const bool is_endpoint; | |||
//! input vars that have been discarded and need to be recomputed before | |||
//! this opr; for internal use by apply_discard_plan() | |||
std::vector<Var*> inputs_to_recompute; | |||
//! new oprs to be inserted before this opr; setup by apply_discard_plan() | |||
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before; | |||
//! [begin, end) interval of *time* for oprs belonging to this block; setup | |||
//! by make_discard_plan() | |||
size_t block_begin_time = 0, block_end_time = 0; | |||
Opr(OperatorNodeBase* opr, size_t t) | |||
: orig_opr{opr}, | |||
time{t}, | |||
is_endpoint{opr->owner_graph() | |||
->options() | |||
.opr_attribute.get_sublinear_memory_endpoint( | |||
opr)} {} | |||
}; | |||
struct SeqModifierForSublinearMemory::Var { | |||
//! write or read access of a var | |||
struct AccessRecord { | |||
Opr* const opr; | |||
const size_t time; | |||
size_t stride; //!< time distance until next read; 0 for last access | |||
explicit AccessRecord(Opr* o = nullptr) | |||
: opr{o}, time{o->time}, stride{0} {} | |||
}; | |||
VarNode* const orig_var; | |||
const size_t size; //!< memory usage in bytes of this var | |||
//! access_rec[0] is the creation opr, and others are reader oprs | |||
std::vector<AccessRecord> access_rec; | |||
/*! | |||
* An index in access_rec | |||
* | |||
* if valid, then the var should be discarded after | |||
* discard_tailing_access->opr finishes | |||
* | |||
* setup by make_discard_plan | |||
*/ | |||
Maybe<size_t> discard_tailing_access; | |||
/*! | |||
* An index in access_rec | |||
* maintained during make_discard_plan(), for the next access relative to | |||
* current operator | |||
*/ | |||
Maybe<size_t> next_access; | |||
AccessRecord* visit_discard_tailing_access() { | |||
return discard_tailing_access.valid() | |||
? &access_rec.at(discard_tailing_access.val()) | |||
: nullptr; | |||
} | |||
AccessRecord* visit_next_access() { | |||
return next_access.valid() ? &access_rec.at(next_access.val()) | |||
: nullptr; | |||
} | |||
auto owner_opr() const { return access_rec[0].opr; } | |||
auto last_access_opr() const { return access_rec.back().opr; } | |||
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { | |||
access_rec.emplace_back(opr); | |||
} | |||
}; | |||
/* ====================== ModifyActionPlanner ====================== */ | |||
class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||
//! special creation time used for oprs duplicated from others | |||
static constexpr size_t DUPOPR_TIME = | |||
std::numeric_limits<size_t>::max() - 1; | |||
class SeqModifierForSublinearMemory::ModifyActionPlanner : public ModifyActionPlannerBase { | |||
using VarArray = std::vector<Var*>; | |||
using VarSet = ThinHashSet<Var*>; | |||
using OprArray = std::vector<Opr*>; | |||
const SeqModifierForSublinearMemory* const m_par_modifier; | |||
const OprNodeArray* m_orig_opr_seq; | |||
MemPool<Var> m_var_mempool; | |||
MemPool<Opr> m_opr_mempool; | |||
std::vector<MemPool<Var>::UniquePtr> m_var_storage; | |||
std::vector<MemPool<Opr>::UniquePtr> m_seq; | |||
size_t m_nr_endpoint_oprs = 0; | |||
VarSet m_prev_block_discard_vars; | |||
std::vector<OprArray> m_blocks; | |||
SeqModifyAction m_action; | |||
//! split_point_set to block | |||
void split_into_blocks(const SplitPointSet& split_point_set); | |||
@@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner { | |||
public: | |||
ModifyActionPlanner(SeqModifierForSublinearMemory* par) | |||
: m_par_modifier{par} {} | |||
~ModifyActionPlanner() noexcept { | |||
m_opr_mempool.disable_freelist(); | |||
m_var_mempool.disable_freelist(); | |||
} | |||
//! init m_orig_opr_seq from opr_seq, should be called first. | |||
void init_seq(const OprNodeArray& opr_seq); | |||
: ModifyActionPlannerBase{par} {} | |||
//! generate split point set from thresh | |||
SplitPointSet get_split_point_set(size_t block_size_thresh); | |||
@@ -213,7 +113,7 @@ public: | |||
void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action( | |||
SeqModifyAction& action) { | |||
action.clear(); | |||
for (auto&& opr : m_seq) { | |||
for (auto&& opr : seq()) { | |||
auto&& arr = opr->oprs_insert_before; | |||
if (arr.empty()) | |||
continue; | |||
@@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
cur_block_alive_vars.clear(); | |||
}; | |||
for (size_t i = 0; i < m_seq.size(); ++i) { | |||
auto opr = m_seq[i].get(); | |||
for (size_t i = 0; i < seq().size(); ++i) { | |||
auto opr = seq()[i].get(); | |||
for (auto i : opr->output) | |||
add_alive(i); | |||
@@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
remove_alive(i); | |||
} | |||
if (i + 1 < m_seq.size() && (cur_block_usage < block_size_thresh || | |||
(m_nr_endpoint_oprs && !opr->is_endpoint))) | |||
if (i + 1 < seq().size() && (cur_block_usage < block_size_thresh || | |||
(nr_endpoint_oprs() && !opr->is_endpoint))) | |||
continue; | |||
flush_block_member(i); | |||
@@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( | |||
return split_point_set; | |||
} | |||
void SeqModifierForSublinearMemory::ModifyActionPlanner::init_seq( | |||
const OprNodeArray& opr_seq) { | |||
m_orig_opr_seq = &opr_seq; | |||
m_var_storage.clear(); | |||
m_seq.clear(); | |||
m_var_mempool.reorder_free(); | |||
m_opr_mempool.reorder_free(); | |||
m_nr_endpoint_oprs = 0; | |||
ThinHashMap<VarNode*, Var*> varmap; | |||
for (auto orig_opr : *m_orig_opr_seq) { | |||
auto time = m_seq.size(); | |||
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); | |||
auto opr = m_seq.back().get(); | |||
m_nr_endpoint_oprs += opr->is_endpoint; | |||
for (auto&& dep : orig_opr->node_prop().dep_map()) { | |||
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) | |||
continue; | |||
auto iter = varmap.find(dep.first); | |||
if (iter == varmap.end()) { | |||
// input var needs not to be considered | |||
continue; | |||
} | |||
auto ivar = iter->second; | |||
bool exist = false; | |||
for (auto i : opr->input) { | |||
if (i == ivar) { | |||
exist = true; | |||
break; | |||
} | |||
} | |||
if (exist) { | |||
// same var for different inputs | |||
continue; | |||
} | |||
opr->input.push_back(ivar); | |||
auto&& prev_rec = ivar->access_rec.back(); | |||
prev_rec.stride = time - prev_rec.opr->time; | |||
ivar->access_rec.emplace_back(opr); | |||
} | |||
for (auto i : orig_opr->output()) { | |||
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); | |||
auto iter = var2memsize->find(i); | |||
if (iter == var2memsize->end()) { | |||
// some vars are ignored; see split_into_cn2oprseq() | |||
continue; | |||
} | |||
m_var_storage.emplace_back( | |||
m_var_mempool.alloc_unique(i, iter->second, opr)); | |||
auto ovar = m_var_storage.back().get(); | |||
varmap[i] = ovar; | |||
opr->output.push_back(ovar); | |||
} | |||
mgb_assert(!opr->output.empty()); | |||
} | |||
// remove unused output | |||
for (auto&& i : m_seq) { | |||
auto&& oarr = i->output; | |||
for (size_t j = 0; j < oarr.size();) { | |||
if (oarr[j]->access_rec.size() == 1) { | |||
std::swap(oarr[j], oarr.back()); | |||
oarr.pop_back(); | |||
} else | |||
++j; | |||
} | |||
} | |||
} | |||
size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | |||
calc_bottleneck_from_discard_plan() { | |||
size_t cur_usage = 0, max_usage = 0; | |||
@@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: | |||
++time; | |||
}; | |||
for (auto&& opr : m_seq) { | |||
for (auto&& opr : seq()) { | |||
for (auto&& i : opr->oprs_insert_before) | |||
process_opr(i.get()); | |||
process_opr(opr.get()); | |||
@@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
mgb_assert(opr->time < block_end); | |||
auto new_opr_storage = m_opr_mempool.alloc_unique( | |||
auto new_opr_storage = opr_mempool().alloc_unique( | |||
opr->orig_opr, static_cast<size_t>(DUPOPR_TIME)); | |||
auto new_opr = new_opr_storage.get(); | |||
@@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
Var* new_var = nullptr; | |||
for (auto i : opr->output) { | |||
auto&& ovar = m_var_mempool.alloc_unique(i->orig_var, i->size, | |||
auto&& ovar = var_mempool().alloc_unique(i->orig_var, i->size, | |||
new_opr); | |||
new_opr->output.push_back(ovar.get()); | |||
if (i == var) | |||
@@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
auto ins = var_map.insert({i, ovar.get()}); | |||
mgb_assert(ins.second); | |||
m_var_storage.emplace_back(std::move(ovar)); | |||
var_storage().emplace_back(std::move(ovar)); | |||
} | |||
mgb_assert(new_var); | |||
return new_var; | |||
@@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { | |||
add_dep(var); | |||
}; | |||
for (auto&& _raw_opr : m_seq) { | |||
for (auto&& _raw_opr : seq()) { | |||
auto opr = _raw_opr.get(); | |||
for (auto i : opr->inputs_to_recompute) | |||
@@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||
m_blocks.clear(); | |||
std::vector<Opr*> cur_block_member; | |||
size_t i, j; | |||
for (i = j = 0; i < m_seq.size() && j < split_point_set->size(); ++i) { | |||
auto opr = m_seq[i].get(); | |||
for (i = j = 0; i < seq().size() && j < split_point_set->size(); ++i) { | |||
auto opr = seq()[i].get(); | |||
cur_block_member.push_back(opr); | |||
if (i != split_point_set->at(j)) | |||
continue; | |||
@@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( | |||
cur_block_member.clear(); | |||
j++; | |||
} | |||
mgb_assert(i >= m_seq.size()); | |||
mgb_assert(i >= seq().size()); | |||
mgb_assert(j >= split_point_set->size()); | |||
} | |||
@@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()( | |||
} | |||
void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | |||
m_var_map.clear(); | |||
var_map().clear(); | |||
m_opr2replace_info.clear(); | |||
auto config = | |||
MemoryOptimizerHelper::SubGraphConfig() | |||
@@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { | |||
.add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) | |||
.add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE); | |||
auto cn2oprseq = m_mem_opt.split_into_cn2oprseq(oprseq, config); | |||
auto cn2oprseq = mem_opt().split_into_cn2oprseq(oprseq, config); | |||
if (cn2oprseq->empty()) { | |||
// empty graph | |||
@@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||
// each operator should be set no more than once | |||
auto set_priority = [&](OperatorNodeBase* opr) { | |||
mgb_assert(modified_opr.insert(opr).second); | |||
m_mem_opt.set_priority(opr, cur_priority++); | |||
mem_opt().set_priority(opr, cur_priority++); | |||
}; | |||
auto on_opr_visited = [&](OperatorNodeBase* opr) { | |||
@@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, | |||
mgb_assert(action.empty()); | |||
} | |||
bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { | |||
m_new_inputs.assign(inputs.begin(), inputs.end()); | |||
bool changed = false; | |||
for (auto&& i : m_new_inputs) { | |||
auto iter = m_var_map.find(i); | |||
if (iter != m_var_map.end()) { | |||
i = iter->second; | |||
changed = true; | |||
} | |||
} | |||
return changed; | |||
} | |||
OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | |||
OperatorNodeBase* opr, bool recomp) { | |||
auto config = opr->config(); | |||
// update operator instance id to bybass the shallow copy's cache if | |||
// it's a dup-opr-copying due to discarding. | |||
// Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||
// because: | |||
// 0) recomp-opr would be copied iff its input vars is changed | |||
// 1) some pair of recomp-opr and dup-opr have the same inputs, params | |||
// and config, we use instance id to differentiate them. | |||
config.name(opr->name() + (recomp ? ":recomp" : ":dup")); | |||
if (!recomp) { | |||
config.update_instance_id(this); | |||
} | |||
// Note: if all outputs of op were placed on the same comp_node, since its | |||
// stream maybe changed during seq_comp_node_opt, output's comp_node has | |||
// higher priority than opr->config() | |||
auto out_cn = opr->output(0)->comp_node(); | |||
for (auto i : opr->output()) { | |||
auto cn = i->comp_node(); | |||
if (out_cn != cn) { | |||
out_cn = {}; | |||
break; | |||
} | |||
} | |||
if (out_cn.valid()) | |||
config.comp_node(out_cn); | |||
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); | |||
mgb_assert(opr_new != opr); | |||
auto&& out0 = opr->output(); | |||
auto&& out1 = opr_new->output(); | |||
mgb_assert(out0.size() == out1.size()); | |||
bool stream_changed = false; | |||
for (size_t i = 0; i < out0.size(); ++i) { | |||
auto &&cn0 = out0[i]->comp_node(), | |||
&&cn1 = out1[i]->comp_node(); | |||
if (cn0 != cn1) { | |||
mgb_assert(recomp); | |||
mgb_assert(cn0.locator().type == cn1.locator().type && | |||
cn0.locator().device == cn1.locator().device); | |||
out1[i]->comp_node(cn0); | |||
stream_changed = true; | |||
} | |||
m_var_map[out0[i]] = out1[i]; | |||
} | |||
if (stream_changed) { | |||
opr_new->on_output_comp_node_stream_changed(); | |||
} | |||
return opr_new; | |||
} | |||
void SeqModifierForSublinearMemory::modify_endpoint_vars( | |||
VarNodeArray& endpoints) { | |||
auto comp_seq = MemoryOptimizerHelper::CompSeq(m_owner_graph, endpoints); | |||
auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints); | |||
reset_opr_seq(*comp_seq.m_seq); | |||
for (auto&& i : endpoints) { | |||
auto iter = m_var_map.find(i); | |||
if (iter != m_var_map.end()) { | |||
auto iter = var_map().find(i); | |||
if (iter != var_map().end()) { | |||
i = iter->second; | |||
} | |||
} | |||
@@ -1357,8 +1115,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | |||
ComputingGraphImpl* owner, Config* config_p) | |||
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||
: SeqModifierBase(owner), m_config(config_p) {} | |||
#endif // !MGB_ENABLE_SUBLINEAR | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "./memory_optimizer.h" | |||
#include "./seq_modifier_base.h" | |||
#include "megbrain/graph/cg.h" | |||
#include "megbrain/utils/async_worker.h" | |||
@@ -23,28 +24,31 @@ namespace cg { | |||
* \brief modifying computing sequence, with basically the same idea of Training | |||
* Deep Nets with Sublinear Memory Cost | |||
*/ | |||
class SeqModifierForSublinearMemory { | |||
/*! | |||
* describes modifications that should be applied to an operator sequnce: | |||
* maps from an opr to the oprs that should be duplicated and inserted | |||
* before it. | |||
*/ | |||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||
class SeqModifierForSublinearMemory : public SeqModifierBase { | |||
//! Config options | |||
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | |||
Config* m_config; | |||
public: | |||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||
//! replace endpoint vars by the ones that require more computing | |||
void modify_endpoint_vars(VarNodeArray& endpoints); | |||
//! check whether actual opr_seq is what we expect; throw InternalError | |||
void sanity_check(const OprNodeArray& opr_seq); | |||
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||
private: | |||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||
//! get modifications to be taken under some specific constraints | |||
class ModifyActionPlanner; | |||
//! search best modify action for opr seq on a single comp node | |||
class ActionSearcherSingleCN; | |||
struct Opr; | |||
struct Var; | |||
struct InternalDeleter { | |||
void operator()(ActionSearcherSingleCN*) const; | |||
void operator()(ModifyActionPlanner*) const; | |||
@@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory { | |||
//! thread pool to run ModifyActionPlanner | |||
FutureThreadPool<void> m_planner_thread_pool; | |||
//! map from original var to replaced var | |||
ThinHashMap<VarNode*, VarNode*> m_var_map; | |||
VarNodeArray m_new_inputs; //!< setup by replace_vars | |||
MemoryOptimizerHelper m_mem_opt; | |||
ComputingGraphImpl* const m_owner_graph = nullptr; | |||
CompNode::UnorderedMap<size_t> m_prev_min_bottleneck; | |||
/*! | |||
* \brief replace input vars according to m_var_map, and store results in | |||
* m_new_inputs; | |||
* \return whether any var is changed | |||
*/ | |||
bool replace_vars(const VarNodeArray& inputs); | |||
/*! | |||
* \brief copy opr and set inputs to m_new_inputs, and add outputs in | |||
* m_var_map | |||
* \return new operator | |||
*/ | |||
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, | |||
bool recomp); | |||
//! restore computing sequence and modify operator priority | |||
void reset_opr_seq(const OprNodeArray& oprseq); | |||
@@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory { | |||
return std::make_shared<SplitPointSet::element_type>( | |||
std::forward<Args>(args)...); | |||
} | |||
public: | |||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||
//! see memory_optimizer set_priority_before_opt | |||
void set_priority_before_opt(const VarNodeArray& endpoints) { | |||
m_mem_opt.set_priority_before_opt(endpoints); | |||
} | |||
//! see memory_optimizer restore_graph_option | |||
void restore_graph_option() { | |||
m_mem_opt.restore_graph_option(); | |||
} | |||
//! replace endpoint vars by the ones that require more computing | |||
void modify_endpoint_vars(VarNodeArray& endpoints); | |||
//! check whether actual opr_seq is what we expect; throw InternalError | |||
void sanity_check(const OprNodeArray& opr_seq); | |||
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck(); | |||
}; | |||
} // namespace cg | |||