/** * \file src/opr/impl/utility.sereg.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/utility.h" #include "megbrain/serialization/sereg.h" #if MGB_ENABLE_FBS_SERIALIZATION #include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" #endif namespace mgb { namespace serialization { template<> struct OprLoadDumpImpl { static void dump(OprDumpContext &ctx, const cg::OperatorNodeBase &opr_) { auto &&opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); } static cg::OperatorNodeBase* load( OprLoadContext &ctx, const cg::VarNodeArray &inputs, const OperatorNodeConfig &config) { auto param = ctx.read_param(); SymbolVar out; if (inputs.size() == 2) { // from python out = opr::AssertEqual::make( inputs[0], inputs[1], param, config); } else { // from sereg or copy mgb_assert(inputs.size() == 3); out = opr::AssertEqual::make( inputs[0], inputs[1], inputs[2], param, config); } return out.node()->owner_opr(); } }; #if !MGB_BUILD_SLIM_SERVING template <> struct OprLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {} static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { return opr::VirtualDep::make(to_symbol_var_array(inputs), config) .node() ->owner_opr(); } }; #if MGB_ENABLE_FBS_SERIALIZATION namespace fbs { template <> struct ParamConverter { using FlatBufferType = param::MGBSleep; static opr::Sleep::Param to_param(const param::MGBSleep* fb) { return {fb->seconds(), {fb->device(), fb->host()}}; } static flatbuffers::Offset to_flatbuffer( flatbuffers::FlatBufferBuilder& builder, const opr::Sleep::Param& p) { return param::CreateMGBSleep(builder, p.type.device, p.type.host, p.seconds); } }; } // namespace fbs #endif #endif } // namespace serialization namespace opr { MGB_SEREG_OPR(MarkDynamicVar, 1); MGB_SEREG_OPR(MarkNoBroadcastElemwise, 1); MGB_SEREG_OPR(Identity, 1); MGB_SEREG_OPR(AssertEqual, 0); #if MGB_ENABLE_GRAD MGB_SEREG_OPR(VirtualGrad, 2); cg::OperatorNodeBase* opr_shallow_copy_set_grad( const serialization::OprShallowCopyContext &ctx, const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const OperatorNodeConfig &config) { mgb_assert(inputs.size() == 1); auto &&opr = opr_.cast_final_safe(); return SetGrad::make(inputs[0], opr.grad_getter(), config). node()->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(SetGrad, opr_shallow_copy_set_grad); cg::OperatorNodeBase* opr_shallow_copy_virtual_loss( const serialization::OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { return inputs[0]->owner_graph()->insert_opr( std::make_unique(inputs, config)); } MGB_REG_OPR_SHALLOW_COPY(VirtualLoss, opr_shallow_copy_virtual_loss); cg::OperatorNodeBase* opr_shallow_copy_invalid_grad( const serialization::OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { mgb_assert(inputs.size() == 1); auto&& opr = opr_.cast_final_safe(); return inputs[0]->owner_opr()->owner_graph()->insert_opr( std::make_unique(inputs[0], opr.grad_opr(), opr.inp_idx())); } MGB_REG_OPR_SHALLOW_COPY(InvalidGrad, opr_shallow_copy_invalid_grad) #endif cg::OperatorNodeBase* opr_shallow_copy_callback_injector( const serialization::OprShallowCopyContext &ctx, const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const OperatorNodeConfig &config) { auto &&opr = opr_.cast_final_safe(); return CallbackInjector::make(cg::to_symbol_var_array(inputs), opr.param(), config). node()->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(CallbackInjector, opr_shallow_copy_callback_injector); cg::OperatorNodeBase* opr_shallow_copy_require_input_dynamic_storage( const serialization::OprShallowCopyContext &ctx, const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const OperatorNodeConfig &config) { mgb_assert(inputs.size() == 1); return RequireInputDynamicStorage::make(inputs[0], config). node()->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(RequireInputDynamicStorage, opr_shallow_copy_require_input_dynamic_storage); #if !MGB_BUILD_SLIM_SERVING MGB_SEREG_OPR(Sleep, 1); MGB_SEREG_OPR(VirtualDep, 0); #endif MGB_SEREG_OPR(PersistentOutputStorage, 1); cg::OperatorNodeBase* opr_shallow_copy_shape_hint( const serialization::OprShallowCopyContext &ctx, const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const OperatorNodeConfig &config) { auto &&opr = opr_.cast_final_safe(); mgb_assert(inputs.size() == 1); return ShapeHint::make(inputs[0], opr.shape(), opr.is_const(), config) .node()->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(ShapeHint, opr_shallow_copy_shape_hint); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}