GitOrigin-RevId: 5b9488cb93
release-1.7
@@ -93,6 +93,7 @@ struct LITE_API Options { | |||||
bool const_shape = false; | bool const_shape = false; | ||||
bool force_dynamic_alloc = false; | bool force_dynamic_alloc = false; | ||||
bool force_output_dynamic_alloc = false; | bool force_output_dynamic_alloc = false; | ||||
bool force_output_use_user_specified_memory = false; | |||||
bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
uint8_t jit_level = 0; | uint8_t jit_level = 0; | ||||
uint8_t comp_node_seq_record_level = 0; | uint8_t comp_node_seq_record_level = 0; | ||||
@@ -83,6 +83,7 @@ typedef struct Options { | |||||
int const_shape; | int const_shape; | ||||
int force_dynamic_alloc; | int force_dynamic_alloc; | ||||
int force_output_dynamic_alloc; | int force_output_dynamic_alloc; | ||||
int force_output_use_user_specified_memory; | |||||
int no_profiling_on_shape_change; | int no_profiling_on_shape_change; | ||||
int jit_level; | int jit_level; | ||||
int comp_node_seq_record_level; | int comp_node_seq_record_level; | ||||
@@ -29,6 +29,7 @@ const LiteOptions default_option = { | |||||
.const_shape = false, | .const_shape = false, | ||||
.force_dynamic_alloc = false, | .force_dynamic_alloc = false, | ||||
.force_output_dynamic_alloc = false, | .force_output_dynamic_alloc = false, | ||||
.force_output_use_user_specified_memory = false, | |||||
.no_profiling_on_shape_change = false, | .no_profiling_on_shape_change = false, | ||||
.jit_level = 0, | .jit_level = 0, | ||||
.comp_node_seq_record_level = 0, | .comp_node_seq_record_level = 0, | ||||
@@ -122,7 +123,9 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) { | |||||
lite_config.options.var_sanity_check_first_run = | lite_config.options.var_sanity_check_first_run = | ||||
c_config.options.var_sanity_check_first_run; | c_config.options.var_sanity_check_first_run; | ||||
lite_config.options.const_shape = c_config.options.const_shape; | lite_config.options.const_shape = c_config.options.const_shape; | ||||
lite_config.options.force_dynamic_alloc = c_config.options.const_shape; | |||||
lite_config.options.force_dynamic_alloc = c_config.options.force_dynamic_alloc; | |||||
lite_config.options.force_output_use_user_specified_memory = | |||||
c_config.options.force_output_use_user_specified_memory; | |||||
lite_config.options.force_output_dynamic_alloc = | lite_config.options.force_output_dynamic_alloc = | ||||
c_config.options.force_output_dynamic_alloc; | c_config.options.force_output_dynamic_alloc; | ||||
lite_config.options.no_profiling_on_shape_change = | lite_config.options.no_profiling_on_shape_change = | ||||
@@ -29,6 +29,7 @@ class LiteOptions(Structure): | |||||
("const_shape", c_int), | ("const_shape", c_int), | ||||
("force_dynamic_alloc", c_int), | ("force_dynamic_alloc", c_int), | ||||
("force_output_dynamic_alloc", c_int), | ("force_output_dynamic_alloc", c_int), | ||||
("force_output_use_user_specified_memory", c_int), | |||||
("no_profiling_on_shape_change", c_int), | ("no_profiling_on_shape_change", c_int), | ||||
("jit_level", c_int), | ("jit_level", c_int), | ||||
("comp_node_seq_record_level", c_int), | ("comp_node_seq_record_level", c_int), | ||||
@@ -52,6 +53,7 @@ class LiteOptions(Structure): | |||||
self.const_shape = False | self.const_shape = False | ||||
self.force_dynamic_alloc = False | self.force_dynamic_alloc = False | ||||
self.force_output_dynamic_alloc = False | self.force_output_dynamic_alloc = False | ||||
self.force_output_use_user_specified_memory = False | |||||
self.no_profiling_on_shape_change = False | self.no_profiling_on_shape_change = False | ||||
self.jit_level = 0 | self.jit_level = 0 | ||||
self.comp_node_seq_record_level = 0 | self.comp_node_seq_record_level = 0 | ||||
@@ -67,6 +69,7 @@ class LiteOptions(Structure): | |||||
"const_shape": bool(self.const_shape), | "const_shape": bool(self.const_shape), | ||||
"force_dynamic_alloc": bool(self.force_dynamic_alloc), | "force_dynamic_alloc": bool(self.force_dynamic_alloc), | ||||
"force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc), | "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc), | ||||
"force_output_nocopy": bool(self.force_output_nocopy), | |||||
"no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change), | "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change), | ||||
"jit_level": self.jit_level, | "jit_level": self.jit_level, | ||||
"comp_node_seq_record_level": self.comp_node_seq_record_level, | "comp_node_seq_record_level": self.comp_node_seq_record_level, | ||||
@@ -84,6 +84,9 @@ void NetworkImplDft::application_config() { | |||||
m_load_config.const_var_shape = m_user_config->options.const_shape; | m_load_config.const_var_shape = m_user_config->options.const_shape; | ||||
ConfigOption(force_dynamic_alloc, force_dynamic_alloc); | ConfigOption(force_dynamic_alloc, force_dynamic_alloc); | ||||
ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc); | ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc); | ||||
ConfigOption( | |||||
force_output_use_user_specified_memory, | |||||
force_output_use_user_specified_memory); | |||||
ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change); | ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change); | ||||
LITE_ASSERT( | LITE_ASSERT( | ||||
m_user_config->options.jit_level == 0 || | m_user_config->options.jit_level == 0 || | ||||
@@ -250,7 +253,13 @@ void NetworkImplDft::make_output_spec() { | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
m_output_spec.emplace_back(load_out, std::move(cb)); | |||||
//! if write to user-specified memory, the CallbackCaller must be nullptr. | |||||
if (m_user_config->options.force_output_use_user_specified_memory || | |||||
m_user_config->options.force_output_dynamic_alloc) { | |||||
m_output_spec.emplace_back(load_out, nullptr); | |||||
} else { | |||||
m_output_spec.emplace_back(load_out, std::move(cb)); | |||||
} | |||||
} else { | } else { | ||||
LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str())); | LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str())); | ||||
} | } | ||||
@@ -444,8 +453,7 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { | |||||
} | } | ||||
} | } | ||||
void NetworkImplDft::try_infer_tensor_layout( | |||||
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) { | |||||
void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) { | |||||
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); | auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); | ||||
auto infer_trait = var.node()->get_static_infer_trait(); | auto infer_trait = var.node()->get_static_infer_trait(); | ||||
if (std::get<0>(infer_trait)) { | if (std::get<0>(infer_trait)) { | ||||
@@ -455,9 +463,13 @@ void NetworkImplDft::try_infer_tensor_layout( | |||||
"Lite infer output shape failed, maybe the model is " | "Lite infer output shape failed, maybe the model is " | ||||
"dynamic " | "dynamic " | ||||
"shape.\n"); | "shape.\n"); | ||||
LITE_ASSERT( | |||||
!m_user_config->options.force_output_use_user_specified_memory, | |||||
"force_output_use_user_specified_memory can't be used when output " | |||||
"shape can't be derived."); | |||||
return; | return; | ||||
} | } | ||||
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()}); | |||||
Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()}); | |||||
tensor->set_layout(layout); | tensor->set_layout(layout); | ||||
} | } | ||||
} | } | ||||
@@ -559,8 +571,7 @@ void NetworkImplDft::update_output() { | |||||
out_it != m_network_io->outputs.end();) { | out_it != m_network_io->outputs.end();) { | ||||
if (std::find_if( | if (std::find_if( | ||||
m_load_result.output_var_list.begin(), | m_load_result.output_var_list.begin(), | ||||
m_load_result.output_var_list.end(), | |||||
[out_it](const mgb::SymbolVar var) { | |||||
m_load_result.output_var_list.end(), [out_it](const SymbolVar var) { | |||||
return var.node()->name() == out_it->name; | return var.node()->name() == out_it->name; | ||||
}) == m_load_result.output_var_list.end()) { | }) == m_load_result.output_var_list.end()) { | ||||
LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str()); | LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str()); | ||||
@@ -584,7 +595,7 @@ void NetworkImplDft::update_output() { | |||||
out_it->lite_tensor = | out_it->lite_tensor = | ||||
std::make_shared<Tensor>(device_id, stream_id, device_type); | std::make_shared<Tensor>(device_id, stream_id, device_type); | ||||
} | } | ||||
mgb::SymbolVar var; | |||||
SymbolVar var; | |||||
for (auto&& out_var : m_load_result.output_var_list) { | for (auto&& out_var : m_load_result.output_var_list) { | ||||
if (out_var.node()->name() == out_it->name) { | if (out_var.node()->name() == out_it->name) { | ||||
var = out_var; | var = out_var; | ||||
@@ -592,10 +603,12 @@ void NetworkImplDft::update_output() { | |||||
} | } | ||||
} | } | ||||
try_infer_tensor_layout(out_it->lite_tensor, var); | try_infer_tensor_layout(out_it->lite_tensor, var); | ||||
output_tensor_copy_optimize(var, out_it->lite_tensor); | |||||
} | } | ||||
//! user not set, use default output | //! user not set, use default output | ||||
} else { | } else { | ||||
for (auto&& out : m_load_result.output_var_list) { | for (auto&& out : m_load_result.output_var_list) { | ||||
std::shared_ptr<Tensor> lite_tensor = nullptr; | |||||
auto it = std::find_if( | auto it = std::find_if( | ||||
m_network_io->outputs.begin(), m_network_io->outputs.end(), | m_network_io->outputs.begin(), m_network_io->outputs.end(), | ||||
[&out](const IOInner io) { return io.name == out.node()->name(); }); | [&out](const IOInner io) { return io.name == out.node()->name(); }); | ||||
@@ -608,6 +621,7 @@ void NetworkImplDft::update_output() { | |||||
std::make_shared<Tensor>(device_id, stream_id, device_type); | std::make_shared<Tensor>(device_id, stream_id, device_type); | ||||
} | } | ||||
try_infer_tensor_layout(it->lite_tensor, out); | try_infer_tensor_layout(it->lite_tensor, out); | ||||
lite_tensor = it->lite_tensor; | |||||
} else { | } else { | ||||
IOInner output; | IOInner output; | ||||
output.name = out.node()->name(); | output.name = out.node()->name(); | ||||
@@ -615,11 +629,47 @@ void NetworkImplDft::update_output() { | |||||
device_id, stream_id, device_type, true); | device_id, stream_id, device_type, true); | ||||
m_network_io->outputs.push_back({output}); | m_network_io->outputs.push_back({output}); | ||||
try_infer_tensor_layout(output.lite_tensor, out); | try_infer_tensor_layout(output.lite_tensor, out); | ||||
lite_tensor = output.lite_tensor; | |||||
} | } | ||||
output_tensor_copy_optimize(out, lite_tensor); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
void NetworkImplDft::output_tensor_copy_optimize( | |||||
Var var, std::shared_ptr<Tensor> tensor) { | |||||
LITE_ASSERT( | |||||
!(m_user_config->options.force_output_use_user_specified_memory && | |||||
m_user_config->options.force_output_dynamic_alloc), | |||||
"Can't set force_output_use_user_specified_memory and " | |||||
"force_output_dynamic_alloc at the same time."); | |||||
if (m_user_config->options.force_output_use_user_specified_memory) { | |||||
TensorHelper::implement(tensor) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.set_reset_callback([var](TensorImplDft* dft_tensor) { | |||||
dft_tensor->device_share_host_memory(); | |||||
auto dv = dft_tensor->dev_tensor().get(); | |||||
dv->comp_node(var.node()->comp_node(), true); | |||||
var.node()->init_mem_plan(dv); | |||||
var.node()->reset_dev_tensor_from_tensor(*dv); | |||||
}); | |||||
} | |||||
if (m_user_config->options.force_output_dynamic_alloc) { | |||||
TensorHelper::implement(tensor) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.set_get_memory_callback([var](TensorImplDft* dft_tensor) { | |||||
if (dft_tensor->is_host()) { | |||||
auto host_tensor = dft_tensor->m_host_tensor; | |||||
*host_tensor = | |||||
HostTensorND::make_proxy(var.node()->dev_tensor()); | |||||
} else { | |||||
auto dev_tensor = dft_tensor->m_dev_tensor; | |||||
*dev_tensor = var.node()->dev_tensor(); | |||||
} | |||||
}); | |||||
} | |||||
} | |||||
std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( | std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( | ||||
std::string io_name, LiteTensorPhase phase) { | std::string io_name, LiteTensorPhase phase) { | ||||
if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { | if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "lite_build_config.h" | #include "lite_build_config.h" | ||||
#include "megbrain/graph.h" | |||||
#if LITE_BUILD_WITH_MGE | #if LITE_BUILD_WITH_MGE | ||||
#include "lite/network.h" | #include "lite/network.h" | ||||
@@ -41,6 +42,7 @@ class NetworkImplDft final : public Network::NetworkImplBase { | |||||
public: | public: | ||||
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } | NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } | ||||
using S = megdnn::param::ExecutionPolicy::Strategy; | using S = megdnn::param::ExecutionPolicy::Strategy; | ||||
using Var = mgb::cg::SymbolVar; | |||||
//! set the config of the network, include: | //! set the config of the network, include: | ||||
//! the inference device | //! the inference device | ||||
//! the other inference options, such as record_level, weight_preprocess... | //! the other inference options, such as record_level, weight_preprocess... | ||||
@@ -207,8 +209,10 @@ private: | |||||
void compile_graph(); | void compile_graph(); | ||||
//! try to infer output tensor layout | //! try to infer output tensor layout | ||||
void try_infer_tensor_layout( | |||||
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var); | |||||
void try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var); | |||||
//! optimized output tensor copy | |||||
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor); | |||||
private: | private: | ||||
bool m_async = false; | bool m_async = false; | ||||
@@ -149,6 +149,9 @@ Layout TensorImplDft::get_layout() const { | |||||
} | } | ||||
void* TensorImplDft::get_memory_ptr() const { | void* TensorImplDft::get_memory_ptr() const { | ||||
if (m_get_memory_callback) { | |||||
m_get_memory_callback(const_cast<TensorImplDft*>(this)); | |||||
} | |||||
if (is_host()) { | if (is_host()) { | ||||
return static_cast<void*>(m_host_tensor->raw_ptr()); | return static_cast<void*>(m_host_tensor->raw_ptr()); | ||||
} else { | } else { | ||||
@@ -157,6 +160,9 @@ void* TensorImplDft::get_memory_ptr() const { | |||||
} | } | ||||
void* TensorImplDft::get_memory_ptr(const std::vector<size_t>& idx) const { | void* TensorImplDft::get_memory_ptr(const std::vector<size_t>& idx) const { | ||||
if (m_get_memory_callback) { | |||||
m_get_memory_callback(const_cast<TensorImplDft*>(this)); | |||||
} | |||||
if (is_host()) { | if (is_host()) { | ||||
auto elemsize_log = m_host_tensor->layout().dtype.size_log(); | auto elemsize_log = m_host_tensor->layout().dtype.size_log(); | ||||
switch (elemsize_log) { | switch (elemsize_log) { | ||||
@@ -317,6 +323,9 @@ void TensorImplDft::reset(void* prepared_data) { | |||||
storage.reset(cn, size, raw_storage); | storage.reset(cn, size, raw_storage); | ||||
m_dev_tensor->reset(storage, mge_layout); | m_dev_tensor->reset(storage, mge_layout); | ||||
} | } | ||||
if (m_reset_callback) { | |||||
m_reset_callback(this); | |||||
} | |||||
} | } | ||||
void TensorImplDft::reset(void* prepared_data, const Layout& layout) { | void TensorImplDft::reset(void* prepared_data, const Layout& layout) { | ||||
@@ -430,6 +439,34 @@ void TensorImplDft::copy_from_mge_tensor(const mgb::DeviceTensorND& dv) { | |||||
} | } | ||||
} | } | ||||
void TensorImplDft::set_reset_callback(const std::function<void(TensorImplDft*)>& cb) { | |||||
m_reset_callback = cb; | |||||
} | |||||
void TensorImplDft::set_get_memory_callback( | |||||
const std::function<void(TensorImplDft*)>& cb) { | |||||
m_get_memory_callback = cb; | |||||
} | |||||
void TensorImplDft::device_share_host_memory() { | |||||
if (is_host()) { | |||||
if (!m_dev_tensor) { | |||||
m_dev_tensor = std::make_shared<mgb::DeviceTensorND>( | |||||
m_host_tensor->comp_node(), m_host_tensor->layout()); | |||||
} | |||||
if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) { | |||||
auto raw_storage = std::shared_ptr<mgb::dt_byte>( | |||||
m_host_tensor->raw_ptr(), [](void*) {}); | |||||
auto cn = m_host_tensor->comp_node(); | |||||
auto mge_layout = m_host_tensor->layout(); | |||||
size_t size = mge_layout.span().dist_byte(); | |||||
mgb::DeviceTensorStorage storage; | |||||
storage.reset(cn, size, raw_storage); | |||||
m_dev_tensor->reset(storage, mge_layout); | |||||
} | |||||
} | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -97,11 +97,22 @@ public: | |||||
//! get host tensor | //! get host tensor | ||||
std::shared_ptr<mgb::HostTensorND> host_tensor() const { return m_host_tensor; } | std::shared_ptr<mgb::HostTensorND> host_tensor() const { return m_host_tensor; } | ||||
//! get device tensor | //! get device tensor | ||||
std::shared_ptr<mgb::DeviceTensorND> dev_tensor() const { return m_dev_tensor; } | std::shared_ptr<mgb::DeviceTensorND> dev_tensor() const { return m_dev_tensor; } | ||||
//! copy from mgb tensor | //! copy from mgb tensor | ||||
void copy_from_mge_tensor(const mgb::DeviceTensorND& dv); | void copy_from_mge_tensor(const mgb::DeviceTensorND& dv); | ||||
//! set tensor reset callback | |||||
void set_reset_callback(const std::function<void(TensorImplDft*)>& cb); | |||||
//! set tensor get memory callback | |||||
void set_get_memory_callback(const std::function<void(TensorImplDft*)>& cb); | |||||
//! shared the same memory with host and device tensor | |||||
void device_share_host_memory(); | |||||
public: | public: | ||||
friend class NetworkImplDft; | friend class NetworkImplDft; | ||||
@@ -115,6 +126,8 @@ private: | |||||
void set_mge_tensor_compnode(const mgb::CompNode& comp_node); | void set_mge_tensor_compnode(const mgb::CompNode& comp_node); | ||||
private: | private: | ||||
std::function<void(TensorImplDft*)> m_get_memory_callback; | |||||
std::function<void(TensorImplDft*)> m_reset_callback; | |||||
std::shared_ptr<mgb::HostTensorND> m_host_tensor; | std::shared_ptr<mgb::HostTensorND> m_host_tensor; | ||||
std::shared_ptr<mgb::DeviceTensorND> m_dev_tensor; | std::shared_ptr<mgb::DeviceTensorND> m_dev_tensor; | ||||
}; | }; | ||||
@@ -153,6 +153,10 @@ std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) { | |||||
Network& Network::set_async_callback(const AsyncCallback& callback) { | Network& Network::set_async_callback(const AsyncCallback& callback) { | ||||
LITE_ERROR_HANDLER_BEGIN | LITE_ERROR_HANDLER_BEGIN | ||||
LITE_ASSERT( | |||||
!m_config.options.force_output_use_user_specified_memory, | |||||
"Async mode can't run with force_output_use_user_specified_memory which " | |||||
"output data is written to use specific memory."); | |||||
LITE_CHECK_NON_NULL_POINTER(m_impl); | LITE_CHECK_NON_NULL_POINTER(m_impl); | ||||
m_impl->set_async_callback(std::move(callback)); | m_impl->set_async_callback(std::move(callback)); | ||||
return *this; | return *this; | ||||
@@ -397,6 +397,73 @@ TEST(TestNetWork, ResetOutput) { | |||||
compare_lite_tensor<float>(output_tensor, result_mgb); | compare_lite_tensor<float>(output_tensor, result_mgb); | ||||
} | } | ||||
TEST(TestNetWork, OutputNoCopy) { | |||||
Config config; | |||||
config.options.force_output_use_user_specified_memory = true; | |||||
auto tensor = get_input_data("./input_data.npy"); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
std::string input_name = "data"; | |||||
auto result_mgb = mgb_lar(model_path, config, input_name, tensor); | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->load_model(model_path); | |||||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||||
auto src_ptr = tensor->get_memory_ptr(); | |||||
auto src_layout = tensor->get_layout(); | |||||
input_tensor->reset(src_ptr, src_layout); | |||||
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||||
size_t times = 5; | |||||
std::vector<std::shared_ptr<Tensor>> result_tensors; | |||||
for (size_t i = 0; i < times; i++) { | |||||
auto tmp = std::make_shared<Tensor>( | |||||
LiteDeviceType::LITE_CPU, | |||||
Layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT}); | |||||
result_tensors.push_back(tmp); | |||||
} | |||||
for (size_t i = 0; i < times; i++) { | |||||
void* out_data = result_tensors[i]->get_memory_ptr(); | |||||
output_tensor->reset(out_data, result_tensors[i]->get_layout()); | |||||
network->forward(); | |||||
network->wait(); | |||||
ASSERT_EQ(output_tensor->get_memory_ptr(), out_data); | |||||
compare_lite_tensor<float>(output_tensor, result_mgb); | |||||
} | |||||
for (size_t i = 0; i < times; i++) { | |||||
compare_lite_tensor<float>(result_tensors[i], result_mgb); | |||||
} | |||||
} | |||||
TEST(TestNetWork, OutputDynamicAlloc) { | |||||
Config config; | |||||
config.options.force_output_dynamic_alloc = true; | |||||
auto tensor = get_input_data("./input_data.npy"); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
std::string input_name = "data"; | |||||
auto result_mgb = mgb_lar(model_path, config, input_name, tensor); | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->load_model(model_path); | |||||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||||
auto src_ptr = tensor->get_memory_ptr(); | |||||
auto src_layout = tensor->get_layout(); | |||||
input_tensor->reset(src_ptr, src_layout); | |||||
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||||
size_t times = 5; | |||||
for (size_t i = 0; i < times; i++) { | |||||
network->forward(); | |||||
network->wait(); | |||||
compare_lite_tensor<float>(output_tensor, result_mgb); | |||||
} | |||||
} | |||||
TEST(TestNetWork, AsyncExec) { | TEST(TestNetWork, AsyncExec) { | ||||
Config config; | Config config; | ||||
config.options.var_sanity_check_first_run = false; | config.options.var_sanity_check_first_run = false; | ||||
@@ -507,13 +507,12 @@ void ComputingGraphImpl::dest_var_optimize(VarNodeArray& dest_vars) { | |||||
i->add_flag(F::NO_MEM_RECLAIM); | i->add_flag(F::NO_MEM_RECLAIM); | ||||
} | } | ||||
} | } | ||||
if (dest_vars[0]->owner_graph()->options().force_output_write_to_user_memory) { | |||||
if (dest_vars[0]->owner_graph()->options().force_output_use_user_specified_memory) { | |||||
for (auto&& i : dest_vars) { | for (auto&& i : dest_vars) { | ||||
mgb_assert( | mgb_assert( | ||||
!i->contain_flag(F::RT_FORCE_DYNAMIC_MEM_ALLOC), | !i->contain_flag(F::RT_FORCE_DYNAMIC_MEM_ALLOC), | ||||
"var %s with force dynamic allocate should be set to write output " | |||||
"to " | |||||
"user memory", | |||||
"var %s with RT_FORCE_DYNAMIC_MEM_ALLOC flag should not set " | |||||
"force write output to user memory", | |||||
i->cname()); | i->cname()); | ||||
i->add_flag( | i->add_flag( | ||||
F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC | | F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC | | ||||
@@ -574,6 +574,10 @@ MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) { | |||||
return m_mem_plan; | return m_mem_plan; | ||||
} | } | ||||
bool VarNode::is_graph_dest_varnode() { | |||||
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0; | |||||
} | |||||
VarNode& VarNode::add_flag(Flag flag) { | VarNode& VarNode::add_flag(Flag flag) { | ||||
modify_flag(flag, m_flag | flag); | modify_flag(flag, m_flag | flag); | ||||
return *this; | return *this; | ||||
@@ -582,10 +586,13 @@ VarNode& VarNode::add_flag(Flag flag) { | |||||
void VarNode::modify_flag(Flag delta, Flag new_flag) { | void VarNode::modify_flag(Flag delta, Flag new_flag) { | ||||
if (contain_flag(Flag::FLAG_FREEZED)) { | if (contain_flag(Flag::FLAG_FREEZED)) { | ||||
mgb_assert( | mgb_assert( | ||||
(delta & (Flag::NO_SYS_MEM_ALLOC | Flag::NO_MEM_RECLAIM | | |||||
Flag::NO_SYS_STATIC_MEM_ALLOC | | |||||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta || | |||||
(new_flag & Flag::MEMORY_NO_NEED)); | |||||
(delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC | | |||||
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC | Flag::MEMORY_NO_NEED)) == | |||||
delta || | |||||
is_graph_dest_varnode(), | |||||
"After the FLAG_FREEZED flag setting, var can only modify " | |||||
"NO_MEM_RECLAIM, NO_SYS_STATIC_MEM_ALLOC, RT_FORCE_DYNAMIC_MEM_ALLOC, " | |||||
"MEMORY_NO_NEED flag except graph dest var."); | |||||
mgb_assert( | mgb_assert( | ||||
!ComputingGraphImpl::downcast(owner_graph()) | !ComputingGraphImpl::downcast(owner_graph()) | ||||
@@ -421,7 +421,7 @@ public: | |||||
* Force the output to be written to the user specified memory, which | * Force the output to be written to the user specified memory, which | ||||
* can optimize the copy of output data at one time | * can optimize the copy of output data at one time | ||||
*/ | */ | ||||
bool force_output_write_to_user_memory = false; | |||||
bool force_output_use_user_specified_memory = false; | |||||
//! whether to perform var sanity check on first run | //! whether to perform var sanity check on first run | ||||
bool var_sanity_check_first_run = true; | bool var_sanity_check_first_run = true; | ||||
@@ -549,6 +549,10 @@ private: | |||||
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); | MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); | ||||
//! whether the var is graph output, if it is output, the Flag of | |||||
//! NO_SYS_MEM_ALLOC can be modified. | |||||
bool is_graph_dest_varnode(); | |||||
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( | MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( | ||||
const DeviceTensorND& value); | const DeviceTensorND& value); | ||||
@@ -82,7 +82,7 @@ TEST(TestNoCopy, BasicInputNoCopy) { | |||||
TEST(TestNoCopy, IONoCopyPtrEQ) { | TEST(TestNoCopy, IONoCopyPtrEQ) { | ||||
auto test_graph = TestGraph(); | auto test_graph = TestGraph(); | ||||
auto compute_graph = test_graph.m_network->graph; | auto compute_graph = test_graph.m_network->graph; | ||||
compute_graph->options().force_output_write_to_user_memory = true; | |||||
compute_graph->options().force_output_use_user_specified_memory = true; | |||||
test_graph.create_graph(); | test_graph.create_graph(); | ||||
auto func = test_graph.compile_without_copy(); | auto func = test_graph.compile_without_copy(); | ||||
auto&& outvar = func->get_output_vars()[0]; | auto&& outvar = func->get_output_vars()[0]; | ||||
@@ -123,7 +123,7 @@ TEST(TestNoCopy, IONoCopyPtrEQ) { | |||||
TEST(TestNoCopy, IONoCopyCorrect) { | TEST(TestNoCopy, IONoCopyCorrect) { | ||||
auto test_graph = TestGraph(); | auto test_graph = TestGraph(); | ||||
auto compute_graph = test_graph.m_network->graph; | auto compute_graph = test_graph.m_network->graph; | ||||
compute_graph->options().force_output_write_to_user_memory = true; | |||||
compute_graph->options().force_output_use_user_specified_memory = true; | |||||
test_graph.create_graph(); | test_graph.create_graph(); | ||||
HostTensorND truth; | HostTensorND truth; | ||||
auto func = test_graph.compile_without_copy(); | auto func = test_graph.compile_without_copy(); | ||||