Browse Source

feat(serialization): support the registry for new serialization format

GitOrigin-RevId: 8eacd5e77c
release-1.10
Megvii Engine Team 3 years ago
parent
commit
50faabf614
9 changed files with 639 additions and 14 deletions
  1. +2
    -0
      src/core/include/megbrain/utils/hash_ct.h
  2. +15
    -0
      src/opr/impl/custom_opnode.sereg.h
  3. +228
    -0
      src/opr/impl/dnn/dnn.sereg.v2.h
  4. +197
    -0
      src/opr/impl/io.sereg.v2.h
  5. +10
    -0
      src/opr/impl/loop/forward_sereg.cpp
  6. +92
    -2
      src/serialization/impl/opr_registry.cpp
  7. +2
    -0
      src/serialization/impl/sereg_caller.cpp
  8. +33
    -1
      src/serialization/include/megbrain/serialization/opr_registry.h
  9. +60
    -11
      src/serialization/include/megbrain/serialization/sereg.h

+ 2
- 0
src/core/include/megbrain/utils/hash_ct.h View File

@@ -153,4 +153,6 @@ struct EnsureHashConstexpr {
#define MGB_HASH_STR(v) \
::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val

#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701)

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 15
- 0
src/opr/impl/custom_opnode.sereg.h View File

@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader(
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)

#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \
namespace { \
struct _OprRegV2##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader, nullptr, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls)

using namespace mgb;
using CustomOpNode = opr::CustomOpNode;
CUSTOM_OP_SEREG_REG(CustomOpNode);

CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);

+ 228
- 0
src/opr/impl/dnn/dnn.sereg.v2.h View File

@@ -0,0 +1,228 @@
#include "megbrain/graph/symbol_var.h"
#include "megdnn/oprs/general.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"

namespace mgb {
namespace serialization {

template <>
struct OprLoadDumpImplV2<opr::Softmax, 1> {
using Opr = opr::Softmax;
using PersisParam = opr::Softmax::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param());
}

static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
int32_t axis = opr->cast_final_safe<Opr>().param().axis;
auto input_var = inputs[0];
auto max_reduce_out =
opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis});
auto elemwise_sub_out = opr::Elemwise::make(
{input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB});
auto elemwise_exp_out =
opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP});
auto sum_reduce_out =
opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis});
auto out = opr::Elemwise::make(
{elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV});
return out.node()->owner_opr();
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<PersisParam>(0);
return Opr::make(inputs[0], param, config).node()->owner_opr();
}
};

template <
class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution>
struct WithPolicyOprLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy_transient());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret =
Maker0::template make<Opr>(inputs, param, execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
}
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
}
mgb_assert(ret);
return ret;
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
auto conv_param = fbs_ctx.read_param<ConvParam>(0);
megdnn::param::ExecutionPolicy policy;
if (fopr->additional_params() && fopr->additional_params()->size()) {
policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1);
}
return make(inputs, conv_param, policy, config)->owner_opr();
}
};

template <>
struct OprLoadDumpImplV2<opr::Convolution, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
megdnn::Convolution> {};

template <>
struct OprLoadDumpImplV2<opr::Convolution3D, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardData,
MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCaller3<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardFilter,
MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::ConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
template <>
struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};

template <>
struct OprLoadDumpImplV2<opr::LocalShare, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardData,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardFilter,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvForward,
MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardData,
MakeConvCaller5<megdnn::DeformableConvBackwardData>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {};

} // namespace serialization

namespace opr {
#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \
MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION);

#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);

SERGE_OPR_V2_CONVERTER(
Softmax, 1,
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));

SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0)
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0);

SERGE_OPR_V2_NO_CONVERTER(Convolution, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0)

SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0);

SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0);

SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0);

#undef SERGE_OPR_V2_CONVERTER
#undef SERGE_OPR_V2_NO_CONVERTER
} // namespace opr

} // namespace mgb

#endif

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 197
- 0
src/opr/impl/io.sereg.v2.h View File

@@ -0,0 +1,197 @@
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"

#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"

namespace mgb {
namespace serialization {

template <>
struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> {
using Opr = opr::ImmutableTensor;

static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.dump_tensor(
{}, HostTensorND{}.copy_from(opr.value()).sync(),
Meth::VALUE_ANONYMOUS);
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
mgb_assert(inputs.empty());
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
if (fopr->tensors() && fopr->tensors()->size() > 0) {
auto val = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr();
} else {
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data.");
}
}
};

template <>
struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> {
using Opr = opr::Host2DeviceCopy;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param(opr.param());

using Meth = OprDumpContext::TensorWriteMethod;
ctx.dump_tensor(
opr.name(), *opr.host_data(),
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<Opr::Param>(0);
auto tensor = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr();
}
};

template <>
struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>();
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS);
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor();
auto dev_val =
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout());
dev_val->copy_from_fixlayout(*val);
auto out_var =
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config);
dev_val->sync();
return out_var.node()->owner_opr();
}
};

