Browse Source

fix(lite): fix lite error when record level is 2

GitOrigin-RevId: 7dabfd8876
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
ce119ef5a5
3 changed files with 22 additions and 1 deletions
  1. +4
    -0
      lite/src/mge/network_impl.cpp
  2. +17
    -0
      lite/test/test_network_options.cpp
  3. +1
    -1
      src/core/impl/graph/cg_impl.cpp

+ 4
- 0
lite/src/mge/network_impl.cpp View File

@@ -436,6 +436,10 @@ void NetworkImplDft::start() const {

void NetworkImplDft::forward() {
start();
if (m_load_config.comp_graph &&
m_user_config->options.comp_node_seq_record_level == 2) {
m_load_config.comp_graph.reset();
}
LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
m_execute_func->execute();
}


+ 17
- 0
lite/test/test_network_options.cpp View File

@@ -89,6 +89,23 @@ TEST(TestNetWorkOptions, const_shape) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}

TEST(TestNetWorkOptions, record2) {
Config config;
std::string model_path = "./shufflenet.mge";

config.options.var_sanity_check_first_run = false;
config.options.const_shape = true;
config.options.comp_node_seq_record_level = 2;
std::shared_ptr<Network> network = std::make_shared<Network>(config);

network->load_model(model_path);

for (int i = 0; i < 3; i++) {
network->forward();
network->wait();
}
}

TEST(TestNetWorkOptions, NCHW44) {
Config config;
auto tensor = get_input_data("./input_data.npy");


+ 1
- 1
src/core/impl/graph/cg_impl.cpp View File

@@ -126,7 +126,7 @@ ComputingGraph::ComputingGraph() {

void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
mgb_assert(
ptr.use_count() == 1, "unexpected use_count: %zu", size_t(ptr.use_count()));
ptr.use_count() <= 2, "unexpected use_count: %zu", size_t(ptr.use_count()));
ptr.reset();
}



Loading…
Cancel
Save