GitOrigin-RevId: e4771d6bc4
tags/v1.7.0.m1
@@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase): | |||||
self._last_param = p | self._last_param = p | ||||
self._param_fields = [] | self._param_fields = [] | ||||
self._fb_fields = ["builder"] | self._fb_fields = ["builder"] | ||||
if p.is_legacy: | |||||
self._skip_current_param = True | |||||
return | |||||
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | ||||
p.name, indent=1) | p.name, indent=1) | ||||
self._write("using MegDNNType = megdnn::param::%s;", p.name) | self._write("using MegDNNType = megdnn::param::%s;", p.name) | ||||
@@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
self._last_param = p | self._last_param = p | ||||
self._cur_const_val = {} | self._cur_const_val = {} | ||||
if p.is_legacy: | |||||
self._skip_current_param = True | |||||
return | |||||
self._write_doc(p.name) | self._write_doc(p.name) | ||||
self._write("table %s {", p.name, indent=1) | self._write("table %s {", p.name, indent=1) | ||||
@@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase): | |||||
def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
self._last_param = p | self._last_param = p | ||||
if p.is_legacy: | |||||
self._skip_current_param = True | |||||
return | |||||
self._packed = True | self._packed = True | ||||
self._current_tparams = [] | self._current_tparams = [] | ||||
self._const = set() | self._const = set() | ||||
@@ -62,6 +62,37 @@ struct PersistentAddUpdateParam { | |||||
} // namespace opr_add_update | } // namespace opr_add_update | ||||
// Old SerializedDType used in MegBrain 7.22.0 - 7.23.1 | |||||
// Should be kept as-is even if there are new dtypes. | |||||
struct SerializedDTypeV1 { | |||||
static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG; | |||||
DTypeEnum enumv; | |||||
union { | |||||
megdnn::DTypeParam<dtype::Quantized8Asymm> Quantized8Asymm; | |||||
megdnn::DTypeParam<dtype::QuantizedS8> QuantizedS8; | |||||
megdnn::DTypeParam<dtype::QuantizedS32> QuantizedS32; | |||||
} param; | |||||
operator DType() const { | |||||
switch (enumv) { | |||||
#define cb(_dt) \ | |||||
case DTypeEnum::_dt: \ | |||||
return DType::from_enum(enumv); | |||||
MEGDNN_FOREACH_DTYPE_NAME(cb) | |||||
#undef cb | |||||
case DTypeEnum::Quantized8Asymm: | |||||
return dtype::Quantized8Asymm{param.Quantized8Asymm}; | |||||
case DTypeEnum::QuantizedS8: | |||||
return dtype::QuantizedS8{param.QuantizedS8}; | |||||
case DTypeEnum::QuantizedS32: | |||||
return dtype::QuantizedS32{param.QuantizedS32}; | |||||
default: | |||||
mgb_assert( | |||||
false, "unexpected old serialized dtype: invalid enumv %d", | |||||
static_cast<uint32_t>(enumv)); | |||||
} | |||||
} | |||||
}; | |||||
template <> | template <> | ||||
struct OprPersistentParam<opr::AddUpdate> { | struct OprPersistentParam<opr::AddUpdate> { | ||||
using Param = opr_add_update::PersistentAddUpdateParam; | using Param = opr_add_update::PersistentAddUpdateParam; | ||||
@@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> { | |||||
return fbs::intl::build_dtype(builder, dtype); | return fbs::intl::build_dtype(builder, dtype); | ||||
} | } | ||||
}; | }; | ||||
} // namespace fbs | |||||
template <> | |||||
struct ParamConverter<SerializedDTypeV1> { | |||||
using FlatBufferType = SerializedDTypeV1; | |||||
static SerializedDTypeV1 to_param(const FlatBufferType* fb) { | |||||
mgb_assert( | |||||
false, | |||||
"You are calling SerializedDTypeV1 in flatbuffer, you should not call " | |||||
"here, this code is just to avoid compiling errors, but not be used in " | |||||
"flatbuffer."); | |||||
} | |||||
}; | |||||
}; // namespace fbs | |||||
#endif | #endif | ||||
template <> | template <> | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/test/autocheck.h" | #include "megbrain/test/autocheck.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megbrain/test/megdnn_helper.h" | #include "megbrain/test/megdnn_helper.h" | ||||
@@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) { | |||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ENABLE_FBS_SERIALIZATION | |||||
TEST(TestOprDNN, MatrixMulSerialization) { | |||||
using namespace serialization; | |||||
auto fname = output_file("MatrixMulSerializationTest"); | |||||
auto dump = [&]() { | |||||
opr::MatrixMul::Param param; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorND a_host{cn, {24, 24}, dtype::Float32()}; | |||||
HostTensorND b_host{cn, {24, 24}, dtype::Float32()}; | |||||
auto a = opr::ImmutableTensor::make(*graph, a_host); | |||||
auto b = opr::ImmutableTensor::make(*graph, b_host); | |||||
auto opr = opr::MatrixMul::make(a, b, param, {}); | |||||
auto dumper = GraphDumper::make( | |||||
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = dumper->dump({opr}); | |||||
ASSERT_EQ(rst.outputs.size(), 1u); | |||||
}; | |||||
auto load = [&]() { | |||||
auto loader = GraphLoader::make( | |||||
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = loader->load(); | |||||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
auto opr = rst.output_var_list[0].node()->owner_opr(); | |||||
ASSERT_TRUE(opr->same_type<opr::MatrixMul>()); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
// | |||||
// |
@@ -47,7 +47,13 @@ namespace { | |||||
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | ||||
constexpr uint32_t MGB_MAGIC = 0x5342474D; | |||||
constexpr uint32_t MGB_MAGIC = 0x4342474D; | |||||
// In order to maintain compatibility and to allow old models to be loaded, we keep | |||||
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) | |||||
constexpr uint32_t MAGIC_V0 = 0x5342474D; | |||||
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the | |||||
// old magic(MAGIC_V0) is false. | |||||
bool magic_compare = true; | |||||
template <typename T> | template <typename T> | ||||
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | ||||
@@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso | |||||
} | } | ||||
} | } | ||||
//! feature bits for backward compatibility; default value should be 0 | |||||
struct FeatureBits64 { | |||||
//! reserved for new fields | |||||
uint64_t : 64; | |||||
static void write(OutputFile& fout) { | |||||
static_assert(sizeof(FeatureBits64) == 8, "bad feature bits"); | |||||
FeatureBits64 fb64; | |||||
memset(&fb64, 0, sizeof(fb64)); | |||||
fout.write(&fb64, 8); | |||||
} | |||||
}; | |||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
} | } | ||||
fbs::OperatorBuilder builder(m_builder); | fbs::OperatorBuilder builder(m_builder); | ||||
builder.add_type_id(registry->unversioned_type_id); | |||||
builder.add_type_id(registry->persist_type_id); | |||||
builder.add_inputs(inputs); | builder.add_inputs(inputs); | ||||
if (m_config.keep_opr_priority) { | if (m_config.keep_opr_priority) { | ||||
builder.add_priority(opr->node_prop().attribute().priority); | builder.add_priority(opr->node_prop().attribute().priority); | ||||
@@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
uint32_t magic = MGB_MAGIC; | uint32_t magic = MGB_MAGIC; | ||||
m_file->write(&magic, sizeof(magic)); | m_file->write(&magic, sizeof(magic)); | ||||
// write FeatureBits | |||||
FeatureBits64::write(*m_file); | |||||
// Padding | // Padding | ||||
uint32_t reserved = 0; | uint32_t reserved = 0; | ||||
m_file->write(&reserved, sizeof(reserved)); | m_file->write(&reserved, sizeof(reserved)); | ||||
@@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) { | |||||
class GraphLoaderOSS final : public GraphLoader { | class GraphLoaderOSS final : public GraphLoader { | ||||
const LoadConfig* m_cur_load_config = nullptr; | const LoadConfig* m_cur_load_config = nullptr; | ||||
std::unique_ptr<InputFile> m_file; | std::unique_ptr<InputFile> m_file; | ||||
FeatureBits64 m_feature_bits; | |||||
SharedBuffer m_graph_buf{{}, 0}; | SharedBuffer m_graph_buf{{}, 0}; | ||||
const fbs::Graph* m_graph; | const fbs::Graph* m_graph; | ||||
SharedTensorIDMap m_shared_tensor_map; | SharedTensorIDMap m_shared_tensor_map; | ||||
@@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb | |||||
} | } | ||||
config.comp_node_arr(comp_node_arr); | config.comp_node_arr(comp_node_arr); | ||||
} | } | ||||
auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||||
const OprRegistry* registry; | |||||
if (magic_compare) { | |||||
registry = OprRegistry::find_by_id(fbopr->type_id()); | |||||
} else { | |||||
registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||||
} | |||||
mgb_throw_if( | mgb_throw_if( | ||||
!registry, SerializationError, | !registry, SerializationError, | ||||
"failed to find opr with type %s, use python env " | "failed to find opr with type %s, use python env " | ||||
@@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||||
uint32_t magic; | uint32_t magic; | ||||
m_file->read(&magic, sizeof(magic)); | m_file->read(&magic, sizeof(magic)); | ||||
mgb_throw_if( | mgb_throw_if( | ||||
magic != MGB_MAGIC, SerializationError, | |||||
"wrong magic: wanted %#08x, actual %#08x (not a invalid fbs " | |||||
(magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError, | |||||
"wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs " | |||||
"model?)", | "model?)", | ||||
MGB_MAGIC, magic); | |||||
MGB_MAGIC, MAGIC_V0, magic); | |||||
if (magic == MGB_MAGIC) { | |||||
// read FeatureBits | |||||
magic_compare = true; | |||||
m_file->read(&m_feature_bits, 8); | |||||
} else { | |||||
magic_compare = false; | |||||
} | |||||
m_file->skip(4); | m_file->skip(4); | ||||
uint64_t offset_to_fbs; | uint64_t offset_to_fbs; | ||||
@@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) { | |||||
uint64_t magic_with_reserved = 0; | uint64_t magic_with_reserved = 0; | ||||
file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | ||||
file.skip(-sizeof(magic_with_reserved)); | file.skip(-sizeof(magic_with_reserved)); | ||||
return magic_with_reserved == MGB_MAGIC; | |||||
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); | |||||
} | } | ||||
} // namespace serialization | } // namespace serialization | ||||
@@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||||
static ser::OprWithOutputAccessor compat_loader( \ | static ser::OprWithOutputAccessor compat_loader( \ | ||||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | ||||
const mgb::cg::OperatorNodeConfig& config) { \ | const mgb::cg::OperatorNodeConfig& config) { \ | ||||
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||||
auto&& ctx_ = static_cast<ser::OprLoadContext&>(ctx); \ | |||||
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | ||||
} \ | } \ | ||||
static void entry() { \ | static void entry() { \ | ||||