|
|
@@ -10,6 +10,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "megbrain/serialization/serializer.h" |
|
|
|
#include "megbrain/gopt/inference.h" |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
|
|
|
|
|
namespace mgb { |
|
|
@@ -27,6 +28,35 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphLoader::LoadResult::graph_compile_ahead() { |
|
|
|
//! when force_output_use_user_specified_memory is set, the output var may |
|
|
|
//! be changed by gopt, then the var in LoadResult can not exist, so here |
|
|
|
//! just do basic optimize_for_inference ahead, and replace the var in |
|
|
|
//! LoadResult |
|
|
|
if (graph->options().force_output_use_user_specified_memory) { |
|
|
|
auto options = gopt::OptimizeForInferenceOptions{}; |
|
|
|
auto new_vars = gopt::optimize_for_inference(output_var_list, options); |
|
|
|
output_var_list = new_vars; |
|
|
|
output_var_map.clear(); |
|
|
|
for (auto& var : new_vars) { |
|
|
|
output_var_map[var.node()->cname()] = var; |
|
|
|
} |
|
|
|
std::unordered_map<size_t, SymbolVar> var_map_id; |
|
|
|
for (auto& var : new_vars) { |
|
|
|
bool found = false; |
|
|
|
for (auto& old_var_it : output_var_map_id) { |
|
|
|
if (old_var_it.second.node()->name() == var.node()->name()) { |
|
|
|
found = true; |
|
|
|
var_map_id[old_var_it.first] = var; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert( |
|
|
|
found, "can't find var name %s when optimize_for_inference. ", |
|
|
|
var.node()->cname()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { |
|
|
|
SharedTensorNameMap ret; |
|
|
|
for (auto&& i : shared_tensor_id_map()) { |
|
|
|