GitOrigin-RevId: 7dabfd8876
tags/v1.7.2.m1
@@ -436,6 +436,10 @@ void NetworkImplDft::start() const { | |||||
void NetworkImplDft::forward() { | void NetworkImplDft::forward() { | ||||
start(); | start(); | ||||
if (m_load_config.comp_graph && | |||||
m_user_config->options.comp_node_seq_record_level == 2) { | |||||
m_load_config.comp_graph.reset(); | |||||
} | |||||
LITE_ASSERT(m_execute_func, "forward must be called after network loaded."); | LITE_ASSERT(m_execute_func, "forward must be called after network loaded."); | ||||
m_execute_func->execute(); | m_execute_func->execute(); | ||||
} | } | ||||
@@ -89,6 +89,23 @@ TEST(TestNetWorkOptions, const_shape) { | |||||
compare_lite_tensor<float>(output_tensor, result_mgb); | compare_lite_tensor<float>(output_tensor, result_mgb); | ||||
} | } | ||||
TEST(TestNetWorkOptions, record2) { | |||||
Config config; | |||||
std::string model_path = "./shufflenet.mge"; | |||||
config.options.var_sanity_check_first_run = false; | |||||
config.options.const_shape = true; | |||||
config.options.comp_node_seq_record_level = 2; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->load_model(model_path); | |||||
for (int i = 0; i < 3; i++) { | |||||
network->forward(); | |||||
network->wait(); | |||||
} | |||||
} | |||||
TEST(TestNetWorkOptions, NCHW44) { | TEST(TestNetWorkOptions, NCHW44) { | ||||
Config config; | Config config; | ||||
auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||
@@ -126,7 +126,7 @@ ComputingGraph::ComputingGraph() { | |||||
void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) { | void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) { | ||||
mgb_assert( | mgb_assert( | ||||
ptr.use_count() == 1, "unexpected use_count: %zu", size_t(ptr.use_count())); | |||||
ptr.use_count() <= 2, "unexpected use_count: %zu", size_t(ptr.use_count())); | |||||
ptr.reset(); | ptr.reset(); | ||||
} | } | ||||