GitOrigin-RevId: 3b95da02c8
release-1.7
@@ -76,7 +76,7 @@ public: | |||
} | |||
}; | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
VarNodeArray apply_on_var_node( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& attr = def.cast_final_safe<OprAttr>(); | |||
auto config = attr.config; | |||
@@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
return registry->loader(ctx, inputs, config); | |||
return registry->loader(ctx, inputs, config).usable_output(); | |||
} | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
@@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { | |||
LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto op = OprAttr::make("BatchNormV1"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
@@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { | |||
{false, false, false, false, false, true}); | |||
} | |||
{ | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto op = OprAttr::make("BatchNormV1"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
Param param; | |||
param.fwd_mode = Param::FwdMode::TRAINING; | |||
@@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { | |||
} | |||
TEST(TestImperative, BatchNorm) { | |||
auto op = OprAttr::make("BatchNorm"); | |||
auto op = OprAttr::make("BatchNormV1"); | |||
auto&& attr = op->cast_final_safe<OprAttr>(); | |||
using Param = opr::BatchNorm::Param; | |||
Param param; | |||
@@ -16,14 +16,13 @@ | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
#include "megbrain/opr/dnn/tqt.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/opr_param_defs.h" | |||
@@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> { | |||
} | |||
}; | |||
// OprMaker in MGB_SEREG_OPR only support unique output opr | |||
template <> | |||
struct OprMaker<opr::BatchNormBackward, 6> { | |||
using Param = opr::BatchNormBackward::Param; | |||
@@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||
ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, | |||
config)[0] | |||
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], | |||
param, config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
@@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||
using ConvBiasForwardV4 = ConvBiasForward; | |||
MGB_SEREG_OPR(ConvBiasForwardV4, 0); | |||
MGB_SEREG_OPR(BatchNorm, 0); | |||
MGB_SEREG_OPR(BatchNormBackward, 6); | |||
using BatchNormV1 = BatchNorm; | |||
using BatchNormBackwardV1 = BatchNormBackward; | |||
MGB_SEREG_OPR(BatchNormV1, 0); | |||
MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||
using LocalShareForwardV1 = LocalShareForward; | |||
using LocalShareBackwardDataV1 = LocalShareBackwardData; | |||
@@ -39,7 +39,7 @@ namespace { | |||
return inst; | |||
} | |||
cg::OperatorNodeBase* dynamic_loader( | |||
OprWithOutputAccessor dynamic_loader( | |||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
const OperatorNodeConfig &config) { | |||
auto name = ctx.load_buf_with_len(); | |||
@@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() { | |||
} | |||
#endif | |||
namespace { | |||
const VarNodeArray& default_accessor(const VarNodeArray& outputs) { | |||
return outputs; | |||
} | |||
} | |||
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){ | |||
m_accessor = &default_accessor; | |||
}; | |||
OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor) | |||
: OprWithOutputAccessor(opr) { | |||
if (accessor) { | |||
m_accessor = accessor; | |||
} | |||
}; | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||
registry->dumper(dumper, opr); | |||
OprLoadContextMemory loader{opr.owner_graph(), dumper}; | |||
return registry->loader(loader, inputs, config); | |||
return registry->loader(loader, inputs, config).opr(); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
} | |||
// call loader | |||
auto opr = registry->loader(*this, inputs, config); | |||
auto accessor = registry->loader(*this, inputs, config); | |||
auto opr = accessor.opr(); | |||
// check opr type; note that: | |||
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs | |||
@@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | |||
// record output vars; read output names | |||
size_t i = 0; | |||
for (auto ovar : opr->output()) { | |||
for (auto ovar : accessor.output()) { | |||
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
m_id2varnode.push_back(ovar); | |||
if (fbopr->output_name()) { | |||
@@ -19,16 +19,36 @@ namespace serialization { | |||
class OprDumpContext; | |||
class OprLoadContext; | |||
class OprShallowCopyContext; | |||
class OprWithOutputAccessor { | |||
cg::OperatorNodeBase* m_opr; | |||
using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>; | |||
Accessor m_accessor; | |||
public: | |||
OprWithOutputAccessor(cg::OperatorNodeBase* opr); | |||
OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor); | |||
VarNode* output(size_t idx) const { return output().at(idx); } | |||
VarNodeArray output() const { return m_accessor(m_opr->output()); } | |||
VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); } | |||
cg::OperatorNodeBase* opr() { return m_opr; } | |||
}; | |||
//! dump opr internal params to OprDumpContext | |||
using OprDumper = thin_function<void( | |||
OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | |||
//! load and restore operator from OprLoadContext | |||
//! is also used by GraphLoadConfig. | |||
using OprLoader = thin_function<cg::OperatorNodeBase*( | |||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
const OperatorNodeConfig &config)>; | |||
//! loader that can change opr output map for compatibility | |||
using OprLoaderWrapper = thin_function<OprWithOutputAccessor( | |||
OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
const OperatorNodeConfig &config)>; | |||
//! shallow copy function for a single operator | |||
using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | |||
const OprShallowCopyContext &ctx, | |||
@@ -41,7 +61,7 @@ namespace serialization { | |||
uint64_t persist_type_id; | |||
std::string name; | |||
OprDumper dumper; | |||
OprLoader loader; | |||
OprLoaderWrapper loader; | |||
OprShallowCopy shallow_copy; //!< set to empty to use default impl | |||
uint64_t unversioned_type_id; | |||
@@ -167,16 +167,22 @@ namespace { \ | |||
/*! | |||
* \brief register opr serialization methods | |||
*/ | |||
#define MGB_SEREG_OPR(_cls, _arity) \ | |||
namespace { \ | |||
struct _OprReg##_cls { \ | |||
static void entry() { \ | |||
using Impl = ::mgb::serialization::OprLoadDumpImpl< \ | |||
_cls, _arity>; \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ | |||
} \ | |||
}; \ | |||
} \ | |||
#define MGB_SEREG_OPR(_cls, _arity) \ | |||
namespace { \ | |||
namespace ser = ::mgb::serialization; \ | |||
struct _OprReg##_cls { \ | |||
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
static ser::OprWithOutputAccessor wrap_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
return ser::OprWithOutputAccessor( \ | |||
Impl::load(ctx, inputs, config)); \ | |||
} \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
//! use to check type is complete or not, midout need a complete type | |||
@@ -187,33 +193,35 @@ template <class T> | |||
struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||
//! call OprRegistry::add with only loader, used for backward compatibility | |||
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||
namespace { \ | |||
static_assert(IsComplete<_name>(), \ | |||
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||
struct _OprReg##_name { \ | |||
static cg::OperatorNodeBase* compat_loader( \ | |||
serialization::OprLoadContext& ctx, \ | |||
const cg::VarNodeArray& inputs, \ | |||
const OperatorNodeConfig& config) { \ | |||
return _load( \ | |||
static_cast<serialization::OprLoadContextRawPOD&>(ctx), \ | |||
inputs, config); \ | |||
} \ | |||
static void entry() { \ | |||
::mgb::serialization::OprRegistry::add( \ | |||
{nullptr, \ | |||
MGB_HASH_STR(#_name), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||
nullptr, \ | |||
compat_loader, \ | |||
{}, \ | |||
{}}); \ | |||
} \ | |||
}; \ | |||
} \ | |||
#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \ | |||
namespace { \ | |||
static_assert(IsComplete<_name>(), \ | |||
"need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||
namespace ser = ::mgb::serialization; \ | |||
struct _OprReg##_name { \ | |||
static ser::OprWithOutputAccessor compat_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \ | |||
_accessor); \ | |||
} \ | |||
static void entry() { \ | |||
ser::OprRegistry::add({nullptr, \ | |||
MGB_HASH_STR(#_name), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||
nullptr, \ | |||
compat_loader, \ | |||
{}, \ | |||
{}}); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | |||
#define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||
MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr) | |||
/*! | |||
* \brief use \p _copy to implement shallow copy for given operator | |||
*/ | |||