GitOrigin-RevId: 8eacd5e77c
release-1.10
@@ -153,4 +153,6 @@ struct EnsureHashConstexpr { | |||
#define MGB_HASH_STR(v) \ | |||
::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val | |||
#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701) | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | |||
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | |||
namespace { \ | |||
struct _OprRegV2##cls { \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
cls, ::mgb::serialization::custom_dumper, \ | |||
::mgb::serialization::custom_loader, nullptr, _version_min, \ | |||
_version_max); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls) | |||
using namespace mgb; | |||
using CustomOpNode = opr::CustomOpNode; | |||
CUSTOM_OP_SEREG_REG(CustomOpNode); | |||
CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION); |
@@ -0,0 +1,228 @@ | |||
#include "megbrain/graph/symbol_var.h" | |||
#include "megdnn/oprs/general.h" | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/softmax.h" | |||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs/nn.h" | |||
namespace mgb { | |||
namespace serialization { | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Softmax, 1> { | |||
using Opr = opr::Softmax; | |||
using PersisParam = opr::Softmax::Param; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||
} | |||
static cg::OperatorNodeBase* replace_opr( | |||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
int32_t axis = opr->cast_final_safe<Opr>().param().axis; | |||
auto input_var = inputs[0]; | |||
auto max_reduce_out = | |||
opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis}); | |||
auto elemwise_sub_out = opr::Elemwise::make( | |||
{input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB}); | |||
auto elemwise_exp_out = | |||
opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP}); | |||
auto sum_reduce_out = | |||
opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis}); | |||
auto out = opr::Elemwise::make( | |||
{elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV}); | |||
return out.node()->owner_opr(); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
auto param = fbs_ctx.read_param<PersisParam>(0); | |||
return Opr::make(inputs[0], param, config).node()->owner_opr(); | |||
} | |||
}; | |||
template < | |||
class Opr, class Maker0, class MegDNNConv, | |||
class Maker1 = MakeConvCallerEmpty<MegDNNConv>, | |||
class Maker2 = MakeConvCallerEmpty<MegDNNConv>, | |||
typename ConvParam = megdnn::param::Convolution> | |||
struct WithPolicyOprLoadDumpImpl { | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
ctx.write_param<ConvParam>(opr.param()); | |||
ctx.write_param<megdnn::param::ExecutionPolicy>( | |||
opr.execution_policy_transient()); | |||
} | |||
static VarNode* make( | |||
const cg::VarNodeArray& inputs, const ConvParam& param, | |||
const megdnn::param::ExecutionPolicy& execution_policy, | |||
const OperatorNodeConfig& config) { | |||
VarNode* ret = | |||
Maker0::template make<Opr>(inputs, param, execution_policy, config); | |||
if (!ret) { | |||
ret = Maker1::template make<Opr>(inputs, param, execution_policy, config); | |||
} | |||
if (!ret) { | |||
ret = Maker2::template make<Opr>(inputs, param, execution_policy, config); | |||
} | |||
mgb_assert(ret); | |||
return ret; | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
fbs_ctx.get_current_opr_data()); | |||
auto conv_param = fbs_ctx.read_param<ConvParam>(0); | |||
megdnn::param::ExecutionPolicy policy; | |||
if (fopr->additional_params() && fopr->additional_params()->size()) { | |||
policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1); | |||
} | |||
return make(inputs, conv_param, policy, config)->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Convolution, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::Convolution, MakeConvCaller2<megdnn::Convolution>, | |||
megdnn::Convolution> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>, | |||
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>, | |||
megdnn::Convolution> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Convolution3D, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>, | |||
megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
megdnn::param::Convolution3D> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::Convolution3DBackwardData, | |||
MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D, | |||
MakeConvCaller3<megdnn::Convolution3D>, | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
megdnn::param::Convolution3D> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::Convolution3DBackwardFilter, | |||
MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D, | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
megdnn::param::Convolution3D> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::ConvBiasForward, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>, | |||
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>, | |||
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::BatchConvBiasForward, | |||
MakeConvCaller2<megdnn::BatchConvBiasForward>, | |||
megdnn::BatchConvBiasForward, | |||
MakeConvCaller3<megdnn::BatchConvBiasForward>, | |||
MakeConvCaller4<megdnn::BatchConvBiasForward>, | |||
megdnn::param::BatchConvBias> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::LocalShare, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>, | |||
megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
megdnn::param::LocalShare> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::LocalShareBackwardData, | |||
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
megdnn::param::LocalShare> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::LocalShareBackwardFilter, | |||
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
megdnn::param::LocalShare> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::DeformableConvForward, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::DeformableConvForward, | |||
MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> { | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::DeformableConvBackwardData, | |||
MakeConvCaller5<megdnn::DeformableConvBackwardData>, | |||
megdnn::Convolution> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0> | |||
: public WithPolicyOprLoadDumpImpl< | |||
opr::DeformableConvBackwardFilter, | |||
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | |||
megdnn::Convolution> {}; | |||
} // namespace serialization | |||
namespace opr { | |||
#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \ | |||
MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION); | |||
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||
SERGE_OPR_V2_CONVERTER( | |||
Softmax, 1, | |||
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr)); | |||
SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | |||
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution, 0) | |||
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0) | |||
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0) | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0); | |||
#undef SERGE_OPR_V2_CONVERTER | |||
#undef SERGE_OPR_V2_NO_CONVERTER | |||
} // namespace opr | |||
} // namespace mgb | |||
#endif | |||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,197 @@ | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/opr/dnn/softmax.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" | |||
#include "megbrain/serialization/internal/schema_v2_generated.h" | |||
namespace mgb { | |||
namespace serialization { | |||
template <> | |||
struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> { | |||
using Opr = opr::ImmutableTensor; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
using Meth = OprDumpContext::TensorWriteMethod; | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
ctx.dump_tensor( | |||
{}, HostTensorND{}.copy_from(opr.value()).sync(), | |||
Meth::VALUE_ANONYMOUS); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
mgb_assert(inputs.empty()); | |||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
fbs_ctx.get_current_opr_data()); | |||
if (fopr->tensors() && fopr->tensors()->size() > 0) { | |||
auto val = fbs_ctx.load_tensor(); | |||
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr(); | |||
} else { | |||
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data."); | |||
} | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> { | |||
using Opr = opr::Host2DeviceCopy; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
ctx.write_param(opr.param()); | |||
using Meth = OprDumpContext::TensorWriteMethod; | |||
ctx.dump_tensor( | |||
opr.name(), *opr.host_data(), | |||
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(inputs.empty()); | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
auto param = fbs_ctx.read_param<Opr::Param>(0); | |||
auto tensor = fbs_ctx.load_tensor(); | |||
return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> { | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
using Meth = OprDumpContext::TensorWriteMethod; | |||
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | |||
HostTensorND val; | |||
val.copy_from(opr.get_dev_tensor()).sync(); | |||
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(inputs.empty()); | |||
auto val = ctx.load_tensor(); | |||
auto dev_val = | |||
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout()); | |||
dev_val->copy_from_fixlayout(*val); | |||
auto out_var = | |||
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | |||
dev_val->sync(); | |||
return out_var.node()->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> { | |||
using Opr = opr::MultipleDeviceTensorHolder; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
using Meth = OprDumpContext::TensorWriteMethod; | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
uint32_t nr_val = opr.values().size(); | |||
for (uint32_t i = 0; i < nr_val; ++i) { | |||
HostTensorND val; | |||
val.copy_from(*opr.values()[i]).sync(); | |||
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||
} | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(inputs.empty()); | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
fbs_ctx.get_current_opr_data()); | |||
uint32_t nr = 0; | |||
if (fopr && fopr->tensors()) { | |||
nr = fopr->tensors()->size(); | |||
} | |||
Opr::ValueArray values(nr); | |||
for (auto&& i : values) { | |||
i = ctx.load_tensor_shared(); | |||
} | |||
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||
using Opr = opr::MultipleDeviceTensorWithFormatHolder; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
using Meth = OprDumpContext::TensorWriteMethod; | |||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||
uint32_t nr_val = opr.values().size(); | |||
for (uint32_t i = 0; i < nr_val; ++i) { | |||
HostTensorND val; | |||
auto value = *opr.values()[i]; | |||
val.copy_from(value).sync(); | |||
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||
} | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(inputs.empty()); | |||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
fbs_ctx.get_current_opr_data()); | |||
uint32_t nr = 0; | |||
if (fopr && fopr->tensors()) { | |||
nr = fopr->tensors()->size(); | |||
} | |||
Opr::ValueArray values(nr); | |||
for (auto&& i : values) { | |||
i = ctx.load_tensor_shared(); | |||
//! set tensor format | |||
TensorLayout layout_with_format = i->layout(); | |||
if (i->storage().comp_node().mem_node() == | |||
CompNode::default_cpu().mem_node()) { | |||
mgb_assert( | |||
i->storage().ptr(), | |||
"storage should not be nullptr if mem_node is " | |||
"default_cpu"); | |||
HostTensorND src{i->storage().comp_node(), layout_with_format}; | |||
src.copy_from_fixlayout(*i).sync(); | |||
*i = DeviceTensorND::make_proxy(src); | |||
} else { | |||
//! actually only layout of this tensor will be used later, see | |||
//! src/serialization/impl/batched_device_value_loader.cpp:49. But we | |||
//! have no way to reset layout only, so just construct a invalid | |||
//! storage instead | |||
auto size = layout_with_format.span().dist_byte(); | |||
DeviceTensorStorage storage; | |||
storage.reset(i->comp_node(), size, nullptr); | |||
i->reset(storage, layout_with_format); | |||
} | |||
} | |||
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||
} | |||
}; | |||
} // namespace serialization | |||
namespace opr { | |||
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||
SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0); | |||
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0); | |||
} // namespace opr | |||
} // namespace mgb | |||
#endif | |||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() { | |||
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | |||
MGB_SEREG_OPR_INTL_CALL_ADD( | |||
CounterProvider, dump_counter_provider, load_counter_provider); | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
opr::Loop, dump_loop, load_loop, nullptr, 2, | |||
CURRENT_VERSION); | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
InputMaker, dump_input_maker, load_input_maker, nullptr, 2, | |||
CURRENT_VERSION); | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2, | |||
CURRENT_VERSION); | |||
} | |||
void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
@@ -20,6 +20,11 @@ struct StaticData { | |||
ThinHashMap<Typeinfo*, OprRegistry*> type2reg; | |||
std::unordered_map<std::string, OprRegistry*> name2reg; | |||
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg; | |||
//! versioned OprRegistryV2, version_id_reg_map is used for Operator | |||
//! load/shallow copy and version_type_reg_map is used for Operator dump | |||
ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map; | |||
ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map; | |||
}; | |||
StaticData& static_data() { | |||
@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() { | |||
return ret; | |||
} | |||
const OprRegistryV2* dynamic_registry_v2() { | |||
static const OprRegistryV2* ret = nullptr; | |||
if (ret) | |||
return ret; | |||
auto id = MGB_HASH_STR("dynamic"); | |||
OprRegistryV2::versioned_add( | |||
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, | |||
CURRENT_VERSION); | |||
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); | |||
mgb_assert(ret); | |||
return ret; | |||
} | |||
class _Init { | |||
public: | |||
_Init() { | |||
@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) { | |||
auto registry_ins = sd.id2reg.emplace(persist_id, record); | |||
mgb_assert( | |||
registry_ins.second || persist_id == dynamic_registry()->persist_type_id, | |||
"duplicated operator persist_type_id: %s", | |||
std::to_string(persist_id).c_str()); | |||
"duplicated operator name : %s", record.name.c_str()); | |||
OprRegistry* persis_record_ptr; | |||
if (registry_ins.second) { | |||
@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) { | |||
return iter == uid2reg.end() ? nullptr : iter->second; | |||
} | |||
//! find the registry equal to the giving version | |||
const OprRegistryV2* OprRegistryV2::versioned_find_by_id( | |||
const size_t id, uint8_t version) { | |||
auto&& id_reg_map = static_data().version_id_reg_map; | |||
auto iter_version = id_reg_map.find(version); | |||
if (iter_version != id_reg_map.end()) { | |||
auto iter = iter_version->second.find(id); | |||
return iter == iter_version->second.end() ? nullptr : &iter->second; | |||
} | |||
return nullptr; | |||
} | |||
//! find the registry equal or below the giving version | |||
const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( | |||
Typeinfo* type, uint8_t version) { | |||
const auto& type_reg_map = static_data().version_type_reg_map; | |||
for (int version_id = version; version_id > 0; version_id--) { | |||
auto iter_version = type_reg_map.find(version_id); | |||
if (iter_version != type_reg_map.end()) { | |||
auto iter = iter_version->second.find(type); | |||
if (iter == iter_version->second.end()) { | |||
continue; | |||
} else { | |||
return iter->second; | |||
} | |||
} | |||
} | |||
return nullptr; | |||
} | |||
void OprRegistryV2::versioned_add( | |||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { | |||
mgb_assert(max_version >= min_version); | |||
auto&& sd = static_data(); | |||
auto id = record.type_id; | |||
uint64_t type_id = id; | |||
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 | |||
if (record.type && record.type->name) { | |||
type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); | |||
} | |||
for (uint8_t version = min_version; version <= max_version; version++) { | |||
auto&& registry_map = sd.version_id_reg_map[version]; | |||
auto versioned_record = record; | |||
versioned_record.version = version; | |||
mgb_assert( | |||
registry_map.find(id) == registry_map.end() || | |||
id == dynamic_registry_v2()->type_id, | |||
"dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||
auto registry_ins = registry_map.emplace(id, versioned_record); | |||
if (!registry_ins.second) { | |||
//! the registry is dynamic | |||
mgb_assert(!record.converter); | |||
registry_map[id] = versioned_record; | |||
} | |||
//! sometimes the register id and the hash typeinfo is not same, just as | |||
//! dynamic Operator | |||
if (id != type_id) { | |||
mgb_assert( | |||
registry_map.find(type_id) == registry_map.end(), | |||
"dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||
registry_map.emplace(type_id, versioned_record); | |||
} | |||
auto&& registry_type_map = sd.version_type_reg_map[version]; | |||
registry_type_map.emplace(record.type, ®istry_map[id]); | |||
} | |||
} | |||
void OprRegistry::add_using_dynamic_loader( | |||
Typeinfo* type, const std::string& name, const OprDumper& dumper) { | |||
// dynamic oprs are implemented by mapping different opr types to the same | |||
@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader( | |||
{}, | |||
{}, | |||
dynamic_registry()->unversioned_type_id}); | |||
mgb_assert(type, "type must be not nullptr"); | |||
OprRegistryV2::versioned_add( | |||
{type, dynamic_registry_v2()->type_id, type->name, dumper, | |||
dynamic_registry_v2()->loader, nullptr}, | |||
CURRENT_VERSION, CURRENT_VERSION); | |||
} | |||
#if MGB_ENABLE_DEBUG_UTIL | |||
@@ -9,10 +9,12 @@ void call_sereg() {} | |||
#include "../../opr/impl/blas.sereg.h" | |||
#include "../../opr/impl/cond.sereg.h" | |||
#include "../../opr/impl/dnn/dnn.sereg.h" | |||
#include "../../opr/impl/dnn/dnn.sereg.v2.h" | |||
#include "./extern_c_opr.sereg.h" | |||
#include "../../opr/impl/imgproc.sereg.h" | |||
#include "../../opr/impl/indexing.sereg.h" | |||
#include "../../opr/impl/io.sereg.h" | |||
#include "../../opr/impl/io.sereg.v2.h" | |||
#include "../../opr/impl/loop/forward.sereg.h" | |||
#include "../../opr/impl/loop/grad.sereg.h" | |||
#include "../../opr/impl/misc.sereg.h" | |||
@@ -53,7 +53,6 @@ struct OprRegistry { | |||
uint64_t unversioned_type_id; | |||
MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); | |||
/*! | |||
* \brief register an operator to use dynamic loader | |||
* | |||
@@ -89,6 +88,39 @@ struct OprRegistry { | |||
#endif | |||
}; | |||
//! Convert some modified Opr to compatible Opr | |||
using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*( | |||
cg::OperatorNodeBase*, const VarNodeArray&)>; | |||
//! record of a single operator | |||
struct OprRegistryV2 { | |||
Typeinfo* type; | |||
uint64_t type_id; | |||
std::string name; | |||
OprDumper dumper; | |||
OprLoaderWrapper loader; | |||
OprConvertToCompatible converter; | |||
uint8_t version = 2; | |||
MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; } | |||
//! register opr load/dump to version2regmap | |||
MGE_WIN_DECLSPEC_FUC static void versioned_add( | |||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); | |||
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( | |||
const size_t id, uint8_t version); | |||
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo( | |||
Typeinfo* type, uint8_t version); | |||
#if MGB_ENABLE_DEBUG_UTIL | |||
//! dump registered oprs | |||
MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>> | |||
dump_registries(); | |||
#endif | |||
}; | |||
} // namespace serialization | |||
} // namespace mgb | |||
@@ -3,6 +3,7 @@ | |||
#include "megbrain/serialization/opr_load_dump.h" | |||
#include "megbrain/serialization/opr_registry.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
namespace mgb { | |||
@@ -66,6 +67,9 @@ struct OprLoadDumpImpl { | |||
} | |||
}; | |||
template <class Opr, size_t arity> | |||
struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {}; | |||
#define IMPL_OPR_MAKER(_arity, _args...) \ | |||
template <class Opr> \ | |||
struct OprMaker<Opr, _arity> { \ | |||
@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
__caller_OprReg##_cls##_ins; \ | |||
} | |||
#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \ | |||
namespace { \ | |||
[[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \ | |||
__caller_V2_OprReg##_cls##_ins; \ | |||
} | |||
// Trim the terminating null character and a "V0" like suffix from the string | |||
// then hash it. | |||
// TODO: Get rid of this. | |||
@@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
: 0), \ | |||
20160701)>::val | |||
//! call OprRegistry::add | |||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
do { \ | |||
::mgb::serialization::OprRegistry::add( \ | |||
{_cls::typeinfo(), \ | |||
MGB_HASH_STR(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
_dump, \ | |||
_load, \ | |||
{}, \ | |||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
//! call OprRegistry::add for old serialization | |||
//! call OprRegistryV2::versioned_add for new serialization which is compatiable | |||
//! with old serialization, convert is nullptr, this registry is just only for | |||
//! varsion 1 | |||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
do { \ | |||
::mgb::serialization::OprRegistry::add( \ | |||
{_cls::typeinfo(), \ | |||
MGB_HASH_STR(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
_dump, \ | |||
_load, \ | |||
{}, \ | |||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
} while (0) | |||
//! call OprRegistryV2::versioned_add for new serialization, in which convert the | |||
//! function converter the Operator to the compatiable | |||
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
_cls, _dump, _load, _convert, _version_min, _version_max) \ | |||
do { \ | |||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
{_cls::typeinfo(), MGB_HASH_STR(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ | |||
_version_min, _version_max); \ | |||
} while (0) | |||
/*! | |||
@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is | |||
//! optional , if not implement pass nullptr | |||
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | |||
namespace { \ | |||
namespace ser = ::mgb::serialization; \ | |||
struct _OprRegV2##_cls { \ | |||
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ | |||
static ser::OprWithOutputAccessor wrap_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
} \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
_cls, Impl::dump, wrap_loader, _converter, _version_min, \ | |||
_version_max); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | |||
//! use to check type is complete or not, midout need a complete type | |||
template <class T, class = void> | |||
struct IsComplete : std::false_type {}; | |||