fix extern_c_opr midout
GitOrigin-RevId: 7de4f650d1
tags/v0.3.2
@@ -224,7 +224,7 @@ SymbolVarArray _Opr::extern_c_opr_placeholder( | |||||
} | } | ||||
} | } | ||||
auto opr = serialization::ExternCOprRunner::make_placeholder( | |||||
auto opr = opr::ExternCOprRunner::make_placeholder( | |||||
inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes), | inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes), | ||||
PyBytes_Size(data_bytes), config, cpp_output_dtypes); | PyBytes_Size(data_bytes), config, cpp_output_dtypes); | ||||
SymbolVarArray ret; | SymbolVarArray ret; | ||||
@@ -18,6 +18,7 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace serialization; | using namespace serialization; | ||||
using namespace opr; | |||||
namespace { | namespace { | ||||
@@ -16,18 +16,19 @@ namespace mgb { | |||||
namespace serialization { | namespace serialization { | ||||
template <> | template <> | ||||
struct OprLoadDumpImpl<ExternCOprRunner, 0> { | |||||
struct OprLoadDumpImpl<opr::ExternCOprRunner, 0> { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | ||||
ExternCOprRunner::dump(ctx, opr); | |||||
opr::ExternCOprRunner::dump(ctx, opr); | |||||
} | } | ||||
static cg::OperatorNodeBase* load(OprLoadContext& ctx, | static cg::OperatorNodeBase* load(OprLoadContext& ctx, | ||||
const cg::VarNodeArray& inputs, | const cg::VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
return ExternCOprRunner::load(ctx, inputs, config); | |||||
return opr::ExternCOprRunner::load(ctx, inputs, config); | |||||
} | } | ||||
}; | }; | ||||
using ExternCOprRunner = opr::ExternCOprRunner; | |||||
MGB_SEREG_OPR(ExternCOprRunner, 0); | MGB_SEREG_OPR(ExternCOprRunner, 0); | ||||
MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy); | MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy); | ||||
} // namespace serialization | } // namespace serialization | ||||
@@ -16,7 +16,7 @@ | |||||
#include "megbrain/serialization/opr_registry.h" | #include "megbrain/serialization/opr_registry.h" | ||||
namespace mgb { | namespace mgb { | ||||
namespace serialization { | |||||
namespace opr { | |||||
//! an operator to run extern C oprs | //! an operator to run extern C oprs | ||||
MGB_DEFINE_OPR_CLASS(ExternCOprRunner, | MGB_DEFINE_OPR_CLASS(ExternCOprRunner, | ||||
@@ -68,10 +68,11 @@ public: | |||||
static bool unregister_loader(const char* name); | static bool unregister_loader(const char* name); | ||||
//! impl for serialization dump | //! impl for serialization dump | ||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr); | |||||
static void dump(serialization::OprDumpContext& ctx, | |||||
const cg::OperatorNodeBase& opr); | |||||
//! impl for serialization load | //! impl for serialization load | ||||
static cg::OperatorNodeBase* load(OprLoadContext& ctx, | |||||
static cg::OperatorNodeBase* load(serialization::OprLoadContext& ctx, | |||||
const cg::VarNodeArray& inputs, | const cg::VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
@@ -88,7 +89,7 @@ public: | |||||
static TensorShape tensor_shape_from_c(const MGBTensorShape& shape); | static TensorShape tensor_shape_from_c(const MGBTensorShape& shape); | ||||
}; | }; | ||||
} // namespace serialization | |||||
} // namespace opr | |||||
} // namespace mgb | } // namespace mgb | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -179,9 +179,18 @@ namespace { \ | |||||
} \ | } \ | ||||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | ||||
//! use to check type is complete or not, midout need a complete type | |||||
template <class T, class = void> | |||||
struct IsComplete : std::false_type {}; | |||||
template <class T> | |||||
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||||
//! call OprRegistry::add with only loader, used for backward compatibility | //! call OprRegistry::add with only loader, used for backward compatibility | ||||
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ | #define MGB_SEREG_OPR_COMPAT(_name, _load) \ | ||||
namespace { \ | namespace { \ | ||||
static_assert(IsComplete<_name>(), \ | |||||
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||||
struct _OprReg##_name { \ | struct _OprReg##_name { \ | ||||
static cg::OperatorNodeBase* compat_loader( \ | static cg::OperatorNodeBase* compat_loader( \ | ||||
serialization::OprLoadContext& ctx, \ | serialization::OprLoadContext& ctx, \ | ||||
@@ -182,7 +182,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale, | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
if (sleep) | if (sleep) | ||||
x = opr::Sleep::make(x, sleep); | x = opr::Sleep::make(x, sleep); | ||||
x = serialization::ExternCOprRunner::make_placeholder( | |||||
x = opr::ExternCOprRunner::make_placeholder( | |||||
{x}, {TensorShape{1}}, | {x}, {TensorShape{1}}, | ||||
dtype == MGB_DTYPE_FLOAT32 | dtype == MGB_DTYPE_FLOAT32 | ||||
? "bias_adder_dump" | ? "bias_adder_dump" | ||||
@@ -280,7 +280,7 @@ TEST(TestExternCOpr, Dedup) { | |||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
auto make_opr = [x](float bias) { | auto make_opr = [x](float bias) { | ||||
return ExternCOprRunner::make_from_desc( | |||||
return opr::ExternCOprRunner::make_from_desc( | |||||
{x.node()}, MGBOprDescImpl<>::make(bias)); | {x.node()}, MGBOprDescImpl<>::make(bias)); | ||||
}; | }; | ||||
auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5); | auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5); | ||||
@@ -115,7 +115,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale, | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
if (sleep) | if (sleep) | ||||
x = opr::Sleep::make(x, sleep); | x = opr::Sleep::make(x, sleep); | ||||
x = serialization::ExternCOprRunner::make_placeholder( | |||||
x = opr::ExternCOprRunner::make_placeholder( | |||||
{x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias)) | {x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias)) | ||||
->output(0); | ->output(0); | ||||
if (extra_scale) | if (extra_scale) | ||||