GitOrigin-RevId: 3b95da02c8
release-1.7
@@ -76,7 +76,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
cg::OperatorNodeBase* apply_on_var_node( | |||||
VarNodeArray apply_on_var_node( | |||||
const OpDef& def, const VarNodeArray& inputs) { | const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& attr = def.cast_final_safe<OprAttr>(); | auto&& attr = def.cast_final_safe<OprAttr>(); | ||||
auto config = attr.config; | auto config = attr.config; | ||||
@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
auto registry = serialization::OprRegistry::find_by_name(attr.type); | auto registry = serialization::OprRegistry::find_by_name(attr.type); | ||||
mgb_assert(registry, "operator %s not found", attr.type.c_str()); | mgb_assert(registry, "operator %s not found", attr.type.c_str()); | ||||
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | ||||
return registry->loader(ctx, inputs, config); | |||||
return registry->loader(ctx, inputs, config).usable_output(); | |||||
} | } | ||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | ||||
@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | ||||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | ||||
{ | { | ||||
auto op = OprAttr::make("BatchNorm"); | |||||
auto op = OprAttr::make("BatchNormV1"); | |||||
auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
Param param; | Param param; | ||||
param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
{false, false, false, false, false, true}); | {false, false, false, false, false, true}); | ||||
} | } | ||||
{ | { | ||||
auto op = OprAttr::make("BatchNorm"); | |||||
auto op = OprAttr::make("BatchNormV1"); | |||||
auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
Param param; | Param param; | ||||
param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { | |||||
} | } | ||||
TEST(TestImperative, BatchNorm) { | TEST(TestImperative, BatchNorm) { | ||||
auto op = OprAttr::make("BatchNorm"); | |||||
auto op = OprAttr::make("BatchNormV1"); | |||||
auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
using Param = opr::BatchNorm::Param; | using Param = opr::BatchNorm::Param; | ||||
Param param; | Param param; | ||||
@@ -16,14 +16,13 @@ | |||||
#include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||||
#include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> { | |||||
} | } | ||||
}; | }; | ||||
// OprMaker in MGB_SEREG_OPR only support unique output opr | |||||
template <> | template <> | ||||
struct OprMaker<opr::BatchNormBackward, 6> { | struct OprMaker<opr::BatchNormBackward, 6> { | ||||
using Param = opr::BatchNormBackward::Param; | using Param = opr::BatchNormBackward::Param; | ||||
@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||||
ComputingGraph& graph, | ComputingGraph& graph, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
MGB_MARK_USED_VAR(graph); | MGB_MARK_USED_VAR(graph); | ||||
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, | |||||
config)[0] | |||||
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], | |||||
param, config)[0] | |||||
.node() | .node() | ||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||||
using ConvBiasForwardV4 = ConvBiasForward; | using ConvBiasForwardV4 = ConvBiasForward; | ||||
MGB_SEREG_OPR(ConvBiasForwardV4, 0); | MGB_SEREG_OPR(ConvBiasForwardV4, 0); | ||||
MGB_SEREG_OPR(BatchNorm, 0); | |||||
MGB_SEREG_OPR(BatchNormBackward, 6); | |||||
using BatchNormV1 = BatchNorm; | |||||
using BatchNormBackwardV1 = BatchNormBackward; | |||||
MGB_SEREG_OPR(BatchNormV1, 0); | |||||
MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||||
using LocalShareForwardV1 = LocalShareForward; | using LocalShareForwardV1 = LocalShareForward; | ||||
using LocalShareBackwardDataV1 = LocalShareBackwardData; | using LocalShareBackwardDataV1 = LocalShareBackwardData; | ||||
@@ -39,7 +39,7 @@ namespace { | |||||
return inst; | return inst; | ||||
} | } | ||||
cg::OperatorNodeBase* dynamic_loader( | |||||
OprWithOutputAccessor dynamic_loader( | |||||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | OprLoadContext &ctx, const cg::VarNodeArray &inputs, | ||||
const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
auto name = ctx.load_buf_with_len(); | auto name = ctx.load_buf_with_len(); | ||||
@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() { | |||||
} | } | ||||
#endif | #endif | ||||
namespace { | |||||
const VarNodeArray& default_accessor(const VarNodeArray& outputs) { | |||||
return outputs; | |||||
} | |||||
} | |||||
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){ | |||||
m_accessor = &default_accessor; | |||||
}; | |||||
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor) | |||||
: OprWithOutputAccessor(opr) { | |||||
if (accessor) { | |||||
m_accessor = accessor; | |||||
} | |||||
}; | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||||
registry->dumper(dumper, opr); | registry->dumper(dumper, opr); | ||||
OprLoadContextMemory loader{opr.owner_graph(), dumper}; | OprLoadContextMemory loader{opr.owner_graph(), dumper}; | ||||
return registry->loader(loader, inputs, config); | |||||
return registry->loader(loader, inputs, config).opr(); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
} | } | ||||
// call loader | // call loader | ||||
auto opr = registry->loader(*this, inputs, config); | |||||
auto accessor = registry->loader(*this, inputs, config); | |||||
auto opr = accessor.opr(); | |||||
// check opr type; note that: | // check opr type; note that: | ||||
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs | // 1. registry->type may be empty for dynamic opr loaders or legacy oprs | ||||
@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | ||||
// record output vars; read output names | // record output vars; read output names | ||||
size_t i = 0; | size_t i = 0; | ||||
for (auto ovar : opr->output()) { | |||||
for (auto ovar : accessor.output()) { | |||||
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
m_id2varnode.push_back(ovar); | m_id2varnode.push_back(ovar); | ||||
if (fbopr->output_name()) { | if (fbopr->output_name()) { | ||||
@@ -19,16 +19,36 @@ namespace serialization { | |||||
class OprDumpContext; | class OprDumpContext; | ||||
class OprLoadContext; | class OprLoadContext; | ||||
class OprShallowCopyContext; | class OprShallowCopyContext; | ||||
class OprWithOutputAccessor { | |||||
cg::OperatorNodeBase* m_opr; | |||||
using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>; | |||||
Accessor m_accessor; | |||||
public: | |||||
OprWithOutputAccessor(cg::OperatorNodeBase* opr); | |||||
OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor); | |||||
VarNode* output(size_t idx) const { return output().at(idx); } | |||||
VarNodeArray output() const { return m_accessor(m_opr->output()); } | |||||
VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); } | |||||
cg::OperatorNodeBase* opr() { return m_opr; } | |||||
}; | |||||
//! dump opr internal params to OprDumpContext | //! dump opr internal params to OprDumpContext | ||||
using OprDumper = thin_function<void( | using OprDumper = thin_function<void( | ||||
OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | ||||
//! load and restore operator from OprLoadContext | //! load and restore operator from OprLoadContext | ||||
//! is also used by GraphLoadConfig. | |||||
using OprLoader = thin_function<cg::OperatorNodeBase*( | using OprLoader = thin_function<cg::OperatorNodeBase*( | ||||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | OprLoadContext &ctx, const cg::VarNodeArray &inputs, | ||||
const OperatorNodeConfig &config)>; | const OperatorNodeConfig &config)>; | ||||
//! loader that can change opr output map for compatibility | |||||
using OprLoaderWrapper = thin_function<OprWithOutputAccessor( | |||||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config)>; | |||||
//! shallow copy function for a single operator | //! shallow copy function for a single operator | ||||
using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | ||||
const OprShallowCopyContext &ctx, | const OprShallowCopyContext &ctx, | ||||
@@ -41,7 +61,7 @@ namespace serialization { | |||||
uint64_t persist_type_id; | uint64_t persist_type_id; | ||||
std::string name; | std::string name; | ||||
OprDumper dumper; | OprDumper dumper; | ||||
OprLoader loader; | |||||
OprLoaderWrapper loader; | |||||
OprShallowCopy shallow_copy; //!< set to empty to use default impl | OprShallowCopy shallow_copy; //!< set to empty to use default impl | ||||
uint64_t unversioned_type_id; | uint64_t unversioned_type_id; | ||||
@@ -167,16 +167,22 @@ namespace { \ | |||||
/*! | /*! | ||||
* \brief register opr serialization methods | * \brief register opr serialization methods | ||||
*/ | */ | ||||
#define MGB_SEREG_OPR(_cls, _arity) \ | |||||
namespace { \ | |||||
struct _OprReg##_cls { \ | |||||
static void entry() { \ | |||||
using Impl = ::mgb::serialization::OprLoadDumpImpl< \ | |||||
_cls, _arity>; \ | |||||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
#define MGB_SEREG_OPR(_cls, _arity) \ | |||||
namespace { \ | |||||
namespace ser = ::mgb::serialization; \ | |||||
struct _OprReg##_cls { \ | |||||
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||||
static ser::OprWithOutputAccessor wrap_loader( \ | |||||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
const mgb::cg::OperatorNodeConfig& config) { \ | |||||
return ser::OprWithOutputAccessor( \ | |||||
Impl::load(ctx, inputs, config)); \ | |||||
} \ | |||||
static void entry() { \ | |||||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | ||||
//! use to check type is complete or not, midout need a complete type | //! use to check type is complete or not, midout need a complete type | ||||
@@ -187,33 +193,35 @@ template <class T> | |||||
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | ||||
//! call OprRegistry::add with only loader, used for backward compatibility | //! call OprRegistry::add with only loader, used for backward compatibility | ||||
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||||
namespace { \ | |||||
static_assert(IsComplete<_name>(), \ | |||||
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||||
struct _OprReg##_name { \ | |||||
static cg::OperatorNodeBase* compat_loader( \ | |||||
serialization::OprLoadContext& ctx, \ | |||||
const cg::VarNodeArray& inputs, \ | |||||
const OperatorNodeConfig& config) { \ | |||||
return _load( \ | |||||
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \ | |||||
inputs, config); \ | |||||
} \ | |||||
static void entry() { \ | |||||
::mgb::serialization::OprRegistry::add( \ | |||||
{nullptr, \ | |||||
MGB_HASH_STR(#_name), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||||
nullptr, \ | |||||
compat_loader, \ | |||||
{}, \ | |||||
{}}); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \ | |||||
namespace { \ | |||||
static_assert(IsComplete<_name>(), \ | |||||
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||||
namespace ser = ::mgb::serialization; \ | |||||
struct _OprReg##_name { \ | |||||
static ser::OprWithOutputAccessor compat_loader( \ | |||||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
const mgb::cg::OperatorNodeConfig& config) { \ | |||||
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||||
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \ | |||||
_accessor); \ | |||||
} \ | |||||
static void entry() { \ | |||||
ser::OprRegistry::add({nullptr, \ | |||||
MGB_HASH_STR(#_name), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||||
nullptr, \ | |||||
compat_loader, \ | |||||
{}, \ | |||||
{}}); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | ||||
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||||
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr) | |||||
/*! | /*! | ||||
* \brief use \p _copy to implement shallow copy for given operator | * \brief use \p _copy to implement shallow copy for given operator | ||||
*/ | */ | ||||