GitOrigin-RevId: 6066516c31
release-1.10
@@ -430,6 +430,7 @@ def dump_graph( | |||||
dump_format_map = { | dump_format_map = { | ||||
None: None, | None: None, | ||||
"FBS_V2": SerializationFormat.FBS_V2, | |||||
"FBS": SerializationFormat.FBS, | "FBS": SerializationFormat.FBS, | ||||
} | } | ||||
dump_format = dump_format_map[dump_format] | dump_format = dump_format_map[dump_format] | ||||
@@ -613,8 +613,9 @@ class trace: | |||||
resize_input: whether resize input image to fit input var shape. | resize_input: whether resize input image to fit input var shape. | ||||
input_transform: a python expression to transform the input data. | input_transform: a python expression to transform the input data. | ||||
Example: data / np.std(data) | Example: data / np.std(data) | ||||
dump_format: using different dump formats. the open source MegEngine defaults to the FBS | |||||
format. internal MegEngine have a choice of FBS and internal proprietary formats | |||||
dump_format: using different dump formats. the open source MegEngine | |||||
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, | |||||
internal MegEngine have an other choice of internal proprietary formats | |||||
Keyword Arguments: | Keyword Arguments: | ||||
@@ -308,6 +308,7 @@ void init_graph_rt(py::module m) { | |||||
py::enum_<_SerializationFormat>(m, "SerializationFormat") | py::enum_<_SerializationFormat>(m, "SerializationFormat") | ||||
.value("FBS", _SerializationFormat::FLATBUFFERS) | .value("FBS", _SerializationFormat::FLATBUFFERS) | ||||
.value("FBS_V2", _SerializationFormat::FLATBUFFERS_V2) | |||||
.export_values(); | .export_values(); | ||||
m.def("optimize_for_inference", | m.def("optimize_for_inference", | ||||
@@ -384,11 +385,9 @@ void init_graph_rt(py::module m) { | |||||
std::optional<_SerializationFormat> dump_format, py::list& stat, | std::optional<_SerializationFormat> dump_format, py::list& stat, | ||||
py::list& inputs, py::list& outputs, py::list& params) { | py::list& inputs, py::list& outputs, py::list& params) { | ||||
std::vector<uint8_t> buf; | std::vector<uint8_t> buf; | ||||
ser::GraphDumpFormat format; | |||||
ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; | |||||
if (dump_format.has_value()) { | if (dump_format.has_value()) { | ||||
format = dump_format.value(); | format = dump_format.value(); | ||||
} else { | |||||
format = {}; | |||||
} | } | ||||
auto dumper = ser::GraphDumper::make( | auto dumper = ser::GraphDumper::make( | ||||
ser::OutputFile::make_vector_proxy(&buf), format); | ser::OutputFile::make_vector_proxy(&buf), format); | ||||
@@ -3,7 +3,6 @@ | |||||
#include "megbrain/serialization/opr_load_dump.h" | #include "megbrain/serialization/opr_load_dump.h" | ||||
#include "megbrain/serialization/opr_registry.h" | #include "megbrain/serialization/opr_registry.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||||
#include "megbrain/utils/hash_ct.h" | #include "megbrain/utils/hash_ct.h" | ||||
namespace mgb { | namespace mgb { | ||||