Browse Source

fix(core): fix output var replaced by optpass

GitOrigin-RevId: aea62de345
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
202b407149
4 changed files with 44 additions and 2 deletions
  1. +5
    -2
      lite/test/test_network.cpp
  2. +30
    -0
      src/serialization/impl/serializer.cpp
  3. +1
    -0
      src/serialization/impl/serializer_oss.cpp
  4. +8
    -0
      src/serialization/include/megbrain/serialization/serializer.h

+ 5
- 2
lite/test/test_network.cpp View File

@@ -510,7 +510,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) {
std::vector<std::vector<std::shared_ptr<Tensor>>> inputs;
std::vector<std::vector<std::shared_ptr<Tensor>>> outputs;

std::shared_ptr<Network> network = std::make_shared<Network>();
Config config;

config.options.graph_opt_level = 0;
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);

input_names = network->get_all_input_name();
@@ -559,10 +562,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) {
outputs.push_back(net_outputs);
}

Config config;
config.options.force_output_use_user_specified_memory = true;
config.options.comp_node_seq_record_level = record;
config.options.const_shape = true;
config.options.graph_opt_level = 2;

std::shared_ptr<Network> network_record = std::make_shared<Network>(config);



+ 30
- 0
src/serialization/impl/serializer.cpp View File

@@ -10,6 +10,7 @@
*/

#include "megbrain/serialization/serializer.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/utility.h"

namespace mgb {
@@ -27,6 +28,35 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
return ret;
}

void GraphLoader::LoadResult::graph_compile_ahead() {
//! when force_output_use_user_specified_memory is set, the output var may
//! be changed by gopt, then the var in LoadResult can not exist, so here
//! just do basic optimize_for_inference ahead, and replace the var in
//! LoadResult
if (graph->options().force_output_use_user_specified_memory) {
auto options = gopt::OptimizeForInferenceOptions{};
auto new_vars = gopt::optimize_for_inference(output_var_list, options);
output_var_list = new_vars;
output_var_map.clear();
for (auto& var : new_vars) {
output_var_map[var.node()->cname()] = var;
}
std::unordered_map<size_t, SymbolVar> var_map_id;
for (auto& var : new_vars) {
bool found = false;
for (auto& old_var_it : output_var_map_id) {
if (old_var_it.second.node()->name() == var.node()->name()) {
found = true;
var_map_id[old_var_it.first] = var;
}
}
mgb_assert(
found, "can't find var name %s when optimize_for_inference. ",
var.node()->cname());
}
}
}

GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
SharedTensorNameMap ret;
for (auto&& i : shared_tensor_id_map()) {


+ 1
- 0
src/serialization/impl/serializer_oss.cpp View File

@@ -946,6 +946,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
mgb_assert(fbs_end > cur);
// Skip to Graph end
m_file->skip(fbs_end - cur);
result.graph_compile_ahead();
return result;
}



+ 8
- 0
src/serialization/include/megbrain/serialization/serializer.h View File

@@ -63,6 +63,14 @@ public:
*/
MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile(
const ComputingGraph::OutputSpec& outspec);

/*!
* \brief after graph is loaded, do some basic optimized_for_inference,
* because some dest var maybe replaced, case error when optimize flag
* force_output_use_user_specified_memory is on
*
*/
MGE_WIN_DECLSPEC_FUC void graph_compile_ahead();
};

//! helper to disable inplace arith graph optimization during


Loading…
Cancel
Save