template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> {
using Opr = opr::MultipleDeviceTensorHolder;

static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
val.copy_from(*opr.values()[i]).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};

template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> {
using Opr = opr::MultipleDeviceTensorWithFormatHolder;

static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
auto value = *opr.values()[i];
val.copy_from(value).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}

static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
//! set tensor format
TensorLayout layout_with_format = i->layout();

if (i->storage().comp_node().mem_node() ==
CompNode::default_cpu().mem_node()) {
mgb_assert(
i->storage().ptr(),
"storage should not be nullptr if mem_node is "
"default_cpu");
HostTensorND src{i->storage().comp_node(), layout_with_format};
src.copy_from_fixlayout(*i).sync();
*i = DeviceTensorND::make_proxy(src);
} else {
//! actually only layout of this tensor will be used later, see
//! src/serialization/impl/batched_device_value_loader.cpp:49. But we
//! have no way to reset layout only, so just construct a invalid
//! storage instead
auto size = layout_with_format.span().dist_byte();
DeviceTensorStorage storage;
storage.reset(i->comp_node(), size, nullptr);
i->reset(storage, layout_with_format);
}
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};

} // namespace serialization

namespace opr {
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);

SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0);
SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0);
SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0);
} // namespace opr
} // namespace mgb

#endif

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 10
- 0
src/opr/impl/loop/forward_sereg.cpp View File

@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() {
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker);
MGB_SEREG_OPR_INTL_CALL_ADD(
CounterProvider, dump_counter_provider, load_counter_provider);

MGB_SEREG_OPR_INTL_CALL_ADD_V2(
opr::Loop, dump_loop, load_loop, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
InputMaker, dump_input_maker, load_input_maker, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2,
CURRENT_VERSION);
}

void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {


+ 92
- 2
src/serialization/impl/opr_registry.cpp View File

@@ -20,6 +20,11 @@ struct StaticData {
ThinHashMap<Typeinfo*, OprRegistry*> type2reg;
std::unordered_map<std::string, OprRegistry*> name2reg;
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg;

//! versioned OprRegistryV2, version_id_reg_map is used for Operator
//! load/shallow copy and version_type_reg_map is used for Operator dump
ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map;
ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map;
};

StaticData& static_data() {
@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() {
return ret;
}

const OprRegistryV2* dynamic_registry_v2() {
static const OprRegistryV2* ret = nullptr;
if (ret)
return ret;

auto id = MGB_HASH_STR("dynamic");
OprRegistryV2::versioned_add(
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION,
CURRENT_VERSION);
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION);
mgb_assert(ret);
return ret;
}

class _Init {
public:
_Init() {
@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) {
auto registry_ins = sd.id2reg.emplace(persist_id, record);
mgb_assert(
registry_ins.second || persist_id == dynamic_registry()->persist_type_id,
"duplicated operator persist_type_id: %s",
std::to_string(persist_id).c_str());
"duplicated operator name : %s", record.name.c_str());

OprRegistry* persis_record_ptr;
if (registry_ins.second) {
@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) {
return iter == uid2reg.end() ? nullptr : iter->second;
}

//! find the registry equal to the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_id(
const size_t id, uint8_t version) {
auto&& id_reg_map = static_data().version_id_reg_map;
auto iter_version = id_reg_map.find(version);
if (iter_version != id_reg_map.end()) {
auto iter = iter_version->second.find(id);
return iter == iter_version->second.end() ? nullptr : &iter->second;
}
return nullptr;
}
//! find the registry equal or below the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version) {
const auto& type_reg_map = static_data().version_type_reg_map;
for (int version_id = version; version_id > 0; version_id--) {
auto iter_version = type_reg_map.find(version_id);
if (iter_version != type_reg_map.end()) {
auto iter = iter_version->second.find(type);
if (iter == iter_version->second.end()) {
continue;
} else {
return iter->second;
}
}
}
return nullptr;
}

