GitOrigin-RevId: d9abb8de9e
HuaHua404-patch-4
@@ -36,16 +36,19 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||
auto model_type = get_model_type(model_path); | |||
if (ModelType::LITE_MODEL == model_type) { | |||
if (FLAGS_lite) { | |||
mgb_log("run model force lite mode\n"); | |||
return std::make_shared<ModelLite>(model_path); | |||
} else if (FLAGS_mdl) { | |||
mgb_log("run model force mdl mode\n"); | |||
return std::make_shared<ModelMdl>(model_path); | |||
} else if (ModelType::LITE_MODEL == model_type) { | |||
return std::make_shared<ModelLite>(model_path); | |||
} else if (ModelType::MEGDL_MODEL == model_type) { | |||
if (FLAGS_lite) | |||
return std::make_shared<ModelLite>(model_path); | |||
else | |||
return std::make_shared<ModelMdl>(model_path); | |||
} else { | |||
return nullptr; | |||
mgb_assert(ModelType::MEGDL_MODEL == model_type); | |||
return std::make_shared<ModelMdl>(model_path); | |||
} | |||
} | |||
DEFINE_bool(lite, false, "use megengine lite interface to run model"); | |||
DEFINE_bool(mdl, false, "use megengine mdl interface to run model"); | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -4,6 +4,7 @@ | |||
#include "helpers/common.h" | |||
#include "megbrain/utils/json.h" | |||
DECLARE_bool(lite); | |||
DECLARE_bool(mdl); | |||
namespace lar { | |||
/*! | |||
@@ -42,6 +42,8 @@ public: | |||
return m_load_result; | |||
} | |||
void update_mdl_load_result(const mgb::SymbolVarArray& output_var_array); | |||
//! get load config for megDL model | |||
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } | |||
@@ -31,6 +31,13 @@ void FastRunOption::config_model_internel<ModelLite>( | |||
LITE_LOG("enable fast-run strategy for algo profile"); | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | |||
} else if ((!m_fast_run_cache.empty() && | |||
!access(m_fast_run_cache.c_str(), F_OK))) { | |||
LITE_LOG( | |||
"detect fast-run cache usable set LITE_ALGO_PROFILE for algo " | |||
"profile"); | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||
static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||
} else { | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||
} | |||
@@ -299,6 +299,75 @@ void FuseConvBiasElemwiseAddOption::config_model( | |||
CONFIG_MODEL_FUN; | |||
} | |||
///////////////////////// optimize for inference options /////////////// | |||
bool OptimizeForInferenceOption::m_valid; | |||
namespace lar { | |||
template <> | |||
void OptimizeForInferenceOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
LITE_MARK_USED_VAR(model); | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto optimize_for_infer = | |||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||
->get_value(); | |||
if (optimize_for_infer) { | |||
LITE_THROW( | |||
"optimize for inference not supported in lite " | |||
"model"); | |||
} | |||
} | |||
} | |||
template <> | |||
void OptimizeForInferenceOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
auto optimize_for_infer = | |||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||
->get_value(); | |||
if (optimize_for_infer) { | |||
mgb_log("enable optimize for inference optimization"); | |||
auto&& load_result = model->get_mdl_load_result(); | |||
mgb::cg::GraphCommonOptimizeOptions opt = | |||
model->get_mdl_load_result().graph->options().graph_opt; | |||
auto inference_opt2 = mgb::gopt::OptimizeForInferenceOptions(opt); | |||
auto output_var_list = mgb::gopt::optimize_for_inference( | |||
load_result.output_var_list, inference_opt2); | |||
model->get_mdl_load_result().update_output_var_list(output_var_list); | |||
model->get_mdl_load_result().graph->options().graph_opt.clear(); | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
void OptimizeForInferenceOption::update() { | |||
m_option_name = "optimize_for_inference"; | |||
m_option = {{"optimize_for_inference", lar::Bool::make(false)}}; | |||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||
->set_value(FLAGS_optimize_for_inference); | |||
} | |||
bool OptimizeForInferenceOption::is_valid() { | |||
bool ret = FLAGS_optimize_for_inference; | |||
return ret || m_valid; | |||
} | |||
std::shared_ptr<OptionBase> OptimizeForInferenceOption::create_option() { | |||
static std::shared_ptr<OptimizeForInferenceOption> option( | |||
new OptimizeForInferenceOption); | |||
if (OptimizeForInferenceOption::is_valid()) { | |||
option->update(); | |||
return option; | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
void OptimizeForInferenceOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
CONFIG_MODEL_FUN; | |||
} | |||
///////////////////////// graph retrict options ///////////////////////// | |||
bool GraphRecordOption::m_valid; | |||
namespace lar { | |||
@@ -646,6 +715,9 @@ DEFINE_bool( | |||
enable_fuse_conv_bias_with_z, false, | |||
"fuse conv, bias (elemwise add), z(elemwise add) into one opr " | |||
"(only support on GPU)"); | |||
DEFINE_bool( | |||
optimize_for_inference, false, | |||
"whether to optimize_for_inference, fuse bn and many base optimize"); | |||
///////////////////////// graph retrict options ///////////////////////// | |||
DEFINE_bool( | |||
@@ -700,6 +772,11 @@ REGIST_OPTION_VALIDATER( | |||
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); | |||
REGIST_OPTION_CREATOR( | |||
optimize_for_inference, lar::OptimizeForInferenceOption::create_option); | |||
REGIST_OPTION_VALIDATER( | |||
optimize_for_inference, lar::OptimizeForInferenceOption::set_valid); | |||
REGIST_OPTION_CREATOR( | |||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||
REGIST_OPTION_VALIDATER( | |||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid); | |||
@@ -5,6 +5,7 @@ | |||
#include "option_base.h" | |||
DECLARE_bool(enable_fuse_preprocess); | |||
DECLARE_bool(optimize_for_inference); | |||
DECLARE_bool(fuse_grain); | |||
DECLARE_bool(weight_preprocess); | |||
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | |||
@@ -216,6 +217,34 @@ private: | |||
uint64_t workspace_limit; | |||
}; | |||
///////////////////////// optimize for inference options ///////////////////////// | |||
class OptimizeForInferenceOption final : public OptionBase { | |||
public: | |||
static bool is_valid(); | |||
static std::shared_ptr<OptionBase> create_option(); | |||
void config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
static void set_valid(bool val) { m_valid = val; } | |||
std::string option_name() const override { return m_option_name; }; | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
OptimizeForInferenceOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
std::string m_option_name; | |||
static bool m_valid; | |||
OptionValMap m_option; | |||
}; | |||
///////////////////////// other options for optimization ///////////////// | |||
class JITOption final : public OptionBase { | |||
public: | |||
@@ -366,19 +366,7 @@ void NetworkImplDft::layout_transform_optimization() { | |||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | |||
auto output_var_array = mgb::gopt::layout_transform( | |||
m_load_result.output_var_list, m_layout_transform_target); | |||
// replace symvar in output_var_list | |||
for (size_t idx = 0; idx < output_var_array.size(); ++idx) { | |||
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx]; | |||
m_load_result.output_var_list[idx] = output_var_array[idx]; | |||
} | |||
// replace symvar in output_var_map_id | |||
for (auto&& item : m_load_result.output_var_map_id) { | |||
item.second = out_var_map[item.second]; | |||
} | |||
// replace symvar in output_var_map | |||
for (auto&& item : m_load_result.output_var_map) { | |||
item.second = out_var_map[item.second]; | |||
} | |||
m_load_result.update_output_var_list(output_var_array); | |||
} else if (m_user_config->auto_optimize_inference) { | |||
//! set model weight preprocess | |||
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; | |||
@@ -8,6 +8,7 @@ | |||
#include "../src/misc.h" | |||
#include "lite/network.h" | |||
#include "lite/tensor.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/graph/bases.h" | |||
#include "megbrain/plugin/opr_io_dump.h" | |||
#include "megbrain/plugin/profiler.h" | |||
@@ -167,4 +168,18 @@ __attribute__((unused)) static std::shared_ptr<Tensor> mgb_lar( | |||
#endif | |||
static inline bool check_gpu_available(size_t num) { | |||
if (mgb::CompNode::get_device_count(mgb::CompNode::DeviceType::CUDA) < num) { | |||
mgb_log_warn("skip test case that requires %zu GPU(s)", num); | |||
return false; | |||
} | |||
return true; | |||
} | |||
#define REQUIRE_CUDA() \ | |||
{ \ | |||
if (!check_gpu_available(1)) { \ | |||
return; \ | |||
} \ | |||
} \ | |||
while (0) | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,51 @@ | |||
#include <gtest/gtest.h> | |||
#include <string.h> | |||
#include <memory> | |||
#include "test_common.h" | |||
#include "test_options.h" | |||
using namespace lar; | |||
DECLARE_bool(lite); | |||
DECLARE_bool(cpu); | |||
DECLARE_bool(optimize_for_inference); | |||
#if LITE_WITH_CUDA | |||
DECLARE_bool(cuda); | |||
#endif | |||
namespace { | |||
BOOL_OPTION_WRAP(optimize_for_inference); | |||
BOOL_OPTION_WRAP(lite); | |||
BOOL_OPTION_WRAP(cpu); | |||
#if LITE_WITH_CUDA | |||
BOOL_OPTION_WRAP(cuda); | |||
#endif | |||
} // anonymous namespace | |||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE) { | |||
DEFINE_WRAP(cpu); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(optimize_for_inference); | |||
} | |||
#if LITE_WITH_OPENCL | |||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_OPENCL) { | |||
REQUIRE_OPENCL(); | |||
DEFINE_WRAP(opencl); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(optimize_for_inference); | |||
} | |||
#endif | |||
#if LITE_WITH_CUDA | |||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_CUDA) { | |||
REQUIRE_CUDA(); | |||
DEFINE_WRAP(cuda); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(optimize_for_inference); | |||
} | |||
#endif |
@@ -1,6 +1,7 @@ | |||
#include <gtest/gtest.h> | |||
#include <string.h> | |||
#include <memory> | |||
#include "test_common.h" | |||
#include "test_options.h" | |||
using namespace lar; | |||
@@ -109,6 +109,16 @@ struct GraphCommonOptimizeOptions { | |||
///< support on Nvidia GPU | |||
}; | |||
LayoutTransform layout_transform = LayoutTransform::DEFAULT; | |||
void clear() { | |||
f16_io_f32_comp = false; | |||
f16_io_comp = false; | |||
fuse_conv_bias_nonlinearity = false; | |||
fuse_conv_bias_with_z = false; | |||
weight_preprocess = false; | |||
fuse_preprocess = false; | |||
fuse_grain = false; | |||
layout_transform = LayoutTransform::DEFAULT; | |||
} | |||
#define SET(n) \ | |||
GraphCommonOptimizeOptions& enable_##n() { \ | |||
@@ -312,6 +312,9 @@ public: | |||
}; | |||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||
OptimizeForInferenceOptions() = default; | |||
OptimizeForInferenceOptions(const cg::GraphCommonOptimizeOptions& opt) | |||
: cg::GraphCommonOptimizeOptions(opt){}; | |||
uint64_t serialize() { | |||
uint64_t ret = 0; | |||
ret |= (uint64_t)layout_transform << 32; | |||
@@ -17,6 +17,25 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( | |||
return ret; | |||
} | |||
void GraphLoader::LoadResult::update_output_var_list( | |||
const SymbolVarArray& output_var_array) { | |||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | |||
mgb_assert(output_var_array.size() == output_var_list.size()); | |||
// replace symvar in output_var_list | |||
for (size_t idx = 0; idx < output_var_array.size(); ++idx) { | |||
out_var_map[output_var_list[idx]] = output_var_array[idx]; | |||
output_var_list[idx] = output_var_array[idx]; | |||
} | |||
// replace symvar in output_var_map_id | |||
for (auto&& item : output_var_map_id) { | |||
item.second = out_var_map[item.second]; | |||
} | |||
// replace symvar in output_var_map | |||
for (auto&& item : output_var_map) { | |||
item.second = out_var_map[item.second].rename(item.first); | |||
} | |||
} | |||
void GraphLoader::LoadResult::graph_compile_ahead() { | |||
//! when force_output_use_user_specified_memory is set, the output var may | |||
//! be changed by gopt, then the var in LoadResult can not exist, so here | |||
@@ -45,6 +45,13 @@ public: | |||
//! GraphDumper::dump | |||
SymbolVarArray output_var_list; | |||
/** | |||
* \brief update output_var_list with output_var_map, output_var_map_id | |||
* | |||
*/ | |||
MGE_WIN_DECLSPEC_FUC void update_output_var_list( | |||
const SymbolVarArray& output_var_array); | |||
/*! | |||
* \brief call graph->compile() but also checks for comp seq rec | |||
* | |||