Browse Source

fix(midout): fix brain opr midout 2/2 (also see a6aa1574)

fix extern_c_opr midout

GitOrigin-RevId: 7de4f650d1
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
dbe2b89331
7 changed files with 23 additions and 11 deletions
  1. +1
    -1
      python_module/src/cpp/opr_defs.cpp
  2. +1
    -0
      src/serialization/impl/extern_c_opr.cpp
  3. +4
    -3
      src/serialization/impl/extern_c_opr.sereg.h
  4. +5
    -4
      src/serialization/include/megbrain/serialization/extern_c_opr_io.h
  5. +9
    -0
      src/serialization/include/megbrain/serialization/sereg.h
  6. +2
    -2
      src/serialization/test/extern_c_opr.cpp
  7. +1
    -1
      src/serialization/test/extern_c_opr_v23.cpp

+ 1
- 1
python_module/src/cpp/opr_defs.cpp View File

@@ -224,7 +224,7 @@ SymbolVarArray _Opr::extern_c_opr_placeholder(
} }
} }


auto opr = serialization::ExternCOprRunner::make_placeholder(
auto opr = opr::ExternCOprRunner::make_placeholder(
inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes), inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes),
PyBytes_Size(data_bytes), config, cpp_output_dtypes); PyBytes_Size(data_bytes), config, cpp_output_dtypes);
SymbolVarArray ret; SymbolVarArray ret;


+ 1
- 0
src/serialization/impl/extern_c_opr.cpp View File

@@ -18,6 +18,7 @@


using namespace mgb; using namespace mgb;
using namespace serialization; using namespace serialization;
using namespace opr;


namespace { namespace {




+ 4
- 3
src/serialization/impl/extern_c_opr.sereg.h View File

@@ -16,18 +16,19 @@ namespace mgb {


namespace serialization { namespace serialization {
template <> template <>
struct OprLoadDumpImpl<ExternCOprRunner, 0> {
struct OprLoadDumpImpl<opr::ExternCOprRunner, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ExternCOprRunner::dump(ctx, opr);
opr::ExternCOprRunner::dump(ctx, opr);
} }


static cg::OperatorNodeBase* load(OprLoadContext& ctx, static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
return ExternCOprRunner::load(ctx, inputs, config);
return opr::ExternCOprRunner::load(ctx, inputs, config);
} }
}; };


using ExternCOprRunner = opr::ExternCOprRunner;
MGB_SEREG_OPR(ExternCOprRunner, 0); MGB_SEREG_OPR(ExternCOprRunner, 0);
MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy); MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy);
} // namespace serialization } // namespace serialization


+ 5
- 4
src/serialization/include/megbrain/serialization/extern_c_opr_io.h View File

@@ -16,7 +16,7 @@
#include "megbrain/serialization/opr_registry.h" #include "megbrain/serialization/opr_registry.h"


namespace mgb { namespace mgb {
namespace serialization {
namespace opr {


//! an operator to run extern C oprs //! an operator to run extern C oprs
MGB_DEFINE_OPR_CLASS(ExternCOprRunner, MGB_DEFINE_OPR_CLASS(ExternCOprRunner,
@@ -68,10 +68,11 @@ public:
static bool unregister_loader(const char* name); static bool unregister_loader(const char* name);


//! impl for serialization dump //! impl for serialization dump
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
static void dump(serialization::OprDumpContext& ctx,
const cg::OperatorNodeBase& opr);


//! impl for serialization load //! impl for serialization load
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
static cg::OperatorNodeBase* load(serialization::OprLoadContext& ctx,
const cg::VarNodeArray& inputs, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);


@@ -88,7 +89,7 @@ public:
static TensorShape tensor_shape_from_c(const MGBTensorShape& shape); static TensorShape tensor_shape_from_c(const MGBTensorShape& shape);
}; };


} // namespace serialization
} // namespace opr
} // namespace mgb } // namespace mgb


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 9
- 0
src/serialization/include/megbrain/serialization/sereg.h View File

@@ -179,9 +179,18 @@ namespace { \
} \ } \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)


//! use to check type is complete or not, midout need a complete type
template <class T, class = void>
struct IsComplete : std::false_type {};

template <class T>
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};

//! call OprRegistry::add with only loader, used for backward compatibility //! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ #define MGB_SEREG_OPR_COMPAT(_name, _load) \
namespace { \ namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
struct _OprReg##_name { \ struct _OprReg##_name { \
static cg::OperatorNodeBase* compat_loader( \ static cg::OperatorNodeBase* compat_loader( \
serialization::OprLoadContext& ctx, \ serialization::OprLoadContext& ctx, \


+ 2
- 2
src/serialization/test/extern_c_opr.cpp View File

@@ -182,7 +182,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale,
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
if (sleep) if (sleep)
x = opr::Sleep::make(x, sleep); x = opr::Sleep::make(x, sleep);
x = serialization::ExternCOprRunner::make_placeholder(
x = opr::ExternCOprRunner::make_placeholder(
{x}, {TensorShape{1}}, {x}, {TensorShape{1}},
dtype == MGB_DTYPE_FLOAT32 dtype == MGB_DTYPE_FLOAT32
? "bias_adder_dump" ? "bias_adder_dump"
@@ -280,7 +280,7 @@ TEST(TestExternCOpr, Dedup) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto make_opr = [x](float bias) { auto make_opr = [x](float bias) {
return ExternCOprRunner::make_from_desc(
return opr::ExternCOprRunner::make_from_desc(
{x.node()}, MGBOprDescImpl<>::make(bias)); {x.node()}, MGBOprDescImpl<>::make(bias));
}; };
auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5); auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5);


+ 1
- 1
src/serialization/test/extern_c_opr_v23.cpp View File

@@ -115,7 +115,7 @@ std::vector<uint8_t> create_graph_dump(float bias, float extra_scale,
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
if (sleep) if (sleep)
x = opr::Sleep::make(x, sleep); x = opr::Sleep::make(x, sleep);
x = serialization::ExternCOprRunner::make_placeholder(
x = opr::ExternCOprRunner::make_placeholder(
{x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias)) {x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias))
->output(0); ->output(0);
if (extra_scale) if (extra_scale)


Loading…
Cancel
Save