GitOrigin-RevId: ee58271276
HuaHua404-patch-4
@@ -40,16 +40,16 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||
} // namespace serialization | |||
} // namespace mgb | |||
#define CUSTOM_OP_SEREG_REG(cls) \ | |||
namespace { \ | |||
struct _OprReg##cls { \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||
cls, ::mgb::serialization::custom_dumper, \ | |||
::mgb::serialization::custom_loader); \ | |||
} \ | |||
}; \ | |||
} \ | |||
#define CUSTOM_OP_SEREG_REG(cls) \ | |||
namespace { \ | |||
struct _OprReg##cls { \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||
cls, ::mgb::serialization::custom_dumper, \ | |||
::mgb::serialization::custom_loader, true); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | |||
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | |||
@@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop( | |||
} | |||
void LoopSerializer::reg_all() { | |||
MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop); | |||
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | |||
MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true); | |||
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true); | |||
MGB_SEREG_OPR_INTL_CALL_ADD( | |||
CounterProvider, dump_counter_provider, load_counter_provider); | |||
CounterProvider, dump_counter_provider, load_counter_provider, true); | |||
MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION); | |||
@@ -1,3 +1,4 @@ | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/serialization/sereg.h" | |||
@@ -7,10 +8,74 @@ template <> | |||
struct OprMaker<opr::ElemwiseMultiType, 0> | |||
: public OprMakerVariadic<opr::ElemwiseMultiType> {}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> { | |||
using Opr = opr::ElemwiseMultiType; | |||
using PersisParam = opr::ElemwiseMultiType::Param; | |||
using PersisElemwseiParam = opr::Elemwise::Param; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||
} | |||
static cg::OperatorNodeBase* replace_opr( | |||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
auto mode = opr->cast_final_safe<Opr>().param().mode; | |||
auto change_to_elemwise_mode = [&](PersisParam::Mode multitype_mode) { | |||
if (multitype_mode == PersisParam::Mode::EQ) { | |||
return PersisElemwseiParam::Mode::EQ; | |||
} else if (multitype_mode == PersisParam::Mode::LT) { | |||
return PersisElemwseiParam::Mode::LT; | |||
} else if (multitype_mode == PersisParam::Mode::LEQ) { | |||
return PersisElemwseiParam::Mode::LEQ; | |||
} | |||
mgb_assert(0, "no supported model."); | |||
}; | |||
if (PersisParam::Mode::EQ == mode || PersisParam::Mode::LT == mode || | |||
PersisParam::Mode::LEQ == mode) { | |||
auto elemwise_mode = change_to_elemwise_mode(mode); | |||
auto elemiwse_out = opr::Elemwise::make(inputs, {elemwise_mode}); | |||
return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||
} else if (PersisParam::Mode::NEQ == mode) { | |||
auto elemiwse_out = | |||
opr::Elemwise::make(inputs, {PersisElemwseiParam::Mode::EQ}); | |||
auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||
return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||
.node() | |||
->owner_opr(); | |||
} else if (PersisParam::Mode::ISNAN == mode) { | |||
auto elemiwse_out = opr::Elemwise::make( | |||
{inputs[0], inputs[0]}, {PersisElemwseiParam::Mode::EQ}); | |||
auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||
return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||
.node() | |||
->owner_opr(); | |||
} else if (PersisParam::Mode::ISINF == mode) { | |||
auto input_var = SymbolVar{inputs[0]}; | |||
auto inf_var = input_var.make_scalar(INFINITY); | |||
auto float_out = opr::TypeCvt::make(inputs[0], dtype::Float32()); | |||
auto elemiwse_out = opr::Elemwise::make( | |||
{float_out, inf_var}, {PersisElemwseiParam::Mode::EQ}); | |||
return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||
} | |||
return opr; | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
return OprMaker<opr::ElemwiseMultiType, 0>::make( | |||
ctx.read_param<PersisParam>(), inputs, ctx.graph(), config); | |||
} | |||
}; | |||
} // namespace serialization | |||
namespace opr { | |||
MGB_SEREG_OPR(ElemwiseMultiType, 0); | |||
MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); | |||
MGB_SEREG_OPR_V2( | |||
ElemwiseMultiType, 0, | |||
(mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), | |||
VERSION_1, VERSION_1); | |||
MGB_SEREG_OPR(AffineInt, 3); | |||
} // namespace opr | |||
} // namespace mgb | |||
@@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph( | |||
cg::OperatorNodeBase* serialization::copy_opr_shallow( | |||
const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, | |||
const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { | |||
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||
mgb_assert( | |||
registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(), | |||
opr.dyn_typeinfo()->name); | |||
OprShallowCopy shallow_copy = nullptr; | |||
if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||
shallow_copy = registry->shallow_copy; | |||
} else { | |||
shallow_copy = intl::copy_opr_shallow_default_impl; | |||
} | |||
mgb_assert(inputs.size() == opr.input().size()); | |||
auto dst_og = ctx.owner_graph(opr, inputs); | |||
auto do_copy = [&]() { | |||
auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); | |||
auto ret = registry->shallow_copy(ctx, opr, inputs, config); | |||
auto ret = shallow_copy(ctx, opr, inputs, config); | |||
if (dst_og != opr.owner_graph() || | |||
opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { | |||
@@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||
const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, | |||
const VarNodeArray& inputs, const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(ctx); | |||
OprDumper opr_dumper = nullptr; | |||
OprLoaderWrapper opr_loader = nullptr; | |||
auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||
if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||
opr_loader = registry->loader; | |||
opr_dumper = registry->dumper; | |||
} else { | |||
auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo( | |||
opr.dyn_typeinfo(), CURRENT_VERSION); | |||
opr_loader = registryv2->loader; | |||
opr_dumper = registryv2->dumper; | |||
} | |||
mgb_assert( | |||
registry && registry->dumper && registry->loader, | |||
opr_dumper && opr_loader, | |||
"can not shallow_copy operator %s{%s}: " | |||
"no dumper/loader registered", | |||
opr.cname(), opr.dyn_typeinfo()->name); | |||
OprDumpContextMemory dumper; | |||
registry->dumper(dumper, opr); | |||
OprDumpContextMemory memory_dumper; | |||
opr_dumper(memory_dumper, opr); | |||
OprLoadContextMemory loader{opr.owner_graph(), dumper}; | |||
return registry->loader(loader, inputs, config).opr(); | |||
OprLoadContextMemory loader{opr.owner_graph(), memory_dumper}; | |||
return opr_loader(loader, inputs, config).opr(); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||
auto new_output_vars = output_vars; | |||
if (!config.no_change_graph) { | |||
new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||
mgb_assert(output_vars.size() == new_output_vars.size()); | |||
for (size_t id = 0; id < output_vars.size(); id++) { | |||
auto& new_var = new_output_vars[id]; | |||
new_var.rename(output_vars[id].node()->name()); | |||
} | |||
} | |||
auto begin_pos = m_file->tell(); | |||
@@ -151,20 +151,22 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
//! call OprRegistryV2::versioned_add for new serialization which is compatiable | |||
//! with old serialization, convert is nullptr, this registry is just only for | |||
//! varsion 1 | |||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
do { \ | |||
::mgb::serialization::OprRegistry::add( \ | |||
{_cls::typeinfo(), \ | |||
MGB_HASH_STR(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
_dump, \ | |||
_load, \ | |||
{}, \ | |||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load, _registerv2) \ | |||
do { \ | |||
::mgb::serialization::OprRegistry::add( \ | |||
{_cls::typeinfo(), \ | |||
MGB_HASH_STR(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
_dump, \ | |||
_load, \ | |||
{}, \ | |||
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
if (_registerv2) { \ | |||
::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
} \ | |||
} while (0) | |||
//! call OprRegistryV2::versioned_add for new serialization, in which convert the | |||
@@ -181,23 +183,25 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
/*! | |||
* \brief register opr serialization methods | |||
*/ | |||
#define MGB_SEREG_OPR(_cls, _arity) \ | |||
namespace { \ | |||
namespace ser = ::mgb::serialization; \ | |||
struct _OprReg##_cls { \ | |||
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
static ser::OprWithOutputAccessor wrap_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
} \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||
} \ | |||
}; \ | |||
} \ | |||
#define MGB_SEREG_OPR_CONDITION(_cls, _arity, _registerv2) \ | |||
namespace { \ | |||
namespace ser = ::mgb::serialization; \ | |||
struct _OprReg##_cls { \ | |||
using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
static ser::OprWithOutputAccessor wrap_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
} \ | |||
static void entry() { \ | |||
MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader, _registerv2); \ | |||
} \ | |||
}; \ | |||
} \ | |||
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
#define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true) | |||
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is | |||
//! optional , if not implement pass nullptr | |||
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | |||
@@ -1,3 +1,4 @@ | |||
#include "megbrain/opr/nn_int.h" | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
@@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { | |||
load(); | |||
} | |||
TEST(TestSerializer2, TestElemwiseMultiTypeLoadDump) { | |||
auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2); | |||
TensorShape shape{3}; | |||
auto cn = CompNode::load("xpu0"); | |||
std::shared_ptr<HostTensorND> host0 = | |||
std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||
std::shared_ptr<HostTensorND> host1 = | |||
std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||
HostTensorND dst_truth; | |||
host0->ptr<float>()[0] = 2; | |||
host0->ptr<float>()[1] = 2; | |||
host0->ptr<float>()[2] = -1; | |||
host1->ptr<float>()[0] = 1; | |||
host1->ptr<float>()[1] = 2; | |||
host1->ptr<float>()[2] = 3; | |||
auto dump = [&](opr::ElemwiseMultiType::Param::Mode mode, size_t nr_opr) { | |||
auto graph = ComputingGraph::make(); | |||
OperatorNodeConfig config; | |||
config.name("input0"); | |||
auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0, config); | |||
config.name("input1"); | |||
auto h2d1 = opr::Host2DeviceCopy::make(*graph, host1, config); | |||
auto x = opr::ElemwiseMultiType::make( | |||
{h2d0, h2d1}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||
x.rename("out"); | |||
auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||
auto dumper = GraphDumper::make( | |||
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
auto rst = dumper->dump({x}); | |||
func->execute().wait(); | |||
ASSERT_EQ(rst.nr_opr, nr_opr); | |||
}; | |||
auto load = [&]() { | |||
auto loader = GraphLoader::make( | |||
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
auto rst = loader->load(); | |||
ASSERT_EQ(rst.tensor_map.size(), 2); | |||
ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||
HostTensorND host_x; | |||
auto func = | |||
rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||
for (auto& input : rst.tensor_map) { | |||
if (input.first == "input0") { | |||
input.second->copy_from(*host0).sync(); | |||
} else if (input.first == "input1") { | |||
input.second->copy_from(*host1).sync(); | |||
} | |||
} | |||
func->execute().wait(); | |||
for (int i = 0; i < 3; i++) { | |||
EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||
} | |||
}; | |||
dump(opr::ElemwiseMultiType::Param::Mode::EQ, 4); | |||
load(); | |||
dump(opr::ElemwiseMultiType::Param::Mode::LT, 4); | |||
load(); | |||
dump(opr::ElemwiseMultiType::Param::Mode::LEQ, 4); | |||
load(); | |||
dump(opr::ElemwiseMultiType::Param::Mode::NEQ, 5); | |||
load(); | |||
auto dump_single_input = [&](opr::ElemwiseMultiType::Param::Mode mode, | |||
size_t nr_opr) { | |||
auto graph = ComputingGraph::make(); | |||
auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0); | |||
auto x = opr::ElemwiseMultiType::make( | |||
{h2d0}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||
x.rename("out"); | |||
auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||
auto dumper = GraphDumper::make( | |||
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
auto rst = dumper->dump({x}); | |||
func->execute().wait(); | |||
ASSERT_EQ(rst.nr_opr, nr_opr); | |||
}; | |||
auto load_single_input = [&]() { | |||
auto loader = GraphLoader::make( | |||
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
auto rst = loader->load(); | |||
ASSERT_EQ(rst.tensor_map.size(), 1); | |||
ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||
HostTensorND host_x; | |||
auto func = | |||
rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||
rst.tensor_map.begin()->second->copy_from(*host0).sync(); | |||
func->execute().wait(); | |||
for (int i = 0; i < 3; i++) { | |||
EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||
} | |||
}; | |||
host0->ptr<float>()[2] = INFINITY; | |||
dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISINF, 4); | |||
load_single_input(); | |||
host0->ptr<float>()[2] = NAN; | |||
dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISNAN, 4); | |||
load_single_input(); | |||
} | |||
#endif |