Browse Source

feat(mgb/serialization): add Accessor for OprLoader to fix BN output compatibility

GitOrigin-RevId: 3b95da02c8
release-1.7
Megvii Engine Team 3 years ago
parent
commit
270f1aa28f
9 changed files with 98 additions and 51 deletions
  1. +2
    -2
      imperative/src/impl/ops/opr_attr.cpp
  2. +2
    -2
      imperative/src/test/backward_graph.cpp
  3. +1
    -1
      imperative/src/test/imperative.cpp
  4. +8
    -6
      src/opr/impl/dnn/dnn.sereg.h
  5. +17
    -1
      src/serialization/impl/opr_registry.cpp
  6. +1
    -1
      src/serialization/impl/opr_shallow_copy.cpp
  7. +3
    -2
      src/serialization/impl/serializer_oss.cpp
  8. +21
    -1
      src/serialization/include/megbrain/serialization/opr_registry.h
  9. +43
    -35
      src/serialization/include/megbrain/serialization/sereg.h

+ 2
- 2
imperative/src/impl/ops/opr_attr.cpp View File

@@ -76,7 +76,7 @@ public:
}
};

cg::OperatorNodeBase* apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, config);
return registry->loader(ctx, inputs, config).usable_output();
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {


+ 2
- 2
imperative/src/test/backward_graph.cpp View File

@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) {
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
{
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;
@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) {
{false, false, false, false, false, true});
}
{
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
Param param;
param.fwd_mode = Param::FwdMode::TRAINING;


+ 1
- 1
imperative/src/test/imperative.cpp View File

@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) {
}

TEST(TestImperative, BatchNorm) {
auto op = OprAttr::make("BatchNorm");
auto op = OprAttr::make("BatchNormV1");
auto&& attr = op->cast_final_safe<OprAttr>();
using Param = opr::BatchNorm::Param;
Param param;


+ 8
- 6
src/opr/impl/dnn/dnn.sereg.h View File

@@ -16,14 +16,13 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> {
}
};

// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::BatchNormBackward, 6> {
using Param = opr::BatchNormBackward::Param;
@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> {
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param,
config)[0]
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5],
param, config)[0]
.node()
->owner_opr();
}
@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);

MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 6);
using BatchNormV1 = BatchNorm;
using BatchNormBackwardV1 = BatchNormBackward;
MGB_SEREG_OPR(BatchNormV1, 0);
MGB_SEREG_OPR(BatchNormBackwardV1, 6);

using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;


+ 17
- 1
src/serialization/impl/opr_registry.cpp View File

@@ -39,7 +39,7 @@ namespace {
return inst;
}

cg::OperatorNodeBase* dynamic_loader(
OprWithOutputAccessor dynamic_loader(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto name = ctx.load_buf_with_len();
@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() {
}
#endif

namespace {
const VarNodeArray& default_accessor(const VarNodeArray& outputs) {
return outputs;
}
}

OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){
m_accessor = &default_accessor;
};
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor)
: OprWithOutputAccessor(opr) {
if (accessor) {
m_accessor = accessor;
}
};

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
src/serialization/impl/opr_shallow_copy.cpp View File

@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
registry->dumper(dumper, opr);

OprLoadContextMemory loader{opr.owner_graph(), dumper};
return registry->loader(loader, inputs, config);
return registry->loader(loader, inputs, config).opr();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 3
- 2
src/serialization/impl/serializer_oss.cpp View File

@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
}

// call loader
auto opr = registry->loader(*this, inputs, config);
auto accessor = registry->loader(*this, inputs, config);
auto opr = accessor.opr();

// check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs
@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name);
// record output vars; read output names
size_t i = 0;
for (auto ovar : opr->output()) {
for (auto ovar : accessor.output()) {
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
m_id2varnode.push_back(ovar);
if (fbopr->output_name()) {


+ 21
- 1
src/serialization/include/megbrain/serialization/opr_registry.h View File

@@ -19,16 +19,36 @@ namespace serialization {
class OprDumpContext;
class OprLoadContext;
class OprShallowCopyContext;
class OprWithOutputAccessor {
cg::OperatorNodeBase* m_opr;
using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>;
Accessor m_accessor;

public:
OprWithOutputAccessor(cg::OperatorNodeBase* opr);
OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor);
VarNode* output(size_t idx) const { return output().at(idx); }
VarNodeArray output() const { return m_accessor(m_opr->output()); }
VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); }
cg::OperatorNodeBase* opr() { return m_opr; }
};


//! dump opr internal params to OprDumpContext
using OprDumper = thin_function<void(
OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>;

//! load and restore operator from OprLoadContext
//! is also used by GraphLoadConfig.
using OprLoader = thin_function<cg::OperatorNodeBase*(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>;

//! loader that can change opr output map for compatibility
using OprLoaderWrapper = thin_function<OprWithOutputAccessor(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config)>;

//! shallow copy function for a single operator
using OprShallowCopy = thin_function<cg::OperatorNodeBase*(
const OprShallowCopyContext &ctx,
@@ -41,7 +61,7 @@ namespace serialization {
uint64_t persist_type_id;
std::string name;
OprDumper dumper;
OprLoader loader;
OprLoaderWrapper loader;
OprShallowCopy shallow_copy; //!< set to empty to use default impl
uint64_t unversioned_type_id;



+ 43
- 35
src/serialization/include/megbrain/serialization/sereg.h View File

@@ -167,16 +167,22 @@ namespace { \
/*!
* \brief register opr serialization methods
*/
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
struct _OprReg##_cls { \
static void entry() { \
using Impl = ::mgb::serialization::OprLoadDumpImpl< \
_cls, _arity>; \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \
} \
}; \
} \
#define MGB_SEREG_OPR(_cls, _arity) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprReg##_cls { \
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor( \
Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)

//! use to check type is complete or not, midout need a complete type
@@ -187,33 +193,35 @@ template <class T>
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};

//! call OprRegistry::add with only loader, used for backward compatibility
#define MGB_SEREG_OPR_COMPAT(_name, _load) \
namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
struct _OprReg##_name { \
static cg::OperatorNodeBase* compat_loader( \
serialization::OprLoadContext& ctx, \
const cg::VarNodeArray& inputs, \
const OperatorNodeConfig& config) { \
return _load( \
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \
inputs, config); \
} \
static void entry() { \
::mgb::serialization::OprRegistry::add( \
{nullptr, \
MGB_HASH_STR(#_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
nullptr, \
compat_loader, \
{}, \
{}}); \
} \
}; \
} \
#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \
namespace { \
static_assert(IsComplete<_name>(), \
"need a complete type for MGB_SEREG_OPR_COMPAT"); \
namespace ser = ::mgb::serialization; \
struct _OprReg##_name { \
static ser::OprWithOutputAccessor compat_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \
_accessor); \
} \
static void entry() { \
ser::OprRegistry::add({nullptr, \
MGB_HASH_STR(#_name), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \
nullptr, \
compat_loader, \
{}, \
{}}); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name)

#define MGB_SEREG_OPR_COMPAT(_name, _load) \
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr)

/*!
* \brief use \p _copy to implement shallow copy for given operator
*/


Loading…
Cancel
Save