GitOrigin-RevId: e4771d6bc4
tags/v1.7.0.m1
@@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase): | |||
self._last_param = p | |||
self._param_fields = [] | |||
self._fb_fields = ["builder"] | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||
p.name, indent=1) | |||
self._write("using MegDNNType = megdnn::param::%s;", p.name) | |||
@@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase): | |||
def _on_param_begin(self, p): | |||
self._last_param = p | |||
self._cur_const_val = {} | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._write_doc(p.name) | |||
self._write("table %s {", p.name, indent=1) | |||
@@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase): | |||
def _on_param_begin(self, p): | |||
self._last_param = p | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._packed = True | |||
self._current_tparams = [] | |||
self._const = set() | |||
@@ -62,6 +62,37 @@ struct PersistentAddUpdateParam { | |||
} // namespace opr_add_update | |||
// Old SerializedDType used in MegBrain 7.22.0 - 7.23.1 | |||
// Should be kept as-is even if there are new dtypes. | |||
struct SerializedDTypeV1 { | |||
static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG; | |||
DTypeEnum enumv; | |||
union { | |||
megdnn::DTypeParam<dtype::Quantized8Asymm> Quantized8Asymm; | |||
megdnn::DTypeParam<dtype::QuantizedS8> QuantizedS8; | |||
megdnn::DTypeParam<dtype::QuantizedS32> QuantizedS32; | |||
} param; | |||
operator DType() const { | |||
switch (enumv) { | |||
#define cb(_dt) \ | |||
case DTypeEnum::_dt: \ | |||
return DType::from_enum(enumv); | |||
MEGDNN_FOREACH_DTYPE_NAME(cb) | |||
#undef cb | |||
case DTypeEnum::Quantized8Asymm: | |||
return dtype::Quantized8Asymm{param.Quantized8Asymm}; | |||
case DTypeEnum::QuantizedS8: | |||
return dtype::QuantizedS8{param.QuantizedS8}; | |||
case DTypeEnum::QuantizedS32: | |||
return dtype::QuantizedS32{param.QuantizedS32}; | |||
default: | |||
mgb_assert( | |||
false, "unexpected old serialized dtype: invalid enumv %d", | |||
static_cast<uint32_t>(enumv)); | |||
} | |||
} | |||
}; | |||
template <> | |||
struct OprPersistentParam<opr::AddUpdate> { | |||
using Param = opr_add_update::PersistentAddUpdateParam; | |||
@@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> { | |||
return fbs::intl::build_dtype(builder, dtype); | |||
} | |||
}; | |||
} // namespace fbs | |||
template <> | |||
struct ParamConverter<SerializedDTypeV1> { | |||
using FlatBufferType = SerializedDTypeV1; | |||
static SerializedDTypeV1 to_param(const FlatBufferType* fb) { | |||
mgb_assert( | |||
false, | |||
"You are calling SerializedDTypeV1 in flatbuffer, you should not call " | |||
"here, this code is just to avoid compiling errors, but not be used in " | |||
"flatbuffer."); | |||
} | |||
}; | |||
}; // namespace fbs | |||
#endif | |||
template <> | |||
@@ -16,6 +16,7 @@ | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/serialization/serializer.h" | |||
#include "megbrain/test/autocheck.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/test/megdnn_helper.h" | |||
@@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) { | |||
} | |||
#endif | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
TEST(TestOprDNN, MatrixMulSerialization) { | |||
using namespace serialization; | |||
auto fname = output_file("MatrixMulSerializationTest"); | |||
auto dump = [&]() { | |||
opr::MatrixMul::Param param; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
HostTensorND a_host{cn, {24, 24}, dtype::Float32()}; | |||
HostTensorND b_host{cn, {24, 24}, dtype::Float32()}; | |||
auto a = opr::ImmutableTensor::make(*graph, a_host); | |||
auto b = opr::ImmutableTensor::make(*graph, b_host); | |||
auto opr = opr::MatrixMul::make(a, b, param, {}); | |||
auto dumper = GraphDumper::make( | |||
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||
auto rst = dumper->dump({opr}); | |||
ASSERT_EQ(rst.outputs.size(), 1u); | |||
}; | |||
auto load = [&]() { | |||
auto loader = GraphLoader::make( | |||
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); | |||
auto rst = loader->load(); | |||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||
auto opr = rst.output_var_list[0].node()->owner_opr(); | |||
ASSERT_TRUE(opr->same_type<opr::MatrixMul>()); | |||
}; | |||
dump(); | |||
load(); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
// | |||
// |
@@ -47,7 +47,13 @@ namespace { | |||
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; | |||
constexpr uint32_t MGB_MAGIC = 0x5342474D; | |||
constexpr uint32_t MGB_MAGIC = 0x4342474D; | |||
// In order to maintain compatibility and to allow old models to be loaded, we keep | |||
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) | |||
constexpr uint32_t MAGIC_V0 = 0x5342474D; | |||
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the | |||
// old magic(MAGIC_V0) is false. | |||
bool magic_compare = true; | |||
template <typename T> | |||
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) { | |||
@@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso | |||
} | |||
} | |||
//! feature bits for backward compatibility; default value should be 0 | |||
struct FeatureBits64 { | |||
//! reserved for new fields | |||
uint64_t : 64; | |||
static void write(OutputFile& fout) { | |||
static_assert(sizeof(FeatureBits64) == 8, "bad feature bits"); | |||
FeatureBits64 fb64; | |||
memset(&fb64, 0, sizeof(fb64)); | |||
fout.write(&fb64, 8); | |||
} | |||
}; | |||
} // namespace | |||
namespace mgb { | |||
@@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
} | |||
fbs::OperatorBuilder builder(m_builder); | |||
builder.add_type_id(registry->unversioned_type_id); | |||
builder.add_type_id(registry->persist_type_id); | |||
builder.add_inputs(inputs); | |||
if (m_config.keep_opr_priority) { | |||
builder.add_priority(opr->node_prop().attribute().priority); | |||
@@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
uint32_t magic = MGB_MAGIC; | |||
m_file->write(&magic, sizeof(magic)); | |||
// write FeatureBits | |||
FeatureBits64::write(*m_file); | |||
// Padding | |||
uint32_t reserved = 0; | |||
m_file->write(&reserved, sizeof(reserved)); | |||
@@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) { | |||
class GraphLoaderOSS final : public GraphLoader { | |||
const LoadConfig* m_cur_load_config = nullptr; | |||
std::unique_ptr<InputFile> m_file; | |||
FeatureBits64 m_feature_bits; | |||
SharedBuffer m_graph_buf{{}, 0}; | |||
const fbs::Graph* m_graph; | |||
SharedTensorIDMap m_shared_tensor_map; | |||
@@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb | |||
} | |||
config.comp_node_arr(comp_node_arr); | |||
} | |||
auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||
const OprRegistry* registry; | |||
if (magic_compare) { | |||
registry = OprRegistry::find_by_id(fbopr->type_id()); | |||
} else { | |||
registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); | |||
} | |||
mgb_throw_if( | |||
!registry, SerializationError, | |||
"failed to find opr with type %s, use python env " | |||
@@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||
uint32_t magic; | |||
m_file->read(&magic, sizeof(magic)); | |||
mgb_throw_if( | |||
magic != MGB_MAGIC, SerializationError, | |||
"wrong magic: wanted %#08x, actual %#08x (not a invalid fbs " | |||
(magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError, | |||
"wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs " | |||
"model?)", | |||
MGB_MAGIC, magic); | |||
MGB_MAGIC, MAGIC_V0, magic); | |||
if (magic == MGB_MAGIC) { | |||
// read FeatureBits | |||
magic_compare = true; | |||
m_file->read(&m_feature_bits, 8); | |||
} else { | |||
magic_compare = false; | |||
} | |||
m_file->skip(4); | |||
uint64_t offset_to_fbs; | |||
@@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) { | |||
uint64_t magic_with_reserved = 0; | |||
file.read(&magic_with_reserved, sizeof(magic_with_reserved)); | |||
file.skip(-sizeof(magic_with_reserved)); | |||
return magic_with_reserved == MGB_MAGIC; | |||
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); | |||
} | |||
} // namespace serialization | |||
@@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||
static ser::OprWithOutputAccessor compat_loader( \ | |||
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
const mgb::cg::OperatorNodeConfig& config) { \ | |||
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||
auto&& ctx_ = static_cast<ser::OprLoadContext&>(ctx); \ | |||
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ | |||
} \ | |||
static void entry() { \ | |||