GitOrigin-RevId: aea62de345
tags/v1.7.2.m1
@@ -510,7 +510,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { | |||||
std::vector<std::vector<std::shared_ptr<Tensor>>> inputs; | std::vector<std::vector<std::shared_ptr<Tensor>>> inputs; | ||||
std::vector<std::vector<std::shared_ptr<Tensor>>> outputs; | std::vector<std::vector<std::shared_ptr<Tensor>>> outputs; | ||||
std::shared_ptr<Network> network = std::make_shared<Network>(); | |||||
Config config; | |||||
config.options.graph_opt_level = 0; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
network->load_model(model_path); | network->load_model(model_path); | ||||
input_names = network->get_all_input_name(); | input_names = network->get_all_input_name(); | ||||
@@ -559,10 +562,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { | |||||
outputs.push_back(net_outputs); | outputs.push_back(net_outputs); | ||||
} | } | ||||
Config config; | |||||
config.options.force_output_use_user_specified_memory = true; | config.options.force_output_use_user_specified_memory = true; | ||||
config.options.comp_node_seq_record_level = record; | config.options.comp_node_seq_record_level = record; | ||||
config.options.const_shape = true; | config.options.const_shape = true; | ||||
config.options.graph_opt_level = 2; | |||||
std::shared_ptr<Network> network_record = std::make_shared<Network>(config); | std::shared_ptr<Network> network_record = std::make_shared<Network>(config); | ||||
@@ -10,6 +10,7 @@ | |||||
*/ | */ | ||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -27,6 +28,35 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( | |||||
return ret; | return ret; | ||||
} | } | ||||
void GraphLoader::LoadResult::graph_compile_ahead() { | |||||
//! when force_output_use_user_specified_memory is set, the output var may | |||||
//! be changed by gopt, then the var in LoadResult can not exist, so here | |||||
//! just do basic optimize_for_inference ahead, and replace the var in | |||||
//! LoadResult | |||||
if (graph->options().force_output_use_user_specified_memory) { | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
auto new_vars = gopt::optimize_for_inference(output_var_list, options); | |||||
output_var_list = new_vars; | |||||
output_var_map.clear(); | |||||
for (auto& var : new_vars) { | |||||
output_var_map[var.node()->cname()] = var; | |||||
} | |||||
std::unordered_map<size_t, SymbolVar> var_map_id; | |||||
for (auto& var : new_vars) { | |||||
bool found = false; | |||||
for (auto& old_var_it : output_var_map_id) { | |||||
if (old_var_it.second.node()->name() == var.node()->name()) { | |||||
found = true; | |||||
var_map_id[old_var_it.first] = var; | |||||
} | |||||
} | |||||
mgb_assert( | |||||
found, "can't find var name %s when optimize_for_inference. ", | |||||
var.node()->cname()); | |||||
} | |||||
} | |||||
} | |||||
GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { | GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { | ||||
SharedTensorNameMap ret; | SharedTensorNameMap ret; | ||||
for (auto&& i : shared_tensor_id_map()) { | for (auto&& i : shared_tensor_id_map()) { | ||||
@@ -946,6 +946,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||||
mgb_assert(fbs_end > cur); | mgb_assert(fbs_end > cur); | ||||
// Skip to Graph end | // Skip to Graph end | ||||
m_file->skip(fbs_end - cur); | m_file->skip(fbs_end - cur); | ||||
result.graph_compile_ahead(); | |||||
return result; | return result; | ||||
} | } | ||||
@@ -63,6 +63,14 @@ public: | |||||
*/ | */ | ||||
MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile( | MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile( | ||||
const ComputingGraph::OutputSpec& outspec); | const ComputingGraph::OutputSpec& outspec); | ||||
/*! | |||||
* \brief after graph is loaded, do some basic optimized_for_inference, | |||||
* because some dest var maybe replaced, case error when optimize flag | |||||
* force_output_use_user_specified_memory is on | |||||
* | |||||
*/ | |||||
MGE_WIN_DECLSPEC_FUC void graph_compile_ahead(); | |||||
}; | }; | ||||
//! helper to disable inplace arith graph optimization during | //! helper to disable inplace arith graph optimization during | ||||