/** * \file src/opr/impl/utility.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/graph/grad_impl.h" #include "megbrain/graph/event.h" #include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/operator_node.h" #include "megbrain/utils/debug.h" #include "megbrain/opr/utility.h" #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/comp_node_env.h" #include using namespace mgb; using namespace opr; #if !MGB_BUILD_SLIM_SERVING namespace { OperatorNodeConfig setup_config_cn(const OperatorNodeConfig& config_, const CompNode& cn) { auto prev_cn = config_.get_single_comp_node(); mgb_assert(!prev_cn.valid() || cn == prev_cn); auto config = config_; config.comp_node(cn); return config; } } // namespace /* ===================== Sleep ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Sleep); void Sleep::scn_do_execute() { #if MGB_HAVE_THREAD auto in = input(0), out = output(0); if (m_type.device) { if (!m_opr || m_opr.comp_node() != comp_node()) { m_opr = intl::create_megdnn_opr(comp_node()); } m_opr->param().time = m_seconds; m_opr->exec(); } if (m_type.host) { std::this_thread::sleep_for(std::chrono::microseconds( static_cast(m_seconds * 1e6))); } out->dev_tensor().copy_from_fixlayout(in->dev_tensor()); #else mgb_throw(MegBrainError, "sleep is unavilable when threading is disabled"); #endif } void Sleep::record_execute_deps(ExecDependencyArray& deps) { if (m_opr) { mixin::MegDNNOprHolder::record_megdnn_opr(std::move(m_opr), deps); } } void Sleep::sleep(const CompNode &node, double seconds) { node.activate(); auto opr = intl::get_megdnn_handle(node)->create_operator(); opr->param().time = seconds; opr->exec(); } Sleep::Sleep(VarNode *node, double seconds, Type type, const OperatorNodeConfig &config): Super(node->owner_graph(), config, "sleep", {node}), m_seconds{seconds}, m_type{type} { mgb_assert(seconds > 0); add_input({node}); add_output(None); add_equivalence_component>(&m_seconds); add_equivalence_component>(&m_type); } SymbolVar Sleep::make(SymbolVar node, double seconds, Type type, const OperatorNodeConfig &config) { mgb_assert(seconds >= 0); if (!seconds) return node; return node.insert_single_output_opr(node.node(), seconds, type, config); } MGB_IMPL_OPR_GRAD(Sleep) { return out_grad.at(0); } /* ===================== Timestamp ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Timestamp); class Timestamp::GraphStorage final : public UserDataContainer::UserData { MGB_TYPEINFO_OBJ_DECL; //! whether oprs and event info should be cleared upon next register call bool m_should_clear = false; SyncEventConnecter::ReceiverHandler m_recv_handler_wait, m_recv_handler_compile; std::vector m_oprs; CompNode::UnorderedMap m_first_event; public: GraphStorage(ComputingGraph* cg) { auto on_compile = [this](const cg::event::CompSeqOrderDetermined&) { m_should_clear = true; }; auto on_wait = [this](const cg::event::CompSeqExecFinished& event) { for (auto i : m_oprs) { i->update(); } mgb_assert(event.device_actually_finished, "Timestamp in subgraph is not supported"); }; m_recv_handler_compile = cg->event() .register_receiver( on_compile); m_recv_handler_wait = cg->event().register_receiver( on_wait); } //! return the first event on this comp seq CompNode::Event* register_opr(Timestamp* opr) { if (m_should_clear) { m_oprs.clear(); m_first_event.clear(); m_should_clear = true; } m_oprs.push_back(opr); auto ins = m_first_event.insert({opr->comp_node(), opr->m_event.get()}); return ins.first->second; } }; MGB_TYPEINFO_OBJ_IMPL(Timestamp::GraphStorage); void Timestamp::add_input_layout_constraint() { if (!m_event) { m_event = comp_node().create_event(CompNode::Event::Flags::NEED_TIMER); } auto make = [this]() { return std::make_shared(owner_graph()); }; auto storage = owner_graph() ->options() .user_data.get_user_data_or_create(make); m_first_event = storage->register_opr(this); Super::add_input_layout_constraint(); } void Timestamp::scn_do_execute_finish(const DeviceTensorND&) { m_event->record(); } void Timestamp::on_output_comp_node_stream_changed() { m_event.reset(); Super::on_output_comp_node_stream_changed(); } void Timestamp::update() { mgb_assert(m_dest_off < m_dest->shape(0)); m_dest->ptr()[m_dest_off] = m_first_event->elapsed_time_until(*m_event); } Timestamp::Timestamp(VarNode* node, std::shared_ptr dest, size_t dest_off, const OperatorNodeConfig& config) : Super(node->owner_graph(), config, "timestamp", {node}), m_dest{std::move(dest)}, m_dest_off{dest_off} { mgb_assert(m_dest, "empty dest tensor"); mgb_assert(m_dest->dtype() == dtype::Float32{} && m_dest->shape().ndim == 1 && dest_off < m_dest->shape()[0] && m_dest->layout().stride[0] == 1, "dest tensor must be 1-dimensional float32; got %s (%s)", m_dest->layout().to_string().c_str(), m_dest->dtype().name()); add_input({node}); add_output(None); add_equivalence_component>(m_dest.get()); add_equivalence_component>(m_dest_off); } SymbolVar Timestamp::make(SymbolVar node, std::shared_ptr dest, size_t dest_off, const OperatorNodeConfig& config) { return node.insert_single_output_opr( node.node(), std::move(dest), dest_off, config); } /* ========================== VirtualDep ============================ */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); VirtualDep::VirtualDep(const VarNodeArray& inputs, const OperatorNodeConfig& config) : Super(inputs[0]->owner_graph(), setup_config_cn(config, inputs[0]->comp_node()), "virtual_dep", inputs) { for (auto inp : inputs) { add_input({inp}); } mgb_assert(inputs[0]->dtype().valid()); add_output(None)->dtype(inputs[0]->dtype()); } cg::OperatorNodeBase::NodeProp* VirtualDep::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); if (input().size() > 1) { SmallVector dep_types{NodeProp::DepType::DEV_VALUE}; for (size_t i = 1; i < input().size(); ++i) { dep_types.push_back(NodeProp::DepType::DEV_COMP_ORDER); } prop->reset_dep_type(input(), dep_types); } prop->add_flag( cg::OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY); return prop; } SymbolVar VirtualDep::make(const SymbolVarArray& inputs, const OperatorNodeConfig& config) { mgb_assert(!inputs.empty()); auto nodes = to_var_node_array(inputs); return inputs[0].insert_single_output_opr(nodes, config); } MGB_IMPL_OPR_GRAD(VirtualDep) { if (wrt_idx == 0) { return out_grad.at(0); } return nullptr; } #endif // MGB_BUILD_SLIM_SERVING /* ===================== MarkDynamicVar ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkDynamicVar); void MarkDynamicVar::scn_do_execute() { auto i = input(0), o = output(0); o->shape_alloc(i->shape()); o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkDynamicVar) { return MarkDynamicVar::make(out_grad.at(0)).node(); } #endif MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config): Super{node->owner_graph(), config, "mark_dyn", {node}} { add_input({node}); add_output(None) ->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVar MarkDynamicVar::make( SymbolVar node, const OperatorNodeConfig &config) { return node.insert_single_output_opr(node.node(), config); } MarkDynamicVar::NodeProp* MarkDynamicVar::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); return ret; } /* ===================== CallbackInjector ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CallbackInjector); CallbackInjector::CallbackInjector( VarNode *inp, const Param ¶m, const OperatorNodeConfig &config): Super{inp->owner_graph(), config, "callback", {inp}}, m_param{param} { add_input({inp}); add_output(None); if (m_param.ignore_side_effect) { set_ignore_side_effect(); } // so this opr would not get deduped add_equivalence_component>(this); } CallbackInjector::CallbackInjector( VarNodeArray& inps, const Param ¶m, const OperatorNodeConfig &config): Super{inps[0]->owner_graph(), config, "callback", inps}, m_param{param} { for (auto inp : inps) { add_input({inp}); } add_output(None); if (m_param.ignore_side_effect) { set_ignore_side_effect(); } // so this opr would not get deduped add_equivalence_component>(this); } SymbolVar CallbackInjector::make(mgb::cg::SymbolVarArray inp, const Param ¶m, const OperatorNodeConfig &config) { auto nodes = to_var_node_array(inp); return inp[0].insert_single_output_opr(nodes, param, config); } void CallbackInjector::scn_do_execute_finish(const DeviceTensorND &val) { SmallVector input_list = {}; for(size_t i = 0; i < input().size(); ++i) { input_list.push_back(input(i)->dev_tensor()); } m_param.callback(const_cast&>(input_list)); } cg::OperatorNodeBase::NodeProp* CallbackInjector::do_make_node_prop() const { auto prop = ForwardInputToOutput::do_make_node_prop(); if (!m_param.allow_auto_dup) { prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP); } return prop; } cg::static_infer::ValueInferDesc CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { using namespace cg::static_infer; auto infer_val = [this](DeviceTensorND& dst, const InpVal& iv) -> bool { dst = iv.val[0].value(); if (!m_param.invoke_for_static_infer) { return true; } if (m_warn_printed < 10) { mgb_log_warn( "[warn %d/10] CallbackInjector %s is called during static " "value inference. The warning can be safely ignored if " "CallbackInjector does nothing other than inspecting the " "tensor value; otherwise it may introduce unexpected " "behavior.", ++m_warn_printed, cname()); } SmallVector callback_list = {}; for (size_t i = 0; i < iv.val.size(); ++i) { if (m_append_one_more_shape and i + 1== iv.val.size()) { continue; } callback_list.push_back(iv.val[i].value()); } m_param.callback(callback_list); return true; }; DepVal dep_val_list = {}; for (size_t i = 0; i < input().size(); ++i) { dep_val_list.push_back({opr.input(i), DepType::VALUE}); } if (m_param.invoke_for_static_infer) { return {SourceType::DEP, {{opr.input(0), DepType::VALUE}}, infer_val}; } else { return {SourceType::DEP, dep_val_list, infer_val}; } } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CallbackInjector) { MGB_MARK_USED_VAR(wrt_idx); return out_grad.at(0); } #endif /* ===================== MarkNoBroadcastElemwise ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise); MarkNoBroadcastElemwise::MarkNoBroadcastElemwise( VarNode* input, const OperatorNodeConfig &config): Super(input->owner_graph(), config, "no_brdcst", {input}) { add_input({input}); add_output(None); set_ignore_side_effect(); } SymbolVar MarkNoBroadcastElemwise::make( SymbolVar input, const OperatorNodeConfig &config) { return input.insert_single_output_opr( input.node(), config); } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { return out_grad.at(0); } #endif /* ===================== Identity ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); Identity::Identity(VarNode* input, const OperatorNodeConfig &config): Super(input->owner_graph(), config, "identity", {input}) { add_input({input}); add_output(None); set_ignore_side_effect(); } SymbolVar Identity::make( SymbolVar input, const OperatorNodeConfig &config) { if (input.node()->owner_opr()->same_type()) { // collapse consecutive Identity oprs // this is also necessary for megskull GradWrt in loop to work return input; } return input.insert_single_output_opr(input.node(), config); } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(Identity) { return out_grad.at(0); } #endif /* ===================== AssertEqual ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(AssertEqual); AssertEqual::AssertEqual( VarNode *expect, VarNode *get, VarNode *err, const Param ¶m, const OperatorNodeConfig &config): Super(err->owner_graph(), config, "assert_eq", {expect, get}), m_param{param} { add_input({expect, get, err}); add_output(None); add_equivalence_component>(&m_param); } SymbolVar AssertEqual::make(SymbolVar expect, SymbolVar get, const Param ¶m, const OperatorNodeConfig &config) { auto err = opr::reduce_max( opr::abs(expect - get) / opr::max( opr::min(opr::abs(expect), opr::abs(get)), expect.make_scalar_dt(1)), expect.make_scalar(1)); return make(expect, get, err, param, config); } SymbolVar AssertEqual::make( SymbolVar expect, SymbolVar get, SymbolVar err, const Param ¶m, const OperatorNodeConfig &config) { return expect.insert_single_output_opr( expect.node(), get.node(), err.node(), param, config); } void AssertEqual::scn_do_execute_finish(const DeviceTensorND &) { if (owner_graph()->options().comp_node_seq_record_level >= 2) { mgb_log_error("AssertEqual %s disabled due to seq rec", cname()); return; } m_hv.copy_from(input(2)->dev_tensor()).sync(); mgb_assert(m_hv.shape().is_scalar()); auto err = DTypeScalar::make_from_raw( m_hv.dtype(), m_hv.raw_ptr()).get_cast(); if (m_param.verbose) { //! FIXME: stderr will be slow when build windows with VS clang-cl (test in VM), //! but I can`t find the root case. fix it when you figure out fprintf(stdout, "AssertEqual: err=%g (name=%s id=%zu)\n", err, cname(), id()); } if (!(err >= 0 && err <= m_param.maxerr)) { HostTensorND expect, get; expect.copy_from(input(0)->dev_tensor()); get.copy_from(input(1)->dev_tensor()).sync(); auto msg = debug::compare_tensor_value( expect, cg::dump_var_info({input(0)}).c_str(), get, cg::dump_var_info({input(1)}).c_str(), m_param.maxerr); mgb_assert(msg.valid()); if (m_throw_on_error) { owner_graph()->record_async_error( cg::OperatorNodeExcExtraInfo::ExcMaker{ input(1)->owner_opr()}.make_unique(msg.val())); } else { mgb_log_error("%s", msg->c_str()); } } } #if MGB_ENABLE_GRAD /* ===================== SetGrad ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SetGrad); SetGrad::SetGrad( VarNode* input, const GradGetter& grad_getter, const OperatorNodeConfig &config): Super(input->owner_graph(), config, "set_grad", {input}), m_grad_getter{grad_getter} { add_input({input}); add_output(None); set_ignore_side_effect(); if (grad_getter) { // dedup not allowed add_equivalence_component>(this); } else { // force to be zero_grad if no callback, and we can safely enable dedup m_grad_getter = zero_grad; } } SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, const OperatorNodeConfig &config) { return input.insert_single_output_opr( input.node(), grad_getter, config); } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(SetGrad) { MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(out_grad); auto grad = opr.grad_getter()(opr); mgb_assert(!grad.node() || grad.node()->owner_graph() == opr.owner_graph(), "var returned by grad_getter belongs to a different comp graph"); return grad.node(); } #endif /* ===================== InvalidGrad ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InvalidGrad); void InvalidGrad::scn_do_execute() { mgb_assert(0); } InvalidGrad::InvalidGrad(VarNode* vinp, const OperatorNodeBase* grad_opr, size_t inp_idx) : Super{vinp->owner_graph(), {}, "invalid_grad", {vinp}}, m_grad_opr(grad_opr), m_inp_idx(inp_idx) { add_input({vinp}); add_output(None); } void InvalidGrad::add_input_layout_constraint() { MGB_MARK_USED_VAR(m_grad_opr); mgb_throw(GraphError, "invalid grad: can not take grad with respect to the %zu'th " "input var of operator {id:%zu, name:%s, type:%s}; " "(w.r.t. var: %s)", m_inp_idx, m_grad_opr->id(), m_grad_opr->cname(), m_grad_opr->dyn_typeinfo()->name, cg::dump_var_info(input()).c_str()); } VarNode* InvalidGrad::make(const OperatorNodeBase& grad_opr, size_t inp_idx) { return SymbolVar(grad_opr.input(inp_idx)) .insert_single_output_opr(grad_opr.input(inp_idx), &grad_opr, inp_idx) .node(); } /* ===================== VirtualGrad ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualGrad); VirtualGrad::VirtualGrad(VarNode *target, VarNode *wrt, const OperatorNodeConfig &config): Super(target->owner_graph(), config, "grad", {target, wrt}) { add_input({target, wrt}); add_output(None)->dtype(wrt->dtype()); } SymbolVar VirtualGrad::make(SymbolVar target, SymbolVar wrt, Param, const OperatorNodeConfig &config) { return target.insert_single_output_opr( target.node(), wrt.node(), config); } void VirtualGrad::do_execute(ExecEnv &) { mgb_throw(MegBrainError, "VirtualGrad opr must be removed by " "gopt::ExpandVirtualGradPass"); } void VirtualGrad::init_output_comp_node() { output(0)->comp_node(input(1)->comp_node()); } void VirtualGrad::init_output_static_infer_desc() { using namespace cg::static_infer; auto &&mgr = owner_graph()->static_infer_manager(); auto ovar = output(0), ivar = input(1); mgr.register_shape_infer(ovar, ShapeInferDesc::make_identity(ivar)); } void VirtualGrad::on_output_comp_node_stream_changed() { } VirtualGrad::NodeProp* VirtualGrad::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); return ret; } /* ===================== VirtualLoss ===================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualLoss); VirtualLoss::VirtualLoss(const VarNodeArray& inputs, const OperatorNodeConfig& config) : Super(inputs.at(0)->owner_graph(), config, "internal_grad", {inputs.at(0)}) { mgb_assert(inputs.size() % 2 == 0); for (size_t i = 0, it = inputs.size() / 2; i < it; ++i) { auto yi = inputs[i], gradi = inputs[i + it]; mgb_assert(yi && gradi); auto&& shp0 = yi->shape(); auto&& shp1 = gradi->shape(); mgb_assert((!shp0.ndim && !shp1.ndim) || shp0.eq_shape(shp1), "grad shape mismatch: %s vs %s", shp0.to_string().c_str(), shp1.to_string().c_str()); mgb_assert(yi->comp_node() == gradi->comp_node()); add_input({yi}); } for (size_t i = inputs.size() / 2; i < inputs.size(); ++i) { add_input({inputs[i]}); } add_output(None)->dtype(dtype::Float32{}); } SymbolVar VirtualLoss::make(const SymbolVarArray& ys, const SymbolVarArray& y_grads, Param, const OperatorNodeConfig& config) { mgb_assert(ys.size() == y_grads.size() && !ys.empty()); VarNodeArray inputs = to_var_node_array(ys); // sort for better dedup auto cmp = [](VarNode* a, VarNode* b) { return a->id() < b->id(); }; std::sort(inputs.begin(), inputs.end(), cmp); ThinHashMap var2grad; for (size_t i = 0; i < inputs.size(); ++i) { var2grad[ys[i].node()] = y_grads[i].node(); } inputs.resize(inputs.size() * 2); for (size_t i = 0, it = inputs.size() / 2; i < it; ++i) { inputs[i + it] = var2grad.at(inputs[i]); } return ys[0].insert_single_output_opr(inputs, config); } void VirtualLoss::do_execute(ExecEnv&) { mgb_throw_if( #if MGB_BUILD_SLIM_SERVING true, #else !owner_graph()->options().eager_evaluation, #endif MegBrainError, "InternalGradLoss should never be executed"); } void VirtualLoss::init_output_comp_node() { output(0)->comp_node(input(0)->comp_node()); } void VirtualLoss::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({1})); } void VirtualLoss::on_output_comp_node_stream_changed() {} VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); return ret; } #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(VirtualLoss) { mgb_assert(out_grad.size() == 1); auto mid = opr.input().size() / 2; if (wrt_idx < mid) { return opr.input(wrt_idx + mid); } return nullptr; } #endif #else VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) { mgb_throw(MegBrainError, "grad disabled at compile time"); } #endif // MGB_ENABLE_GRAD /* ================== PersistentOutputStorage =================== */ class PersistentOutputStorage::StorageHolder final : public UserDataContainer::UserData { MGB_TYPEINFO_OBJ_DECL; using Key = std::pair; struct KeyHash { size_t operator()(const Key& key) const { return hash_pair_combine(HashTrait::eval(key.first), key.second); } }; std::mutex m_mtx; std::unordered_map m_storage; public: void set_tensor(DeviceTensorND& dst, int key, CompNode comp_node, const TensorLayout& layout) { MGB_LOCK_GUARD(m_mtx); DeviceTensorStorage* storage; Maybe local_storage; if (key == -1) { storage = &local_storage.emplace(dst.storage()); } else { storage = &m_storage[{comp_node, key}]; } if (!storage->comp_node_valid()) { storage->comp_node(comp_node); } auto s = layout.span().dist_byte(); if (s > storage->size()) { if (storage->size()) { // exponential growth if size gets increased s = s * 3 / 2; } storage->ensure_size(s); } dst.reset(*storage, layout); } }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(PersistentOutputStorage); MGB_TYPEINFO_OBJ_IMPL(PersistentOutputStorage::StorageHolder); class PersistentOutputStorage::DevValueExecDep final : public ExecDependency { DeviceTensorStorage m_val; public: explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {} }; PersistentOutputStorage::PersistentOutputStorage( VarNode* inp, const Param& param, const OperatorNodeConfig& config) : Super{inp->owner_graph(), config, "persist", {}}, m_param{param} { add_input({inp}); add_output(None) ->add_flag(VarNode::Flag::NO_MEM_RECLAIM) .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC); } SymbolVar PersistentOutputStorage::make(SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { return inp.insert_single_output_opr(inp.node(), param, config); } void PersistentOutputStorage::record_execute_deps(ExecDependencyArray& deps) { mgb_assert(!m_dev_tensor.empty()); deps.emplace_back( std::make_unique(m_dev_tensor.storage())); } void PersistentOutputStorage::scn_do_execute() { auto &&od = output(0)->dev_tensor(), &&id = input(0)->dev_tensor(); mgb_assert(od.raw_ptr() == m_dev_tensor.raw_ptr()); od.copy_from_fixlayout(id); } void PersistentOutputStorage::init_output_mem_plan(bool dynamic) { mgb_throw_if( dynamic, GraphError, "PersistentOutputStorage can not be used in dynamic storage case"); auto cn = comp_node(); auto ovar = output(0); mgb_assert(cg::is_static_var_storage(ovar)); // note that this method is called after static shape infer, so it is safe // to access var shapes here auto&& shape = ovar->shape(); if (!m_dev_tensor.shape().eq_shape(shape) || m_dev_tensor.comp_node() != cn) { TensorLayout layout{shape, ovar->dtype(), ovar->format()}; auto holder = owner_graph() ->options() .user_data.get_user_data_or_create(); holder->set_tensor(m_dev_tensor, m_param.share_key, cn, layout); } ovar->init_mem_plan(&m_dev_tensor); } /* ================ RequireInputDynamicStorage ================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RequireInputDynamicStorage); RequireInputDynamicStorage::RequireInputDynamicStorage( VarNode* input, const OperatorNodeConfig& config) : Super{input->owner_graph(), config, "require_input_dynamic_storage", {input}} { input->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC); add_input({input}); add_output(None); } SymbolVar RequireInputDynamicStorage::make(const SymbolVar input, const OperatorNodeConfig& config) { return input.insert_single_output_opr( input.node(), config); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}