|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- /**
- * \file src/core/impl/graph/helper.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "megbrain/graph/helper.h"
- #include "megbrain/gopt/framework.h"
- #include "megbrain/opr/utility.h"
- #include "megbrain/serialization/opr_shallow_copy.h"
- #include "./cg_impl.h"
-
- using namespace mgb;
- using namespace cg;
-
- /* =================== global functions =================== */
-
- CompNode::UnorderedSet cg::get_opr_comp_node_set(OperatorNodeBase *opr) {
- CompNode::UnorderedSet rst;
- for (auto i: opr->output())
- rst.insert(i->comp_node());
- if (opr->node_prop().contain(
- OperatorNodeBase::NodeProp::Flag::SINGLE_COMP_NODE))
- mgb_assert(rst.size() == 1);
- return rst;
- }
-
- bool cg::is_all_input_static_storage(OperatorNodeBase* opr) {
- for (auto&& i : opr->node_prop().dep_map())
- if (i.second != OperatorNodeBase::NodeProp::DepType::DEV_COMP_ORDER &&
- !is_static_var_storage(i.first))
- return false;
- return true;
- }
-
- VarNodeArray cg::to_var_node_array(const SymbolVarArray& symbol_var_array) {
- VarNodeArray var_node_array(symbol_var_array.size());
- for (size_t i = 0; i < symbol_var_array.size(); ++i) {
- var_node_array[i] = symbol_var_array[i].node();
- }
- return var_node_array;
- }
-
- SymbolVarArray cg::to_symbol_var_array(const VarNodeArray& var_node_array) {
- SymbolVarArray symbol_var_array(var_node_array.size());
- for (size_t i = 0; i < var_node_array.size(); ++i) {
- symbol_var_array[i] = var_node_array[i];
- }
- return symbol_var_array;
- }
-
- std::string cg::dump_var_info(const VarNodeArrayView &vars) {
- std::string rst;
- int idx = 0;
- for (auto i: vars) {
- if (!rst.empty())
- rst.append(" ");
- auto opr = i->owner_opr();
- if (vars.size() > 1)
- rst.append(ssprintf("%d=", idx ++));
- bool valid = i->dev_tensor_valid();
- auto slot = find(opr->output(), i) - opr->output().begin();
- auto &&it = i->owner_graph()->static_infer_manager().get_infer_type(i);
- rst.append(ssprintf(
- "{id:%zu, %s:%s, %s, "
- "owner:%s{%s}, name:%s, slot:%td, %s, %c, %d, %d}",
- i->id(),
- valid ? "layout": "shape",
- valid ? i->layout().to_string().c_str() :
- i->shape().to_string().c_str(),
- i->dtype().name(),
- opr->cname(), opr->dyn_typeinfo()->name,
- i->cname(),
- slot,
- i->comp_node().to_string().c_str(),
- cg::is_static_var_storage(i) ? 's' : 'd',
- static_cast<int>(it.shape), static_cast<int>(it.value)
- ));
- }
- return rst;
- }
-
- SymbolVar cg::grad(SymbolVar target, SymbolVar wrt, bool warn_mid_wrt,
- bool return_zero_for_nodep) {
- return grad(target, SymbolVarArray{wrt},
- warn_mid_wrt, return_zero_for_nodep)[0];
- }
-
- SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_wrt,
- bool return_zero_for_nodep) {
- #if MGB_ENABLE_GRAD
- auto target = target_.node();
- SymbolVarArray grads;
- grads.reserve(wrts_.size());
- VarNodeArray dest_vars;
- auto&& graph = target->owner_graph();
- auto&& eager_mgr = ComputingGraphImpl::downcast(graph)->eager_eval_manager();
- auto&& grad_mgr = ComputingGraphImpl::downcast(graph)->grad_manager();
- bool already_recorded = eager_mgr.enter_record_mode();
- for (auto&& wrt_ : wrts_) {
- auto wrt = wrt_.node();
- if (warn_mid_wrt && wrt->owner_opr()->input().size()) {
- mgb_log_warn("taking gradient with respect to an intermediate node may "
- "produce incorrect results (for example, when it is produced "
- "by subtensor); node: %s",
- cg::dump_var_info({wrt}).c_str());
- }
- mgb_throw_if(graph != wrt->owner_graph(), GraphError,
- "target and wrt must belong to the same graph");
- auto rst = grad_mgr.grad(target, wrt);
- if (!rst && return_zero_for_nodep) {
- mgb_log_warn("target node (%s) does not depend on wrt node (%s), "
- "return zeros as grad", cg::dump_var_info({target}).c_str(),
- cg::dump_var_info({wrt}).c_str());
- rst = (wrt_ * 0).node();
- }
- if (rst)
- dest_vars.push_back(rst);
- grads.emplace_back(rst);
- }
- if (!already_recorded && eager_mgr.enabled()) {
- eager_mgr.flush_record_oprs(dest_vars);
- grad_mgr.clean_cache();
- }
- return grads;
- #else
- MGB_MARK_USED_VAR(target_);
- MGB_MARK_USED_VAR(wrts_);
- MGB_MARK_USED_VAR(warn_mid_wrt);
- MGB_MARK_USED_VAR(return_zero_for_nodep);
- mgb_throw(MegBrainError, "grad disabled at compile time");
- #endif
- }
-
- SymbolVar cg::current_grad_target(ComputingGraph &graph) {
- #if MGB_ENABLE_GRAD
- auto var = ComputingGraphImpl::downcast(&graph)->grad_manager(
- ).current_grad_target();
- mgb_throw_if(!var, GraphError, "current_grad_target() called outside "
- "grad computing environment");
- return var;
- #else
- MGB_MARK_USED_VAR(graph);
- mgb_throw(MegBrainError, "grad disabled at compile time");
- #endif
- }
-
- SymbolVarArray cg::get_dest_vars_with_extra_deps(
- const SymbolVarArray& dest_vars, SpecialOprStat* sopr_stat) {
- return ExtraDependencyMerger{sopr_stat}.add(dest_vars);
- }
-
- namespace {
-
- SymbolVarArray replace_vars_internal(
- const SymbolVarArray& dest,
- thin_function<void(OperatorNodeBase*,
- gopt::SubGraph::Rewriter&)> on_opr) {
- if (dest.empty()) {
- return dest;
- }
-
- // check that they belong to the same graph
- mgb_assert(dest[0].node());
- auto og = dest[0].node()->owner_graph();
- for (auto i : dest) {
- mgb_assert(i.node() && i.node()->owner_graph() == og);
- }
-
- auto dest_with_extra_deps = get_dest_vars_with_extra_deps(dest);
-
- // do the replace
- gopt::SubGraph graph{dest_with_extra_deps};
- auto rewriter = graph.make_rewriter();
- graph.iter([&](OperatorNodeBase* opr){ on_opr(opr, rewriter); });
-
- auto new_og = rewriter.get_var(dest[0].node())->owner_graph();
- auto &&old_extra_vardeps = og->options().extra_vardeps,
- &&new_extra_vardeps = new_og->options().extra_vardeps;
- auto on_opr_replace_dep = [&](OperatorNodeBase* opr) {
- for (auto i : opr->output()) {
- auto new_node = rewriter.get_var(i);
- auto iter = old_extra_vardeps.find(i);
- if (iter == old_extra_vardeps.end())
- continue;
-
- if (new_node == i) {
- for (const auto& dep : iter->second) {
- auto new_dep = rewriter.get_var(dep);
- mgb_assert(dep == new_dep,
- "var %s is not replaced, but its extra "
- "dependency %s is replaced by %s ",
- cg::dump_var_info({i}).c_str(),
- cg::dump_var_info({dep}).c_str(),
- cg::dump_var_info({new_dep}).c_str());
- }
- } else {
- auto& new_deps = new_extra_vardeps[new_node];
- for (const auto& dep : iter->second) {
- new_deps.push_back(rewriter.get_var(dep));
- }
- }
- }
- };
-
- if (dest_with_extra_deps.size() != dest.size())
- graph.iter(on_opr_replace_dep);
-
- rewriter.apply_inplace();
- auto ret = graph.endpoint_vars();
- ret.resize(dest.size());
- return ret;
- }
- } //namespace
-
- SymbolVarArray cg::replace_oprs(
- const SymbolVarArray& dest,
- const ThinHashMap<OperatorNodeBase*, OperatorNodeBase*>& oprmap) {
- if (oprmap.empty() || dest.empty()) {
- return dest;
- }
-
- mgb_assert(dest[0].node());
- auto graph = dest[0].node()->owner_graph();
- for (auto i : dest) {
- mgb_assert(i.node() && i.node()->owner_graph() == graph,
- "Dest should all be in same graph");
- }
- for (auto&& i : oprmap) {
- mgb_assert(i.first->owner_graph() == graph &&
- i.second->owner_graph() == graph,
- "Original and dest operators in oprmap should all be in "
- "same graph");
- }
-
- ThinHashMap<SymbolVar, SymbolVar> varmap;
- for (auto&& p : oprmap) {
- const auto& outputs0 = p.first->usable_output();
- const auto& outputs1 = p.second->usable_output();
- mgb_assert(outputs0.size() == outputs1.size(),
- "Number of outputs differ: old operator %s has %zu outputs, "
- "while new operator %s has %zu outputs.",
- p.first->name().c_str(), outputs0.size(),
- p.second->name().c_str(), outputs1.size());
- for (size_t i = 0; i < outputs0.size(); i++) {
- varmap[outputs0[i]] = outputs1[i];
- }
- }
- return replace_vars(dest, varmap);
- }
-
- SymbolVarArray cg::replace_vars(
- const SymbolVarArray& dest,
- const ThinHashMap<SymbolVar, SymbolVar>& varmap) {
- if (varmap.empty())
- return dest;
- auto og = dest[0].node()->owner_graph();
- for (auto&& i : varmap) {
- mgb_assert(i.first.node() && i.second.node() &&
- i.first.node()->owner_graph() == og &&
- i.second.node()->owner_graph() == og);
- }
- auto on_opr = [&](OperatorNodeBase* opr,
- gopt::SubGraph::Rewriter& rewriter) {
- for (auto i : opr->output()) {
- auto viter = varmap.find(i);
- if (viter != varmap.end()) {
- rewriter.replace_var(i, viter->second.node(), nullptr);
- }
- }
- rewriter.auto_replace_outputs(opr);
- };
- return replace_vars_internal(dest, on_opr);
- }
-
- SymbolVarArray cg::replace_vars_comp_graph(
- const SymbolVarArray &dest, ComputingGraph* new_graph) {
- ComputingGraph *orig_graph = dest[0].node()->owner_graph();
- mgb_assert(new_graph != orig_graph);
- auto on_opr = [&](OperatorNodeBase* opr,
- gopt::SubGraph::Rewriter& rewriter) {
- OperatorNodeBase* new_opr;
- if (opr->input().size()) {
- rewriter.auto_replace_outputs(opr);
- } else {
- mgb_assert(opr->owner_graph() != new_graph);
- new_opr = serialization::copy_opr_shallow(
- *opr, {}, opr->config(), {new_graph});
- auto &&out0 = opr->output(), &&out1 = new_opr->output();
- mgb_assert(out0.size() == out1.size());
- for (size_t i = 0; i < out0.size(); ++ i) {
- rewriter.replace_var(out0[i], out1[i], "replace comp graph.");
- }
- }
- };
- return replace_vars_internal(dest, on_opr);
- }
-
- SymbolVarArray cg::find_h2d(const SymbolVarArray& dest) {
- mgb_assert(!dest.empty());
- SymbolVarArray h2d;
- auto on_opr = [&](OperatorNodeBase* opr) {
- if (opr->same_type<opr::Host2DeviceCopy>()) {
- h2d.emplace_back(opr->output(0));
- }
- };
-
- // check that they belong to the same graph
- mgb_assert(dest[0].node());
- auto og = dest[0].node()->owner_graph();
- for (auto i : dest) {
- mgb_assert(i.node() && i.node()->owner_graph() == og);
- }
-
- auto dest_with_extra_deps = get_dest_vars_with_extra_deps(dest);
-
- gopt::SubGraph graph{dest_with_extra_deps};
- graph.iter([&](OperatorNodeBase* opr){ on_opr(opr); });
-
- return h2d;
- }
-
- OperatorNodeBase* cg::get_opr_root_source_opr(OperatorNodeBase *opr) {
- auto &&attr = opr->node_prop().attribute();
- if (!attr.src_opr)
- return opr;
- auto orig = attr.src_opr;
- mgb_assert(orig != opr);
- return attr.src_opr = get_opr_root_source_opr(orig);
- }
-
- cg::MemPlanIntersectionType cg::get_mem_plan_intersection_type(
- VarNode* a, VarNode *b) {
- auto &&m0 = a->mem_plan(), &&m1 = b->mem_plan();
- if (&m0.chunk() != &m1.chunk())
- return MemPlanIntersectionType::DISJOINT;
-
- auto get_real_span = [](const MemAllocPlan &p) {
- auto span = p.layout().span();
- return std::make_pair(span.low_byte + p.offset_in_chunk_byte(),
- span.high_byte + p.offset_in_chunk_byte());
- };
- auto s0 = get_real_span(m0), s1 = get_real_span(m1);
- if (s0.first == s1.first && s0.second == s1.second)
- return MemPlanIntersectionType::IDENTICAL;
- if (s0.second <= s1.first || s1.second <= s0.first)
- return MemPlanIntersectionType::DISJOINT;
- return MemPlanIntersectionType::OVERLAP;
- }
-
- void cg::request_fwd_in2out_writable_if_no_mem_ovelap(
- OperatorNodeBase *opr, size_t inp, size_t out) {
- auto ivar = opr->input(inp), ovar = opr->output(out);
- if (is_static_var_storage(ivar) != is_static_var_storage(ovar)) {
- // If ovar is dynamic but there are other outputs of opr with static
- // storage, this function would be called during the static allocation
- // phase, and get_mem_plan_intersection_type() would fail.
- // So we just return here
- return;
- }
-
- auto &&dep_map = opr->node_prop().dep_map();
- using NP = OperatorNodeBase::NodeProp;
- mgb_assert(NP::is_device_value_dep(dep_map.at(ivar)));
-
- if (!ivar->layout().is_contiguous())
- return;
-
- using IT = MemPlanIntersectionType;
- for (size_t i = 0; i < opr->input().size(); ++ i) {
- auto iv = opr->input()[i];
- if (i != inp && NP::is_device_value_dep(dep_map.at(iv)) &&
- get_mem_plan_intersection_type(iv, ivar) != IT::DISJOINT) {
- return;
- }
- }
- ovar->set_fwd_in2out_writable(ivar);
- }
-
- void cg::add_workspace_output(OperatorNodeBase *opr) {
- opr->add_output("workspace")
- ->add_flag(VarNode::Flag::VOLATILE_CONTENT)
- .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
- .dtype(dtype::Byte());
- }
-
- void cg::copy_shape_to_tensor_value(
- DeviceTensorND &dest, const TensorShape &shp) {
-
- dest.comp_node(CompNode::default_cpu()).
- dtype(dtype::Int32()).
- resize({std::max<size_t>(1, shp.ndim)});
- auto ptr = dest.ptr<dt_int32>();
- if (!shp.ndim)
- ptr[0] = 0;
- else {
- for (size_t i = 0; i < shp.ndim; i ++)
- ptr[i] = shp.shape[i];
- }
- }
-
- void cg::copy_tensor_value_to_shape(
- TensorShape &dest, const DeviceTensorND &val) {
- constexpr size_t MAX_DT_SIZE = 4;
- mgb_assert(val.dtype().size() <= MAX_DT_SIZE);
-
- mgb_assert(val.shape().ndim == 1, "shape tensor must be 1-dim, got %s",
- val.shape().to_string().c_str());
- mgb_assert(val.comp_node().device_type() == CompNode::DeviceType::CPU);
- dest.ndim = val.shape(0);
- mgb_assert(dest.ndim <= TensorShape::MAX_NDIM);
- auto vptr = val.raw_ptr();
- dt_byte contig[MAX_DT_SIZE * TensorShape::MAX_NDIM];
- if (val.layout().stride[0] != 1) {
- auto dst = contig;
- auto dst_strd = val.dtype().size();
- auto src = val.raw_ptr();
- auto src_strd = val.layout().stride[0] * dst_strd;
- for (size_t i = 0; i < dest.ndim; ++ i) {
- memcpy(dst, src, dst_strd);
- dst += dst_strd;
- src += src_strd;
- }
- vptr = contig;
- }
- static_cast_dtype_safe(dest.shape, val.dtype(), vptr, dest.ndim);
- }
-
- SymbolVar cg::var_from_tensor_shape(
- ComputingGraph &graph, const OperatorNodeConfig &config,
- const char *opr_name, const TensorShape &shape) {
- auto cn = config.get_single_comp_node();
- mgb_throw_if(!cn.valid(), GraphError,
- "must specify comp node in %s config", opr_name);
- DeviceTensorND dv;
- copy_shape_to_tensor_value(dv, shape);
- HostTensorND hv{cn};
- hv.copy_from(dv);
- return opr::ImmutableTensor::make(graph, hv);
- }
-
- /* =================== DepOprIter =================== */
- void cg::DepOprIter::push_stack(OperatorNodeBase* opr) {
- if (m_visited.insert(opr).second) {
- if (m_extra_dep) {
- auto it = m_extra_dep->find(opr);
- if (it != m_extra_dep->end()) {
- m_stack.push_back({opr, opr->input().data(), it->second.data(),
- 0, opr->input().size(), it->second.size()});
- return;
- }
- }
- m_stack.push_back(
- {opr, opr->input().data(), nullptr, 0, opr->input().size(), 0});
- }
- }
-
- void cg::DepOprIter::add(OperatorNodeBase *dest) {
- if (!m_owner_graph) {
- m_owner_graph = dest->owner_graph();
- } else {
- mgb_assert(m_owner_graph == dest->owner_graph(),
- "dest oprs belong to different graphs");
- }
- push_stack(dest);
- while (!m_stack.empty()) {
- auto &&frame = m_stack.back();
- if (frame.inp_idx == frame.nr_input + frame.nr_extra_dep) {
- m_cb(frame.opr);
- m_stack.pop_back();
- } else {
- VarNode* inp = nullptr;
- if (frame.inp_idx < frame.nr_input) {
- inp = frame.inputs[frame.inp_idx ++];
- } else {
- inp = frame.extra_deps[frame.inp_idx - frame.nr_input];
- frame.inp_idx++;
- }
- push_stack(inp->owner_opr());
- }
- }
- }
-
-
- /* =================== InterGraphVarTransformer =================== */
-
- MGB_TYPEINFO_OBJ_IMPL(InterGraphVarTransformer);
-
- void InterGraphVarTransformer::register_to(ComputingGraph *dest,
- const ComputingGraph *src, const TransFunc &trans) {
- mgb_assert(dest && src && trans);
- mgb_assert(dest->id() > src->id(),
- "inter-graph trans only allowed from old graph to new graph");
- auto mk = []() {
- return std::shared_ptr<InterGraphVarTransformer>(
- new InterGraphVarTransformer);
- };
- auto ptr = dest->options().user_data.
- get_user_data_or_create<InterGraphVarTransformer>(mk);
- mgb_assert(!ptr->m_trans_func, "InterGraphVarTransformer on graph #%zu{%p} "
- "already registered", dest->id(), dest);
- ptr->m_graph_dest = dest;
- ptr->m_graph_src = src;
- ptr->m_trans_func = trans;
- }
-
- const InterGraphVarTransformer*
- InterGraphVarTransformer::get(const ComputingGraph &graph) {
- auto ret = graph.options().user_data.get_user_data<
- InterGraphVarTransformer>();
- if (!ret.second)
- return nullptr;
- mgb_assert(ret.second == 1);
- return ret.first[0];
- }
-
- VarNode* InterGraphVarTransformer::trans(VarNode *src) const {
- if (src->owner_graph() != m_graph_src) {
- auto strans = get(*m_graph_src);
- mgb_throw_if(!strans, GraphError,
- "no InterGraphVarTransformer registered for var %s, "
- "which belongs to graph #%zu{%p}",
- dump_var_info({src}).c_str(),
- src->owner_graph()->id(), src->owner_graph());
- src = strans->trans(src);
- }
- auto ret = m_trans_func(src);
- mgb_assert(ret && ret->owner_graph() == m_graph_dest);
- return ret;
- }
-
- /* =================== ExtraDependencyMerger =================== */
- ExtraDependencyMerger::ExtraDependencyMerger(SpecialOprStat* sopr_stat)
- : m_sopr_stat{sopr_stat}, m_opr_iter{[this](OperatorNodeBase* opr) {
- on_opr(opr);
- }} {}
-
- ExtraDependencyMerger::~ExtraDependencyMerger() = default;
-
- void ExtraDependencyMerger::on_opr(OperatorNodeBase* opr) {
- if (!m_owner_graph) {
- m_owner_graph = opr->owner_graph();
- }
- mgb_assert(m_owner_graph == opr->owner_graph(),
- "owner graph changes in ExtraDependencyMerger; opr: %s{%s}",
- opr->cname(), opr->dyn_typeinfo()->name);
- auto&& extra_deps = m_owner_graph->options().extra_vardeps;
- auto sopr_stat = m_sopr_stat;
- MGB_MARK_USED_VAR(sopr_stat);
- auto&& new_deps = m_new_deps;
- for (auto i : opr->output()) {
- auto&& iter = extra_deps.find(i);
- if (iter != extra_deps.end()) {
- new_deps.insert(new_deps.end(), iter->second.begin(),
- iter->second.end());
- }
- #if !MGB_BUILD_SLIM_SERVING && MGB_ENABLE_GRAD
- if (sopr_stat && opr->same_type<opr::VirtualGrad>()) {
- sopr_stat->has_virtual_grad = true;
- }
- #endif
- }
- }
-
- SymbolVarArray& ExtraDependencyMerger::add(const SymbolVarArray& vars) {
- m_result.reserve(m_result.size() + vars.size());
- for (auto&& i : vars) {
- m_result.push_back(i);
- m_opr_iter.add(i);
- }
- while (!m_new_deps.empty()) {
- auto opr = m_new_deps.back()->owner_opr();
- m_new_deps.pop_back();
- if (!m_opr_iter.visited(opr)) {
- m_opr_iter.add(opr);
- m_result.push_back(opr->output(0));
- }
- }
- return m_result;
- }
-
- // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
|