GitOrigin-RevId: 29f785b701
release-1.11
@@ -27,7 +27,7 @@ class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { | |||||
std::shared_ptr<DeviceTensorND> load_tensor_shared( | std::shared_ptr<DeviceTensorND> load_tensor_shared( | ||||
bool copy_immediatly = false) override { | bool copy_immediatly = false) override { | ||||
(void)copy_immediatly; | |||||
MGB_MARK_USED_VAR(copy_immediatly); | |||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
@@ -56,7 +56,7 @@ public: | |||||
} | } | ||||
void dump_tensor( | void dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
TensorWriteMethod method) { | |||||
TensorWriteMethod method, TensorFormat format = {}) { | |||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | ||||
@@ -72,16 +72,20 @@ struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> { | |||||
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | ||||
HostTensorND val; | HostTensorND val; | ||||
val.copy_from(opr.get_dev_tensor()).sync(); | val.copy_from(opr.get_dev_tensor()).sync(); | ||||
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); | |||||
ctx.dump_tensor( | |||||
{}, val, Meth::VALUE_ANONYMOUS, opr.get_dev_tensor().layout().format); | |||||
} | } | ||||
static cg::OperatorNodeBase* load( | static cg::OperatorNodeBase* load( | ||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | OprLoadContext& ctx, const cg::VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
mgb_assert(inputs.empty()); | mgb_assert(inputs.empty()); | ||||
auto val = ctx.load_tensor(); | |||||
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
auto val = fbs_ctx.load_tensor(); | |||||
auto format = fbs_ctx.load_tensor_format(0); | |||||
TensorLayout layout_with_format = {val->shape(), val->dtype(), format}; | |||||
auto dev_val = | auto dev_val = | ||||
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout()); | |||||
std::make_shared<DeviceTensorND>(val->comp_node(), layout_with_format); | |||||
dev_val->copy_from_fixlayout(*val); | dev_val->copy_from_fixlayout(*val); | ||||
auto out_var = | auto out_var = | ||||
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | ||||
@@ -136,7 +140,9 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||||
HostTensorND val; | HostTensorND val; | ||||
auto value = *opr.values()[i]; | auto value = *opr.values()[i]; | ||||
val.copy_from(value).sync(); | val.copy_from(value).sync(); | ||||
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||||
ctx.dump_tensor( | |||||
opr.output(i)->name(), val, Meth::VALUE_SHARED, | |||||
value.layout().format); | |||||
} | } | ||||
} | } | ||||
@@ -152,10 +158,12 @@ struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||||
nr = fopr->tensors()->size(); | nr = fopr->tensors()->size(); | ||||
} | } | ||||
Opr::ValueArray values(nr); | Opr::ValueArray values(nr); | ||||
size_t id = 0; | |||||
for (auto&& i : values) { | for (auto&& i : values) { | ||||
i = ctx.load_tensor_shared(); | i = ctx.load_tensor_shared(); | ||||
//! set tensor format | //! set tensor format | ||||
TensorLayout layout_with_format = i->layout(); | |||||
auto format = fbs_ctx.load_tensor_format(id++); | |||||
TensorLayout layout_with_format{i->layout(), i->layout().dtype, format}; | |||||
if (i->storage().comp_node().mem_node() == | if (i->storage().comp_node().mem_node() == | ||||
CompNode::default_cpu().mem_node()) { | CompNode::default_cpu().mem_node()) { | ||||
@@ -498,48 +498,66 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) { | |||||
auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
{ | |||||
// dump | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||||
auto test = [&](serialization::GraphDumpFormat format) { | |||||
{ | |||||
// dump | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_x = gen({8, 8, 8, 8}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}); | |||||
opr::Convolution::Param param; | |||||
param.pad_h = param.pad_w = 0; | |||||
auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param); | |||||
auto w2 = mkcvar("w2", {4, 4, 3, 3}), | |||||
conv2 = opr::Convolution::make(conv1, w2, param); | |||||
auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nhwcd4(); | |||||
SymbolVar y_opt = | |||||
gopt::optimize_for_inference({y}, options)[0].rename("out"); | |||||
auto dumper = serialization::GraphDumper::make( | |||||
serialization::OutputFile::make_fs(fname.c_str()), format); | |||||
serialization::GraphDumper::DumpConfig config; | |||||
config.keep_param_name = true; | |||||
dumper->dump({y_opt}, config); | |||||
} | |||||
auto loader = serialization::GraphLoader::make( | |||||
serialization::InputFile::make_fs(fname.c_str()), format); | |||||
auto load = [&](CompNode dest_cn) { | |||||
auto dest_cn_loc = dest_cn.locator_logical(); | |||||
auto rst = | |||||
loader->load({[&](CompNode::Locator& loc) { loc = dest_cn_loc; }}); | |||||
HostTensorND host_z, host_z_expect; | |||||
auto func = rst.graph_compile( | |||||
{make_callback_copy(rst.output_var_map.at("out"), host_z)}); | |||||
func->execute(); | |||||
func->wait(); | |||||
auto&& shared_tensor_map = loader->shared_tensor_id_map(); | |||||
bool cd4 = false; | |||||
for (auto&& i : shared_tensor_map) { | |||||
auto&& shared_tensor = i.second.begin()->second; | |||||
if (shared_tensor->format().type() == | |||||
TensorFormat::Type::IMAGE2D_PACK4) { | |||||
cd4 = true; | |||||
} | |||||
} | |||||
ASSERT_TRUE(cd4); | |||||
}; | }; | ||||
auto host_x = gen({8, 8, 8, 8}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}); | |||||
opr::Convolution::Param param; | |||||
param.pad_h = param.pad_w = 0; | |||||
auto w1 = mkcvar("w1", {4, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param); | |||||
auto w2 = mkcvar("w2", {4, 4, 3, 3}), | |||||
conv2 = opr::Convolution::make(conv1, w2, param); | |||||
auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nhwcd4(); | |||||
SymbolVar y_opt = gopt::optimize_for_inference({y}, options)[0].rename("out"); | |||||
auto dumper = serialization::GraphDumper::make( | |||||
serialization::OutputFile::make_fs(fname.c_str())); | |||||
serialization::GraphDumper::DumpConfig config; | |||||
config.keep_param_name = true; | |||||
dumper->dump({y_opt}, config); | |||||
} | |||||
auto loader = serialization::GraphLoader::make( | |||||
serialization::InputFile::make_fs(fname.c_str())); | |||||
auto load = [&](CompNode dest_cn) { | |||||
auto dest_cn_loc = dest_cn.locator_logical(); | |||||
auto rst = loader->load({[&](CompNode::Locator& loc) { loc = dest_cn_loc; }}); | |||||
HostTensorND host_z, host_z_expect; | |||||
auto func = rst.graph_compile( | |||||
{make_callback_copy(rst.output_var_map.at("out"), host_z)}); | |||||
func->execute(); | |||||
load(cn); | |||||
}; | }; | ||||
load(cn); | |||||
test({}); | |||||
test(serialization::GraphDumpFormat::FLATBUFFERS_V2); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -32,7 +32,9 @@ class OprDumpContextMemory final : public OprDumpContextRawPOD { | |||||
} | } | ||||
void dump_tensor( | void dump_tensor( | ||||
const std::string&, const HostTensorND&, TensorWriteMethod) override { | |||||
const std::string&, const HostTensorND&, TensorWriteMethod, | |||||
TensorFormat format = {}) override { | |||||
MGB_MARK_USED_VAR(format); | |||||
mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor"); | mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor"); | ||||
} | } | ||||
@@ -92,7 +92,7 @@ public: | |||||
const GraphDumpConfig& config() const override { return m_config; } | const GraphDumpConfig& config() const override { return m_config; } | ||||
void dump_tensor( | void dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
TensorWriteMethod method) override; | |||||
TensorWriteMethod method, TensorFormat format = {}) override; | |||||
flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } | flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } | ||||
void append_param(uint32_t type, uint32_t value) override { | void append_param(uint32_t type, uint32_t value) override { | ||||
static_assert( | static_assert( | ||||
@@ -359,7 +359,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||||
} | } | ||||
void GraphDumperOSS::dump_tensor( | void GraphDumperOSS::dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { | |||||
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method, | |||||
TensorFormat) { | |||||
using namespace flatbuffers; | using namespace flatbuffers; | ||||
using Meth = TensorWriteMethod; | using Meth = TensorWriteMethod; | ||||
mgb_assert( | mgb_assert( | ||||
@@ -671,17 +672,17 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSS::OprLoadContextImpl::load_tensor_ | |||||
sh_reg.first = tensor->name()->str(); | sh_reg.first = tensor->name()->str(); | ||||
} | } | ||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { | |||||
// directly forward CPU memory | // directly forward CPU memory | ||||
HostTensorND hv{comp_node}; | HostTensorND hv{comp_node}; | ||||
load_tensor_value(&hv, layout, tensor); | load_tensor_value(&hv, layout, tensor); | ||||
sh_ptr_ref = std::make_shared<DeviceTensorND>(); | sh_ptr_ref = std::make_shared<DeviceTensorND>(); | ||||
*sh_ptr_ref = DeviceTensorND::make_proxy(hv); | |||||
} else if (copy_immediatly) { | |||||
HostTensorND hv{CompNode::default_cpu()}; | |||||
load_tensor_value(&hv, layout, tensor); | |||||
sh_ptr_ref = std::make_shared<DeviceTensorND>(); | |||||
sh_ptr_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
*sh_ptr_ref = DeviceTensorND::make_proxy(hv); | |||||
} else { | |||||
mgb_assert(copy_immediatly); | |||||
sh_ptr_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
} | |||||
} else { | } else { | ||||
// use lazy load for non-CPU devices | // use lazy load for non-CPU devices | ||||
HostTensorND hv{CompNode::default_cpu()}; | HostTensorND hv{CompNode::default_cpu()}; | ||||
@@ -455,7 +455,8 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||||
} | } | ||||
void GraphDumperOSSV2::dump_tensor( | void GraphDumperOSSV2::dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { | |||||
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method, | |||||
TensorFormat format) { | |||||
using namespace flatbuffers; | using namespace flatbuffers; | ||||
using Meth = TensorWriteMethod; | using Meth = TensorWriteMethod; | ||||
mgb_assert( | mgb_assert( | ||||
@@ -510,8 +511,8 @@ void GraphDumperOSSV2::dump_tensor( | |||||
m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); | m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); | ||||
auto fdtype = build_dtype(layout.dtype); | auto fdtype = build_dtype(layout.dtype); | ||||
auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); | |||||
auto fformat = build_tensor_format(layout.format); | |||||
auto fformat_type = get_flatbuffer_tensor_format_type(format); | |||||
auto fformat = build_tensor_format(format); | |||||
auto serialized_tensor = fbs::v2::CreateTensor( | auto serialized_tensor = fbs::v2::CreateTensor( | ||||
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); | m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); | ||||
m_cur_opr_tensor.emplace_back(serialized_tensor); | m_cur_opr_tensor.emplace_back(serialized_tensor); | ||||
@@ -605,7 +606,7 @@ CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node( | |||||
return CompNode::load(loc); | return CompNode::load(loc); | ||||
} | } | ||||
TensorFormat load_tensor_format( | |||||
TensorFormat get_tensor_format( | |||||
const fbs::v2::TensorFormat fformat_type, const void* fformat, | const fbs::v2::TensorFormat fformat_type, const void* fformat, | ||||
const CompNode& comp_node) { | const CompNode& comp_node) { | ||||
switch (fformat_type) { | switch (fformat_type) { | ||||
@@ -631,8 +632,7 @@ TensorFormat load_tensor_format( | |||||
} | } | ||||
} | } | ||||
TensorLayout load_tensor_layout( | |||||
const fbs::v2::Tensor* tensor, const CompNode& comp_node) { | |||||
TensorLayout load_tensor_layout_without_format(const fbs::v2::Tensor* tensor) { | |||||
TensorLayout layout; | TensorLayout layout; | ||||
if (tensor->shape()) { | if (tensor->shape()) { | ||||
layout.ndim = tensor->shape()->size(); | layout.ndim = tensor->shape()->size(); | ||||
@@ -642,14 +642,21 @@ TensorLayout load_tensor_layout( | |||||
// modify data type inplace for TensorLayout | // modify data type inplace for TensorLayout | ||||
layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); | layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); | ||||
} | } | ||||
if (tensor->format() && tensor->format_type()) { | |||||
layout.format = | |||||
load_tensor_format(tensor->format_type(), tensor->format(), comp_node); | |||||
} | |||||
layout.init_contiguous_stride(); | layout.init_contiguous_stride(); | ||||
return layout; | return layout; | ||||
} | } | ||||
TensorFormat GraphLoaderOSSV2::OprLoadContextImpl::load_tensor_format(size_t id) { | |||||
mgb_assert(m_current_opr->tensors() && id < m_current_opr->tensors()->size()); | |||||
auto tensor = m_current_opr->tensors()->Get(id); | |||||
auto comp_node = load_comp_node(tensor->comp_node()); | |||||
TensorFormat format; | |||||
if (tensor->format() && tensor->format_type()) { | |||||
format = get_tensor_format(tensor->format_type(), tensor->format(), comp_node); | |||||
} | |||||
return format; | |||||
} | |||||
//! the opr loader should make sure the exist of tensors and the number of | //! the opr loader should make sure the exist of tensors and the number of | ||||
//! tensor, here just assert it. | //! tensor, here just assert it. | ||||
std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { | std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { | ||||
@@ -658,7 +665,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( | |||||
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | ||||
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | ||||
auto comp_node = load_comp_node(tensor->comp_node()); | auto comp_node = load_comp_node(tensor->comp_node()); | ||||
auto layout = load_tensor_layout(tensor, comp_node); | |||||
auto layout = load_tensor_layout_without_format(tensor); | |||||
auto ret = std::make_shared<HostTensorND>(comp_node, layout); | auto ret = std::make_shared<HostTensorND>(comp_node, layout); | ||||
auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; | auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; | ||||
@@ -692,7 +699,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); | ||||
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); | ||||
auto comp_node = load_comp_node(tensor->comp_node()); | auto comp_node = load_comp_node(tensor->comp_node()); | ||||
auto layout = load_tensor_layout(tensor, comp_node); | |||||
auto layout = load_tensor_layout_without_format(tensor); | |||||
mgb_assert(tensor->data()); | mgb_assert(tensor->data()); | ||||
if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) { | if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) { | ||||
m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5); | m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5); | ||||
@@ -712,7 +719,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
shared_pair.first = tensor->name()->str(); | shared_pair.first = tensor->name()->str(); | ||||
} | } | ||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { | |||||
// directly forward CPU memory | // directly forward CPU memory | ||||
shared_tensor_ref = std::make_shared<DeviceTensorND>(); | shared_tensor_ref = std::make_shared<DeviceTensorND>(); | ||||
HostTensorND hv{comp_node}; | HostTensorND hv{comp_node}; | ||||
@@ -722,18 +729,13 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: | |||||
hv, tensor->data()->data(), tensor->data()->size(), | hv, tensor->data()->data(), tensor->data()->size(), | ||||
m_loader->m_file->is_shared_memory()); | m_loader->m_file->is_shared_memory()); | ||||
} | } | ||||
*shared_tensor_ref = DeviceTensorND::make_proxy(hv); | |||||
m_tensor_alignment->add_device_tensor(shared_tensor_ref); | |||||
} else if (copy_immediatly) { | |||||
HostTensorND hv{CompNode::default_cpu()}; | |||||
shared_tensor_ref = std::make_shared<DeviceTensorND>(); | |||||
if (tensor->data() && tensor->data()->size() > 0) { | |||||
hv.dtype(layout.dtype).resize(layout); | |||||
fill_tensor_memory( | |||||
hv, tensor->data()->data(), tensor->data()->size(), | |||||
m_loader->m_file->is_shared_memory()); | |||||
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { | |||||
*shared_tensor_ref = DeviceTensorND::make_proxy(hv); | |||||
m_tensor_alignment->add_device_tensor(shared_tensor_ref); | |||||
} else { | |||||
mgb_assert(copy_immediatly); | |||||
shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
} | } | ||||
shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync(); | |||||
} else { | } else { | ||||
// use lazy load for non-CPU devices | // use lazy load for non-CPU devices | ||||
HostTensorND hv{CompNode::default_cpu()}; | HostTensorND hv{CompNode::default_cpu()}; | ||||
@@ -47,7 +47,7 @@ public: | |||||
//! whether this can be write | //! whether this can be write | ||||
virtual bool writable() { return false; } | virtual bool writable() { return false; } | ||||
//! whether this file have been wrote | |||||
//! tag this file have been wrote | |||||
virtual void have_modified() {} | virtual void have_modified() {} | ||||
/*! | /*! | ||||
@@ -63,7 +63,7 @@ public: | |||||
*/ | */ | ||||
virtual void dump_tensor( | virtual void dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
TensorWriteMethod method) = 0; | |||||
TensorWriteMethod method, TensorFormat format = {}) = 0; | |||||
//! get associated global configuration | //! get associated global configuration | ||||
virtual const GraphDumpConfig& config() const = 0; | virtual const GraphDumpConfig& config() const = 0; | ||||
@@ -63,7 +63,7 @@ public: | |||||
void dump_tensor( | void dump_tensor( | ||||
const std::string& name, const HostTensorND& tensor, | const std::string& name, const HostTensorND& tensor, | ||||
TensorWriteMethod method) override; | |||||
TensorWriteMethod method, TensorFormat format = {}) override; | |||||
void append_param(uint32_t type, uint32_t value) override { | void append_param(uint32_t type, uint32_t value) override { | ||||
static_assert( | static_assert( | ||||
@@ -148,6 +148,8 @@ public: | |||||
return *m_loader->m_cur_load_config; | return *m_loader->m_cur_load_config; | ||||
} | } | ||||
TensorFormat load_tensor_format(size_t id); | |||||
//! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce | //! shared or copy the loaded flatbuffer memory to the CPU tensor, this can reduce | ||||
//! the memory used when load model, but should consider the memory | //! the memory used when load model, but should consider the memory | ||||
//! alignment | //! alignment | ||||