GitOrigin-RevId: d9abb8de9e
HuaHua404-patch-4
@@ -36,16 +36,19 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||||
auto model_type = get_model_type(model_path); | auto model_type = get_model_type(model_path); | ||||
if (ModelType::LITE_MODEL == model_type) { | |||||
if (FLAGS_lite) { | |||||
mgb_log("run model force lite mode\n"); | |||||
return std::make_shared<ModelLite>(model_path); | |||||
} else if (FLAGS_mdl) { | |||||
mgb_log("run model force mdl mode\n"); | |||||
return std::make_shared<ModelMdl>(model_path); | |||||
} else if (ModelType::LITE_MODEL == model_type) { | |||||
return std::make_shared<ModelLite>(model_path); | return std::make_shared<ModelLite>(model_path); | ||||
} else if (ModelType::MEGDL_MODEL == model_type) { | |||||
if (FLAGS_lite) | |||||
return std::make_shared<ModelLite>(model_path); | |||||
else | |||||
return std::make_shared<ModelMdl>(model_path); | |||||
} else { | } else { | ||||
return nullptr; | |||||
mgb_assert(ModelType::MEGDL_MODEL == model_type); | |||||
return std::make_shared<ModelMdl>(model_path); | |||||
} | } | ||||
} | } | ||||
DEFINE_bool(lite, false, "use megengine lite interface to run model"); | DEFINE_bool(lite, false, "use megengine lite interface to run model"); | ||||
DEFINE_bool(mdl, false, "use megengine mdl interface to run model"); | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -4,6 +4,7 @@ | |||||
#include "helpers/common.h" | #include "helpers/common.h" | ||||
#include "megbrain/utils/json.h" | #include "megbrain/utils/json.h" | ||||
DECLARE_bool(lite); | DECLARE_bool(lite); | ||||
DECLARE_bool(mdl); | |||||
namespace lar { | namespace lar { | ||||
/*! | /*! | ||||
@@ -42,6 +42,8 @@ public: | |||||
return m_load_result; | return m_load_result; | ||||
} | } | ||||
void update_mdl_load_result(const mgb::SymbolVarArray& output_var_array); | |||||
//! get load config for megDL model | //! get load config for megDL model | ||||
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } | mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } | ||||
@@ -31,6 +31,13 @@ void FastRunOption::config_model_internel<ModelLite>( | |||||
LITE_LOG("enable fast-run strategy for algo profile"); | LITE_LOG("enable fast-run strategy for algo profile"); | ||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | ||||
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | ||||
} else if ((!m_fast_run_cache.empty() && | |||||
!access(m_fast_run_cache.c_str(), F_OK))) { | |||||
LITE_LOG( | |||||
"detect fast-run cache usable set LITE_ALGO_PROFILE for algo " | |||||
"profile"); | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||||
static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||||
} else { | } else { | ||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | ||||
} | } | ||||
@@ -299,6 +299,75 @@ void FuseConvBiasElemwiseAddOption::config_model( | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// optimize for inference options /////////////// | |||||
bool OptimizeForInferenceOption::m_valid; | |||||
namespace lar { | |||||
template <> | |||||
void OptimizeForInferenceOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
LITE_MARK_USED_VAR(model); | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto optimize_for_infer = | |||||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||||
->get_value(); | |||||
if (optimize_for_infer) { | |||||
LITE_THROW( | |||||
"optimize for inference not supported in lite " | |||||
"model"); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void OptimizeForInferenceOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto optimize_for_infer = | |||||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||||
->get_value(); | |||||
if (optimize_for_infer) { | |||||
mgb_log("enable optimize for inference optimization"); | |||||
auto&& load_result = model->get_mdl_load_result(); | |||||
mgb::cg::GraphCommonOptimizeOptions opt = | |||||
model->get_mdl_load_result().graph->options().graph_opt; | |||||
auto inference_opt2 = mgb::gopt::OptimizeForInferenceOptions(opt); | |||||
auto output_var_list = mgb::gopt::optimize_for_inference( | |||||
load_result.output_var_list, inference_opt2); | |||||
model->get_mdl_load_result().update_output_var_list(output_var_list); | |||||
model->get_mdl_load_result().graph->options().graph_opt.clear(); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
void OptimizeForInferenceOption::update() { | |||||
m_option_name = "optimize_for_inference"; | |||||
m_option = {{"optimize_for_inference", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"]) | |||||
->set_value(FLAGS_optimize_for_inference); | |||||
} | |||||
bool OptimizeForInferenceOption::is_valid() { | |||||
bool ret = FLAGS_optimize_for_inference; | |||||
return ret || m_valid; | |||||
} | |||||
std::shared_ptr<OptionBase> OptimizeForInferenceOption::create_option() { | |||||
static std::shared_ptr<OptimizeForInferenceOption> option( | |||||
new OptimizeForInferenceOption); | |||||
if (OptimizeForInferenceOption::is_valid()) { | |||||
option->update(); | |||||
return option; | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void OptimizeForInferenceOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// graph retrict options ///////////////////////// | ///////////////////////// graph retrict options ///////////////////////// | ||||
bool GraphRecordOption::m_valid; | bool GraphRecordOption::m_valid; | ||||
namespace lar { | namespace lar { | ||||
@@ -646,6 +715,9 @@ DEFINE_bool( | |||||
enable_fuse_conv_bias_with_z, false, | enable_fuse_conv_bias_with_z, false, | ||||
"fuse conv, bias (elemwise add), z(elemwise add) into one opr " | "fuse conv, bias (elemwise add), z(elemwise add) into one opr " | ||||
"(only support on GPU)"); | "(only support on GPU)"); | ||||
DEFINE_bool( | |||||
optimize_for_inference, false, | |||||
"whether to optimize_for_inference, fuse bn and many base optimize"); | |||||
///////////////////////// graph retrict options ///////////////////////// | ///////////////////////// graph retrict options ///////////////////////// | ||||
DEFINE_bool( | DEFINE_bool( | ||||
@@ -700,6 +772,11 @@ REGIST_OPTION_VALIDATER( | |||||
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); | fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); | ||||
REGIST_OPTION_CREATOR( | REGIST_OPTION_CREATOR( | ||||
optimize_for_inference, lar::OptimizeForInferenceOption::create_option); | |||||
REGIST_OPTION_VALIDATER( | |||||
optimize_for_inference, lar::OptimizeForInferenceOption::set_valid); | |||||
REGIST_OPTION_CREATOR( | |||||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); | fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); | ||||
REGIST_OPTION_VALIDATER( | REGIST_OPTION_VALIDATER( | ||||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid); | fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid); | ||||
@@ -5,6 +5,7 @@ | |||||
#include "option_base.h" | #include "option_base.h" | ||||
DECLARE_bool(enable_fuse_preprocess); | DECLARE_bool(enable_fuse_preprocess); | ||||
DECLARE_bool(optimize_for_inference); | |||||
DECLARE_bool(fuse_grain); | DECLARE_bool(fuse_grain); | ||||
DECLARE_bool(weight_preprocess); | DECLARE_bool(weight_preprocess); | ||||
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | ||||
@@ -216,6 +217,34 @@ private: | |||||
uint64_t workspace_limit; | uint64_t workspace_limit; | ||||
}; | }; | ||||
///////////////////////// optimize for inference options ///////////////////////// | |||||
class OptimizeForInferenceOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
static void set_valid(bool val) { m_valid = val; } | |||||
std::string option_name() const override { return m_option_name; }; | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
void update() override; | |||||
private: | |||||
OptimizeForInferenceOption() = default; | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | |||||
///////////////////////// other options for optimization ///////////////// | ///////////////////////// other options for optimization ///////////////// | ||||
class JITOption final : public OptionBase { | class JITOption final : public OptionBase { | ||||
public: | public: | ||||
@@ -366,19 +366,7 @@ void NetworkImplDft::layout_transform_optimization() { | |||||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | ||||
auto output_var_array = mgb::gopt::layout_transform( | auto output_var_array = mgb::gopt::layout_transform( | ||||
m_load_result.output_var_list, m_layout_transform_target); | m_load_result.output_var_list, m_layout_transform_target); | ||||
// replace symvar in output_var_list | |||||
for (size_t idx = 0; idx < output_var_array.size(); ++idx) { | |||||
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx]; | |||||
m_load_result.output_var_list[idx] = output_var_array[idx]; | |||||
} | |||||
// replace symvar in output_var_map_id | |||||
for (auto&& item : m_load_result.output_var_map_id) { | |||||
item.second = out_var_map[item.second]; | |||||
} | |||||
// replace symvar in output_var_map | |||||
for (auto&& item : m_load_result.output_var_map) { | |||||
item.second = out_var_map[item.second]; | |||||
} | |||||
m_load_result.update_output_var_list(output_var_array); | |||||
} else if (m_user_config->auto_optimize_inference) { | } else if (m_user_config->auto_optimize_inference) { | ||||
//! set model weight preprocess | //! set model weight preprocess | ||||
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; | m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; | ||||
@@ -8,6 +8,7 @@ | |||||
#include "../src/misc.h" | #include "../src/misc.h" | ||||
#include "lite/network.h" | #include "lite/network.h" | ||||
#include "lite/tensor.h" | #include "lite/tensor.h" | ||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/graph/bases.h" | #include "megbrain/graph/bases.h" | ||||
#include "megbrain/plugin/opr_io_dump.h" | #include "megbrain/plugin/opr_io_dump.h" | ||||
#include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
@@ -167,4 +168,18 @@ __attribute__((unused)) static std::shared_ptr<Tensor> mgb_lar( | |||||
#endif | #endif | ||||
static inline bool check_gpu_available(size_t num) { | |||||
if (mgb::CompNode::get_device_count(mgb::CompNode::DeviceType::CUDA) < num) { | |||||
mgb_log_warn("skip test case that requires %zu GPU(s)", num); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
#define REQUIRE_CUDA() \ | |||||
{ \ | |||||
if (!check_gpu_available(1)) { \ | |||||
return; \ | |||||
} \ | |||||
} \ | |||||
while (0) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,51 @@ | |||||
#include <gtest/gtest.h> | |||||
#include <string.h> | |||||
#include <memory> | |||||
#include "test_common.h" | |||||
#include "test_options.h" | |||||
using namespace lar; | |||||
DECLARE_bool(lite); | |||||
DECLARE_bool(cpu); | |||||
DECLARE_bool(optimize_for_inference); | |||||
#if LITE_WITH_CUDA | |||||
DECLARE_bool(cuda); | |||||
#endif | |||||
namespace { | |||||
BOOL_OPTION_WRAP(optimize_for_inference); | |||||
BOOL_OPTION_WRAP(lite); | |||||
BOOL_OPTION_WRAP(cpu); | |||||
#if LITE_WITH_CUDA | |||||
BOOL_OPTION_WRAP(cuda); | |||||
#endif | |||||
} // anonymous namespace | |||||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE) { | |||||
DEFINE_WRAP(cpu); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(optimize_for_inference); | |||||
} | |||||
#if LITE_WITH_OPENCL | |||||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_OPENCL) { | |||||
REQUIRE_OPENCL(); | |||||
DEFINE_WRAP(opencl); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(optimize_for_inference); | |||||
} | |||||
#endif | |||||
#if LITE_WITH_CUDA | |||||
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_CUDA) { | |||||
REQUIRE_CUDA(); | |||||
DEFINE_WRAP(cuda); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(optimize_for_inference); | |||||
} | |||||
#endif |
@@ -1,6 +1,7 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include <string.h> | #include <string.h> | ||||
#include <memory> | #include <memory> | ||||
#include "test_common.h" | |||||
#include "test_options.h" | #include "test_options.h" | ||||
using namespace lar; | using namespace lar; | ||||
@@ -109,6 +109,16 @@ struct GraphCommonOptimizeOptions { | |||||
///< support on Nvidia GPU | ///< support on Nvidia GPU | ||||
}; | }; | ||||
LayoutTransform layout_transform = LayoutTransform::DEFAULT; | LayoutTransform layout_transform = LayoutTransform::DEFAULT; | ||||
void clear() { | |||||
f16_io_f32_comp = false; | |||||
f16_io_comp = false; | |||||
fuse_conv_bias_nonlinearity = false; | |||||
fuse_conv_bias_with_z = false; | |||||
weight_preprocess = false; | |||||
fuse_preprocess = false; | |||||
fuse_grain = false; | |||||
layout_transform = LayoutTransform::DEFAULT; | |||||
} | |||||
#define SET(n) \ | #define SET(n) \ | ||||
GraphCommonOptimizeOptions& enable_##n() { \ | GraphCommonOptimizeOptions& enable_##n() { \ | ||||
@@ -312,6 +312,9 @@ public: | |||||
}; | }; | ||||
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | ||||
OptimizeForInferenceOptions() = default; | |||||
OptimizeForInferenceOptions(const cg::GraphCommonOptimizeOptions& opt) | |||||
: cg::GraphCommonOptimizeOptions(opt){}; | |||||
uint64_t serialize() { | uint64_t serialize() { | ||||
uint64_t ret = 0; | uint64_t ret = 0; | ||||
ret |= (uint64_t)layout_transform << 32; | ret |= (uint64_t)layout_transform << 32; | ||||
@@ -17,6 +17,25 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( | |||||
return ret; | return ret; | ||||
} | } | ||||
void GraphLoader::LoadResult::update_output_var_list( | |||||
const SymbolVarArray& output_var_array) { | |||||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | |||||
mgb_assert(output_var_array.size() == output_var_list.size()); | |||||
// replace symvar in output_var_list | |||||
for (size_t idx = 0; idx < output_var_array.size(); ++idx) { | |||||
out_var_map[output_var_list[idx]] = output_var_array[idx]; | |||||
output_var_list[idx] = output_var_array[idx]; | |||||
} | |||||
// replace symvar in output_var_map_id | |||||
for (auto&& item : output_var_map_id) { | |||||
item.second = out_var_map[item.second]; | |||||
} | |||||
// replace symvar in output_var_map | |||||
for (auto&& item : output_var_map) { | |||||
item.second = out_var_map[item.second].rename(item.first); | |||||
} | |||||
} | |||||
void GraphLoader::LoadResult::graph_compile_ahead() { | void GraphLoader::LoadResult::graph_compile_ahead() { | ||||
//! when force_output_use_user_specified_memory is set, the output var may | //! when force_output_use_user_specified_memory is set, the output var may | ||||
//! be changed by gopt, then the var in LoadResult can not exist, so here | //! be changed by gopt, then the var in LoadResult can not exist, so here | ||||
@@ -45,6 +45,13 @@ public: | |||||
//! GraphDumper::dump | //! GraphDumper::dump | ||||
SymbolVarArray output_var_list; | SymbolVarArray output_var_list; | ||||
/** | |||||
* \brief update output_var_list with output_var_map, output_var_map_id | |||||
* | |||||
*/ | |||||
MGE_WIN_DECLSPEC_FUC void update_output_var_list( | |||||
const SymbolVarArray& output_var_array); | |||||
/*! | /*! | ||||
* \brief call graph->compile() but also checks for comp seq rec | * \brief call graph->compile() but also checks for comp seq rec | ||||
* | * | ||||