#include "megbrain/opr/misc.h" #include "megbrain/serialization/sereg.h" namespace mgb { namespace serialization { template <> struct OprMaker { using Opr = opr::Argsort; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); auto out = Opr::make(inputs[0], param, config); return out[0].node()->owner_opr(); } }; template <> struct OprMaker { using Opr = opr::CondTake; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); auto out = Opr::make(inputs[0], inputs[1], param, config); return out[0].node()->owner_opr(); } }; template <> struct OprMaker { using Opr = opr::TopK; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); auto out = Opr::make(inputs[0], inputs[1], param, config); return out[0].node()->owner_opr(); } }; template <> struct OprMaker { using Opr = opr::CheckNonFinite; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); auto out = Opr::make(inputs, param, config); return out[0].node()->owner_opr(); } }; } // namespace serialization namespace opr { MGB_SEREG_OPR(Argmax, 1); MGB_SEREG_OPR(Argmin, 1); MGB_SEREG_OPR(Argsort, 1); MGB_SEREG_OPR(ArgsortBackward, 3); MGB_SEREG_OPR(CondTake, 2); MGB_SEREG_OPR(TopK, 2); //! current cumsum version using CumsumV1 = opr::Cumsum; MGB_SEREG_OPR(CumsumV1, 1); MGB_SEREG_OPR(Cumprod, 1); #if MGB_CUDA MGB_SEREG_OPR(NvOf, 1); #endif MGB_SEREG_OPR(CheckNonFinite, 0); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}