void OprRegistryV2::versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) {
mgb_assert(max_version >= min_version);

auto&& sd = static_data();
auto id = record.type_id;
uint64_t type_id = id;
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0
if (record.type && record.type->name) {
type_id = MGB_HASH_RUNTIME(std::string(record.type->name));
}
for (uint8_t version = min_version; version <= max_version; version++) {
auto&& registry_map = sd.version_id_reg_map[version];
auto versioned_record = record;
versioned_record.version = version;
mgb_assert(
registry_map.find(id) == registry_map.end() ||
id == dynamic_registry_v2()->type_id,
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
auto registry_ins = registry_map.emplace(id, versioned_record);
if (!registry_ins.second) {
//! the registry is dynamic
mgb_assert(!record.converter);
registry_map[id] = versioned_record;
}
//! sometimes the register id and the hash typeinfo is not same, just as
//! dynamic Operator
if (id != type_id) {
mgb_assert(
registry_map.find(type_id) == registry_map.end(),
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
registry_map.emplace(type_id, versioned_record);
}
auto&& registry_type_map = sd.version_type_reg_map[version];
registry_type_map.emplace(record.type, &registry_map[id]);
}
}

void OprRegistry::add_using_dynamic_loader(
Typeinfo* type, const std::string& name, const OprDumper& dumper) {
// dynamic oprs are implemented by mapping different opr types to the same
@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader(
{},
{},
dynamic_registry()->unversioned_type_id});
mgb_assert(type, "type must be not nullptr");
OprRegistryV2::versioned_add(
{type, dynamic_registry_v2()->type_id, type->name, dumper,
dynamic_registry_v2()->loader, nullptr},
CURRENT_VERSION, CURRENT_VERSION);
}

#if MGB_ENABLE_DEBUG_UTIL


+ 2
- 0
src/serialization/impl/sereg_caller.cpp View File

@@ -9,10 +9,12 @@ void call_sereg() {}
#include "../../opr/impl/blas.sereg.h"
#include "../../opr/impl/cond.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.v2.h"
#include "./extern_c_opr.sereg.h"
#include "../../opr/impl/imgproc.sereg.h"
#include "../../opr/impl/indexing.sereg.h"
#include "../../opr/impl/io.sereg.h"
#include "../../opr/impl/io.sereg.v2.h"
#include "../../opr/impl/loop/forward.sereg.h"
#include "../../opr/impl/loop/grad.sereg.h"
#include "../../opr/impl/misc.sereg.h"


+ 33
- 1
src/serialization/include/megbrain/serialization/opr_registry.h View File

@@ -53,7 +53,6 @@ struct OprRegistry {
uint64_t unversioned_type_id;

MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record);

/*!
* \brief register an operator to use dynamic loader
*
@@ -89,6 +88,39 @@ struct OprRegistry {
#endif
};

//! Convert some modified Opr to compatible Opr
using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*(
cg::OperatorNodeBase*, const VarNodeArray&)>;

//! record of a single operator
struct OprRegistryV2 {
Typeinfo* type;
uint64_t type_id;
std::string name;
OprDumper dumper;
OprLoaderWrapper loader;
OprConvertToCompatible converter;
uint8_t version = 2;

MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; }

//! register opr load/dump to version2regmap
MGE_WIN_DECLSPEC_FUC static void versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version);

MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id(
const size_t id, uint8_t version);

MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version);

#if MGB_ENABLE_DEBUG_UTIL
//! dump registered oprs
MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>>
dump_registries();
#endif
};

} // namespace serialization
} // namespace mgb



+ 60
- 11
src/serialization/include/megbrain/serialization/sereg.h View File

@@ -3,6 +3,7 @@
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_registry.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/utils/hash_ct.h"

namespace mgb {
@@ -66,6 +67,9 @@ struct OprLoadDumpImpl {
}
};

template <class Opr, size_t arity>
struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {};

#define IMPL_OPR_MAKER(_arity, _args...) \
template <class Opr> \
struct OprMaker<Opr, _arity> { \
@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
__caller_OprReg##_cls##_ins; \
}

#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \
namespace { \
[[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \
__caller_V2_OprReg##_cls##_ins; \
}

// Trim the terminating null character and a "V0" like suffix from the string
// then hash it.
// TODO: Get rid of this.
@@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
: 0), \
20160701)>::val

//! call OprRegistry::add
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \
do { \
::mgb::serialization::OprRegistry::add( \
{_cls::typeinfo(), \
MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \
_dump, \
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
//! call OprRegistry::add for old serialization
//! call OprRegistryV2::versioned_add for new serialization which is compatiable
//! with old serialization, convert is nullptr, this registry is just only for
//! varsion 1
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \
do { \
::mgb::serialization::OprRegistry::add( \
{_cls::typeinfo(), \
MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \
_dump, \
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
} while (0)

//! call OprRegistryV2::versioned_add for new serialization, in which convert the
//! function converter the Operator to the compatiable
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, _dump, _load, _convert, _version_min, _version_max) \
do { \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \
_version_min, _version_max); \
} while (0)

/*!
@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)

//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! optional , if not implement pass nullptr
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprRegV2##_cls { \
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, Impl::dump, wrap_loader, _converter, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls)

//! use to check type is complete or not, midout need a complete type
template <class T, class = void>
struct IsComplete : std::false_type {};


Loading…
Cancel
Save