GitOrigin-RevId: 8eacd5e77c
release-1.10
@@ -153,4 +153,6 @@ struct EnsureHashConstexpr { | |||||
#define MGB_HASH_STR(v) \ | #define MGB_HASH_STR(v) \ | ||||
::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val | ::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val | ||||
#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||||
} \ | } \ | ||||
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | ||||
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | |||||
namespace { \ | |||||
struct _OprRegV2##cls { \ | |||||
static void entry() { \ | |||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||||
cls, ::mgb::serialization::custom_dumper, \ | |||||
::mgb::serialization::custom_loader, nullptr, _version_min, \ | |||||
_version_max); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls) | |||||
using namespace mgb; | using namespace mgb; | ||||
using CustomOpNode = opr::CustomOpNode; | using CustomOpNode = opr::CustomOpNode; | ||||
CUSTOM_OP_SEREG_REG(CustomOpNode); | CUSTOM_OP_SEREG_REG(CustomOpNode); | ||||
CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION); |
@@ -0,0 +1,228 @@ | |||||
#include "megbrain/graph/symbol_var.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#if MGB_ENABLE_FBS_SERIALIZATION | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||||
#include "megbrain/serialization/sereg.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
namespace mgb { | |||||
namespace serialization { | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Softmax, 1> { | |||||
using Opr = opr::Softmax; | |||||
using PersisParam = opr::Softmax::Param; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||||
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||||
} | |||||
static cg::OperatorNodeBase* replace_opr( | |||||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||||
int32_t axis = opr->cast_final_safe<Opr>().param().axis; | |||||
auto input_var = inputs[0]; | |||||
auto max_reduce_out = | |||||
opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis}); | |||||
auto elemwise_sub_out = opr::Elemwise::make( | |||||
{input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB}); | |||||
auto elemwise_exp_out = | |||||
opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP}); | |||||
auto sum_reduce_out = | |||||
opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis}); | |||||
auto out = opr::Elemwise::make( | |||||
{elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV}); | |||||
return out.node()->owner_opr(); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto param = fbs_ctx.read_param<PersisParam>(0); | |||||
return Opr::make(inputs[0], param, config).node()->owner_opr(); | |||||
} | |||||
}; | |||||
template < | |||||
class Opr, class Maker0, class MegDNNConv, | |||||
class Maker1 = MakeConvCallerEmpty<MegDNNConv>, | |||||
class Maker2 = MakeConvCallerEmpty<MegDNNConv>, | |||||
typename ConvParam = megdnn::param::Convolution> | |||||
struct WithPolicyOprLoadDumpImpl { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
ctx.write_param<ConvParam>(opr.param()); | |||||
ctx.write_param<megdnn::param::ExecutionPolicy>( | |||||
opr.execution_policy_transient()); | |||||
} | |||||
static VarNode* make( | |||||
const cg::VarNodeArray& inputs, const ConvParam& param, | |||||
const megdnn::param::ExecutionPolicy& execution_policy, | |||||
const OperatorNodeConfig& config) { | |||||
VarNode* ret = | |||||
Maker0::template make<Opr>(inputs, param, execution_policy, config); | |||||
if (!ret) { | |||||
ret = Maker1::template make<Opr>(inputs, param, execution_policy, config); | |||||
} | |||||
if (!ret) { | |||||
ret = Maker2::template make<Opr>(inputs, param, execution_policy, config); | |||||
} | |||||
mgb_assert(ret); | |||||
return ret; | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||||
fbs_ctx.get_current_opr_data()); | |||||
auto conv_param = fbs_ctx.read_param<ConvParam>(0); | |||||
megdnn::param::ExecutionPolicy policy; | |||||
if (fopr->additional_params() && fopr->additional_params()->size()) { | |||||
policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1); | |||||
} | |||||
return make(inputs, conv_param, policy, config)->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Convolution, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::Convolution, MakeConvCaller2<megdnn::Convolution>, | |||||
megdnn::Convolution> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>, | |||||
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>, | |||||
megdnn::Convolution> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Convolution3D, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>, | |||||
megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>, | |||||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||||
megdnn::param::Convolution3D> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::Convolution3DBackwardData, | |||||
MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D, | |||||
MakeConvCaller3<megdnn::Convolution3D>, | |||||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||||
megdnn::param::Convolution3D> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::Convolution3DBackwardFilter, | |||||
MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D, | |||||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||||
MakeConvCallerEmpty<megdnn::Convolution3D>, | |||||
megdnn::param::Convolution3D> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::ConvBiasForward, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>, | |||||
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>, | |||||
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::BatchConvBiasForward, | |||||
MakeConvCaller2<megdnn::BatchConvBiasForward>, | |||||
megdnn::BatchConvBiasForward, | |||||
MakeConvCaller3<megdnn::BatchConvBiasForward>, | |||||
MakeConvCaller4<megdnn::BatchConvBiasForward>, | |||||
megdnn::param::BatchConvBias> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::LocalShare, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>, | |||||
megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
megdnn::param::LocalShare> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::LocalShareBackwardData, | |||||
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
megdnn::param::LocalShare> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::LocalShareBackwardFilter, | |||||
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||||
megdnn::param::LocalShare> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::DeformableConvForward, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::DeformableConvForward, | |||||
MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> { | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::DeformableConvBackwardData, | |||||
MakeConvCaller5<megdnn::DeformableConvBackwardData>, | |||||
megdnn::Convolution> {}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0> | |||||
: public WithPolicyOprLoadDumpImpl< | |||||
opr::DeformableConvBackwardFilter, | |||||
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | |||||
megdnn::Convolution> {}; | |||||
} // namespace serialization | |||||
namespace opr { | |||||
#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \ | |||||
MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION); | |||||
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||||
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||||
SERGE_OPR_V2_CONVERTER( | |||||
Softmax, 1, | |||||
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr)); | |||||
SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | |||||
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(Convolution, 0) | |||||
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0) | |||||
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0) | |||||
SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0); | |||||
#undef SERGE_OPR_V2_CONVERTER | |||||
#undef SERGE_OPR_V2_NO_CONVERTER | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,197 @@ | |||||
#if MGB_ENABLE_FBS_SERIALIZATION | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/opr/dnn/softmax.h" | |||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||||
#include "megbrain/serialization/sereg.h" | |||||
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" | |||||
#include "megbrain/serialization/internal/schema_v2_generated.h" | |||||
namespace mgb { | |||||
namespace serialization { | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> { | |||||
using Opr = opr::ImmutableTensor; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
using Meth = OprDumpContext::TensorWriteMethod; | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
ctx.dump_tensor( | |||||
{}, HostTensorND{}.copy_from(opr.value()).sync(), | |||||
Meth::VALUE_ANONYMOUS); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
mgb_assert(inputs.empty()); | |||||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||||
fbs_ctx.get_current_opr_data()); | |||||
if (fopr->tensors() && fopr->tensors()->size() > 0) { | |||||
auto val = fbs_ctx.load_tensor(); | |||||
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr(); | |||||
} else { | |||||
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data."); | |||||
} | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> { | |||||
using Opr = opr::Host2DeviceCopy; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
ctx.write_param(opr.param()); | |||||
using Meth = OprDumpContext::TensorWriteMethod; | |||||
ctx.dump_tensor( | |||||
opr.name(), *opr.host_data(), | |||||
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(inputs.empty()); | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto param = fbs_ctx.read_param<Opr::Param>(0); | |||||
auto tensor = fbs_ctx.load_tensor(); | |||||
return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
using Meth = OprDumpContext::TensorWriteMethod; | |||||
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | |||||
HostTensorND val; | |||||
val.copy_from(opr.get_dev_tensor()).sync(); | |||||
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(inputs.empty()); | |||||
auto val = ctx.load_tensor(); | |||||
auto dev_val = | |||||
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout()); | |||||
dev_val->copy_from_fixlayout(*val); | |||||
auto out_var = | |||||
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | |||||
dev_val->sync(); | |||||
return out_var.node()->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> { | |||||
using Opr = opr::MultipleDeviceTensorHolder; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
using Meth = OprDumpContext::TensorWriteMethod; | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
uint32_t nr_val = opr.values().size(); | |||||
for (uint32_t i = 0; i < nr_val; ++i) { | |||||
HostTensorND val; | |||||
val.copy_from(*opr.values()[i]).sync(); | |||||
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||||
} | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(inputs.empty()); | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||||
fbs_ctx.get_current_opr_data()); | |||||
uint32_t nr = 0; | |||||
if (fopr && fopr->tensors()) { | |||||
nr = fopr->tensors()->size(); | |||||
} | |||||
Opr::ValueArray values(nr); | |||||
for (auto&& i : values) { | |||||
i = ctx.load_tensor_shared(); | |||||
} | |||||
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||||
using Opr = opr::MultipleDeviceTensorWithFormatHolder; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
using Meth = OprDumpContext::TensorWriteMethod; | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
uint32_t nr_val = opr.values().size(); | |||||
for (uint32_t i = 0; i < nr_val; ++i) { | |||||
HostTensorND val; | |||||
auto value = *opr.values()[i]; | |||||
val.copy_from(value).sync(); | |||||
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||||
} | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(inputs.empty()); | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||||
fbs_ctx.get_current_opr_data()); | |||||
uint32_t nr = 0; | |||||
if (fopr && fopr->tensors()) { | |||||
nr = fopr->tensors()->size(); | |||||
} | |||||
Opr::ValueArray values(nr); | |||||
for (auto&& i : values) { | |||||
i = ctx.load_tensor_shared(); | |||||
//! set tensor format | |||||
TensorLayout layout_with_format = i->layout(); | |||||
if (i->storage().comp_node().mem_node() == | |||||
CompNode::default_cpu().mem_node()) { | |||||
mgb_assert( | |||||
i->storage().ptr(), | |||||
"storage should not be nullptr if mem_node is " | |||||
"default_cpu"); | |||||
HostTensorND src{i->storage().comp_node(), layout_with_format}; | |||||
src.copy_from_fixlayout(*i).sync(); | |||||
*i = DeviceTensorND::make_proxy(src); | |||||
} else { | |||||
//! actually only layout of this tensor will be used later, see | |||||
//! src/serialization/impl/batched_device_value_loader.cpp:49. But we | |||||
//! have no way to reset layout only, so just construct a invalid | |||||
//! storage instead | |||||
auto size = layout_with_format.span().dist_byte(); | |||||
DeviceTensorStorage storage; | |||||
storage.reset(i->comp_node(), size, nullptr); | |||||
i->reset(storage, layout_with_format); | |||||
} | |||||
} | |||||
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||||
} | |||||
}; | |||||
} // namespace serialization | |||||
namespace opr { | |||||
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||||
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||||
SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0); | |||||
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0); | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() { | |||||
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | ||||
MGB_SEREG_OPR_INTL_CALL_ADD( | MGB_SEREG_OPR_INTL_CALL_ADD( | ||||
CounterProvider, dump_counter_provider, load_counter_provider); | CounterProvider, dump_counter_provider, load_counter_provider); | ||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||||
opr::Loop, dump_loop, load_loop, nullptr, 2, | |||||
CURRENT_VERSION); | |||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||||
InputMaker, dump_input_maker, load_input_maker, nullptr, 2, | |||||
CURRENT_VERSION); | |||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||||
CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2, | |||||
CURRENT_VERSION); | |||||
} | } | ||||
void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | ||||
@@ -20,6 +20,11 @@ struct StaticData { | |||||
ThinHashMap<Typeinfo*, OprRegistry*> type2reg; | ThinHashMap<Typeinfo*, OprRegistry*> type2reg; | ||||
std::unordered_map<std::string, OprRegistry*> name2reg; | std::unordered_map<std::string, OprRegistry*> name2reg; | ||||
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg; | ThinHashMap<size_t, OprRegistry*> unversioned_id2reg; | ||||
//! versioned OprRegistryV2, version_id_reg_map is used for Operator | |||||
//! load/shallow copy and version_type_reg_map is used for Operator dump | |||||
ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map; | |||||
ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map; | |||||
}; | }; | ||||
StaticData& static_data() { | StaticData& static_data() { | ||||
@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() { | |||||
return ret; | return ret; | ||||
} | } | ||||
const OprRegistryV2* dynamic_registry_v2() { | |||||
static const OprRegistryV2* ret = nullptr; | |||||
if (ret) | |||||
return ret; | |||||
auto id = MGB_HASH_STR("dynamic"); | |||||
OprRegistryV2::versioned_add( | |||||
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, | |||||
CURRENT_VERSION); | |||||
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); | |||||
mgb_assert(ret); | |||||
return ret; | |||||
} | |||||
class _Init { | class _Init { | ||||
public: | public: | ||||
_Init() { | _Init() { | ||||
@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) { | |||||
auto registry_ins = sd.id2reg.emplace(persist_id, record); | auto registry_ins = sd.id2reg.emplace(persist_id, record); | ||||
mgb_assert( | mgb_assert( | ||||
registry_ins.second || persist_id == dynamic_registry()->persist_type_id, | registry_ins.second || persist_id == dynamic_registry()->persist_type_id, | ||||
"duplicated operator persist_type_id: %s", | |||||
std::to_string(persist_id).c_str()); | |||||
"duplicated operator name : %s", record.name.c_str()); | |||||
OprRegistry* persis_record_ptr; | OprRegistry* persis_record_ptr; | ||||
if (registry_ins.second) { | if (registry_ins.second) { | ||||
@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) { | |||||
return iter == uid2reg.end() ? nullptr : iter->second; | return iter == uid2reg.end() ? nullptr : iter->second; | ||||
} | } | ||||
//! find the registry equal to the giving version | |||||
const OprRegistryV2* OprRegistryV2::versioned_find_by_id( | |||||
const size_t id, uint8_t version) { | |||||
auto&& id_reg_map = static_data().version_id_reg_map; | |||||
auto iter_version = id_reg_map.find(version); | |||||
if (iter_version != id_reg_map.end()) { | |||||
auto iter = iter_version->second.find(id); | |||||
return iter == iter_version->second.end() ? nullptr : &iter->second; | |||||
} | |||||
return nullptr; | |||||
} | |||||
//! find the registry equal or below the giving version | |||||
const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( | |||||
Typeinfo* type, uint8_t version) { | |||||
const auto& type_reg_map = static_data().version_type_reg_map; | |||||
for (int version_id = version; version_id > 0; version_id--) { | |||||
auto iter_version = type_reg_map.find(version_id); | |||||
if (iter_version != type_reg_map.end()) { | |||||
auto iter = iter_version->second.find(type); | |||||
if (iter == iter_version->second.end()) { | |||||
continue; | |||||
} else { | |||||
return iter->second; | |||||
} | |||||
} | |||||
} | |||||
return nullptr; | |||||
} | |||||
void OprRegistryV2::versioned_add( | |||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { | |||||
mgb_assert(max_version >= min_version); | |||||
auto&& sd = static_data(); | |||||
auto id = record.type_id; | |||||
uint64_t type_id = id; | |||||
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 | |||||
if (record.type && record.type->name) { | |||||
type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); | |||||
} | |||||
for (uint8_t version = min_version; version <= max_version; version++) { | |||||
auto&& registry_map = sd.version_id_reg_map[version]; | |||||
auto versioned_record = record; | |||||
versioned_record.version = version; | |||||
mgb_assert( | |||||
registry_map.find(id) == registry_map.end() || | |||||
id == dynamic_registry_v2()->type_id, | |||||
"dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||||
auto registry_ins = registry_map.emplace(id, versioned_record); | |||||
if (!registry_ins.second) { | |||||
//! the registry is dynamic | |||||
mgb_assert(!record.converter); | |||||
registry_map[id] = versioned_record; | |||||
} | |||||
//! sometimes the register id and the hash typeinfo is not same, just as | |||||
//! dynamic Operator | |||||
if (id != type_id) { | |||||
mgb_assert( | |||||
registry_map.find(type_id) == registry_map.end(), | |||||
"dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||||
registry_map.emplace(type_id, versioned_record); | |||||
} | |||||
auto&& registry_type_map = sd.version_type_reg_map[version]; | |||||
registry_type_map.emplace(record.type, ®istry_map[id]); | |||||
} | |||||
} | |||||
void OprRegistry::add_using_dynamic_loader( | void OprRegistry::add_using_dynamic_loader( | ||||
Typeinfo* type, const std::string& name, const OprDumper& dumper) { | Typeinfo* type, const std::string& name, const OprDumper& dumper) { | ||||
// dynamic oprs are implemented by mapping different opr types to the same | // dynamic oprs are implemented by mapping different opr types to the same | ||||
@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader( | |||||
{}, | {}, | ||||
{}, | {}, | ||||
dynamic_registry()->unversioned_type_id}); | dynamic_registry()->unversioned_type_id}); | ||||
mgb_assert(type, "type must be not nullptr"); | |||||
OprRegistryV2::versioned_add( | |||||
{type, dynamic_registry_v2()->type_id, type->name, dumper, | |||||
dynamic_registry_v2()->loader, nullptr}, | |||||
CURRENT_VERSION, CURRENT_VERSION); | |||||
} | } | ||||
#if MGB_ENABLE_DEBUG_UTIL | #if MGB_ENABLE_DEBUG_UTIL | ||||
@@ -9,10 +9,12 @@ void call_sereg() {} | |||||
#include "../../opr/impl/blas.sereg.h" | #include "../../opr/impl/blas.sereg.h" | ||||
#include "../../opr/impl/cond.sereg.h" | #include "../../opr/impl/cond.sereg.h" | ||||
#include "../../opr/impl/dnn/dnn.sereg.h" | #include "../../opr/impl/dnn/dnn.sereg.h" | ||||
#include "../../opr/impl/dnn/dnn.sereg.v2.h" | |||||
#include "./extern_c_opr.sereg.h" | #include "./extern_c_opr.sereg.h" | ||||
#include "../../opr/impl/imgproc.sereg.h" | #include "../../opr/impl/imgproc.sereg.h" | ||||
#include "../../opr/impl/indexing.sereg.h" | #include "../../opr/impl/indexing.sereg.h" | ||||
#include "../../opr/impl/io.sereg.h" | #include "../../opr/impl/io.sereg.h" | ||||
#include "../../opr/impl/io.sereg.v2.h" | |||||
#include "../../opr/impl/loop/forward.sereg.h" | #include "../../opr/impl/loop/forward.sereg.h" | ||||
#include "../../opr/impl/loop/grad.sereg.h" | #include "../../opr/impl/loop/grad.sereg.h" | ||||
#include "../../opr/impl/misc.sereg.h" | #include "../../opr/impl/misc.sereg.h" | ||||
@@ -53,7 +53,6 @@ struct OprRegistry { | |||||
uint64_t unversioned_type_id; | uint64_t unversioned_type_id; | ||||
MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); | MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); | ||||
/*! | /*! | ||||
* \brief register an operator to use dynamic loader | * \brief register an operator to use dynamic loader | ||||
* | * | ||||
@@ -89,6 +88,39 @@ struct OprRegistry { | |||||
#endif | #endif | ||||
}; | }; | ||||
//! Convert some modified Opr to compatible Opr | |||||
using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*( | |||||
cg::OperatorNodeBase*, const VarNodeArray&)>; | |||||
//! record of a single operator | |||||
struct OprRegistryV2 { | |||||
Typeinfo* type; | |||||
uint64_t type_id; | |||||
std::string name; | |||||
OprDumper dumper; | |||||
OprLoaderWrapper loader; | |||||
OprConvertToCompatible converter; | |||||
uint8_t version = 2; | |||||
MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; } | |||||
//! register opr load/dump to version2regmap | |||||
MGE_WIN_DECLSPEC_FUC static void versioned_add( | |||||
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); | |||||
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( | |||||
const size_t id, uint8_t version); | |||||
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo( | |||||
Typeinfo* type, uint8_t version); | |||||
#if MGB_ENABLE_DEBUG_UTIL | |||||
//! dump registered oprs | |||||
MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>> | |||||
dump_registries(); | |||||
#endif | |||||
}; | |||||
} // namespace serialization | } // namespace serialization | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -3,6 +3,7 @@ | |||||
#include "megbrain/serialization/opr_load_dump.h" | #include "megbrain/serialization/opr_load_dump.h" | ||||
#include "megbrain/serialization/opr_registry.h" | #include "megbrain/serialization/opr_registry.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||||
#include "megbrain/utils/hash_ct.h" | #include "megbrain/utils/hash_ct.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -66,6 +67,9 @@ struct OprLoadDumpImpl { | |||||
} | } | ||||
}; | }; | ||||
template <class Opr, size_t arity> | |||||
struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {}; | |||||
#define IMPL_OPR_MAKER(_arity, _args...) \ | #define IMPL_OPR_MAKER(_arity, _args...) \ | ||||
template <class Opr> \ | template <class Opr> \ | ||||
struct OprMaker<Opr, _arity> { \ | struct OprMaker<Opr, _arity> { \ | ||||
@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
__caller_OprReg##_cls##_ins; \ | __caller_OprReg##_cls##_ins; \ | ||||
} | } | ||||
#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \ | |||||
namespace { \ | |||||
[[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \ | |||||
__caller_V2_OprReg##_cls##_ins; \ | |||||
} | |||||
// Trim the terminating null character and a "V0" like suffix from the string | // Trim the terminating null character and a "V0" like suffix from the string | ||||
// then hash it. | // then hash it. | ||||
// TODO: Get rid of this. | // TODO: Get rid of this. | ||||
@@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
: 0), \ | : 0), \ | ||||
20160701)>::val | 20160701)>::val | ||||
//! call OprRegistry::add | |||||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||||
do { \ | |||||
::mgb::serialization::OprRegistry::add( \ | |||||
{_cls::typeinfo(), \ | |||||
MGB_HASH_STR(#_cls), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||||
_dump, \ | |||||
_load, \ | |||||
{}, \ | |||||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||||
//! call OprRegistry::add for old serialization | |||||
//! call OprRegistryV2::versioned_add for new serialization which is compatiable | |||||
//! with old serialization, convert is nullptr, this registry is just only for | |||||
//! varsion 1 | |||||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||||
do { \ | |||||
::mgb::serialization::OprRegistry::add( \ | |||||
{_cls::typeinfo(), \ | |||||
MGB_HASH_STR(#_cls), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||||
_dump, \ | |||||
_load, \ | |||||
{}, \ | |||||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||||
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||||
::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||||
} while (0) | |||||
//! call OprRegistryV2::versioned_add for new serialization, in which convert the | |||||
//! function converter the Operator to the compatiable | |||||
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||||
_cls, _dump, _load, _convert, _version_min, _version_max) \ | |||||
do { \ | |||||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||||
{_cls::typeinfo(), MGB_HASH_STR(#_cls), \ | |||||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ | |||||
_version_min, _version_max); \ | |||||
} while (0) | } while (0) | ||||
/*! | /*! | ||||
@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
} \ | } \ | ||||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | ||||
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is | |||||
//! optional , if not implement pass nullptr | |||||
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | |||||
namespace { \ | |||||
namespace ser = ::mgb::serialization; \ | |||||
struct _OprRegV2##_cls { \ | |||||
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ | |||||
static ser::OprWithOutputAccessor wrap_loader( \ | |||||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
const mgb::cg::OperatorNodeConfig& config) { \ | |||||
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||||
} \ | |||||
static void entry() { \ | |||||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||||
_cls, Impl::dump, wrap_loader, _converter, _version_min, \ | |||||
_version_max); \ | |||||
} \ | |||||
}; \ | |||||
} \ | |||||
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | |||||
//! use to check type is complete or not, midout need a complete type | //! use to check type is complete or not, midout need a complete type | ||||
template <class T, class = void> | template <class T, class = void> | ||||
struct IsComplete : std::false_type {}; | struct IsComplete : std::false_type {}; | ||||