@@ -1,6 +1,6 @@ | |||||
load("//brain/megbrain/lite:flags.bzl","pthread_select", "lite_opts") | load("//brain/megbrain/lite:flags.bzl","pthread_select", "lite_opts") | ||||
cc_library( | cc_library( | ||||
name = "mgblar", | |||||
name = "lar_object", | |||||
srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), | srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), | ||||
hdrs = glob(["src/**/*.h"]), | hdrs = glob(["src/**/*.h"]), | ||||
includes = ["src"], | includes = ["src"], | ||||
@@ -28,7 +28,8 @@ cc_megvii_binary( | |||||
"no_exceptions", | "no_exceptions", | ||||
"no_rtti", | "no_rtti", | ||||
]), | ]), | ||||
internal_deps = [":mgblar"], | |||||
internal_deps = [":lar_object"], | |||||
visibility = ["//visibility:public"], | visibility = ["//visibility:public"], | ||||
) | ) | ||||
@@ -1,23 +1,35 @@ | |||||
# BUILD the load and run for lite | # BUILD the load and run for lite | ||||
include_directories(PUBLIC | include_directories(PUBLIC | ||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | ||||
file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||||
file(GLOB_RECURSE SOURCES src/**/*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||||
add_executable(load_and_run ${SOURCES}) | |||||
target_link_libraries(load_and_run lite_static) | |||||
target_link_libraries(load_and_run megbrain) | |||||
add_library(lar_object OBJECT ${SOURCES}) | |||||
target_link_libraries(lar_object lite_static) | |||||
target_link_libraries(lar_object megbrain) | |||||
if(APPLE) | if(APPLE) | ||||
target_link_libraries(load_and_run gflags) | |||||
target_link_libraries(lar_object gflags) | |||||
else() | else() | ||||
target_link_libraries(load_and_run gflags -Wl,--version-script=${MGE_VERSION_SCRIPT}) | |||||
target_link_libraries(lar_object gflags -Wl,--version-script=${MGE_VERSION_SCRIPT}) | |||||
endif() | |||||
if(LITE_BUILD_WITH_MGE | |||||
AND NOT WIN32 | |||||
AND NOT APPLE) | |||||
# FXIME third_party cpp redis do not support build with clang-cl | |||||
target_include_directories(lar_object PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
endif() | endif() | ||||
add_executable(load_and_run src/main.cpp) | |||||
target_link_libraries(load_and_run lar_object) | |||||
if(LITE_BUILD_WITH_RKNPU) | if(LITE_BUILD_WITH_RKNPU) | ||||
# rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check | # rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check | ||||
target_link_options(load_and_run PRIVATE "-fuse-ld=gold") | target_link_options(load_and_run PRIVATE "-fuse-ld=gold") | ||||
endif() | endif() | ||||
if(MGE_WITH_ROCM) | |||||
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM) | |||||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||||
# FIXME: hip obj can not find cpp obj only through lite_static | # FIXME: hip obj can not find cpp obj only through lite_static | ||||
target_link_libraries(load_and_run megdnn) | target_link_libraries(load_and_run megdnn) | ||||
endif() | endif() | ||||
@@ -30,17 +42,11 @@ if(UNIX) | |||||
endif() | endif() | ||||
endif() | endif() | ||||
if(LITE_BUILD_WITH_MGE | |||||
AND NOT WIN32 | |||||
AND NOT APPLE) | |||||
# FXIME third_party cpp redis do not support build with clang-cl | |||||
target_include_directories(load_and_run PRIVATE ${CPP_REDIS_INCLUDES}) | |||||
endif() | |||||
install( | install( | ||||
TARGETS load_and_run | TARGETS load_and_run | ||||
EXPORT ${LITE_EXPORT_TARGETS} | EXPORT ${LITE_EXPORT_TARGETS} | ||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) | RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) | ||||
if(BUILD_SHARED_LIBS) | if(BUILD_SHARED_LIBS) | ||||
if(LITE_BUILD_WITH_MGE | if(LITE_BUILD_WITH_MGE | ||||
AND NOT WIN32 | AND NOT WIN32 | ||||
@@ -48,7 +54,7 @@ if(BUILD_SHARED_LIBS) | |||||
# FXIME third_party cpp redis do not support build with clang-cl | # FXIME third_party cpp redis do not support build with clang-cl | ||||
list(APPEND SOURCES ${CPP_REDIS_SRCS}) | list(APPEND SOURCES ${CPP_REDIS_SRCS}) | ||||
endif() | endif() | ||||
add_executable(load_and_run_depends_shared ${SOURCES}) | |||||
add_executable(load_and_run_depends_shared ${SOURCES} src/main.cpp) | |||||
target_link_libraries(load_and_run_depends_shared lite_shared) | target_link_libraries(load_and_run_depends_shared lite_shared) | ||||
target_link_libraries(load_and_run_depends_shared gflags) | target_link_libraries(load_and_run_depends_shared gflags) | ||||
target_link_libraries(load_and_run_depends_shared megengine) | target_link_libraries(load_and_run_depends_shared megengine) | ||||
@@ -58,7 +64,8 @@ if(BUILD_SHARED_LIBS) | |||||
target_link_options(load_and_run_depends_shared PRIVATE "-fuse-ld=gold") | target_link_options(load_and_run_depends_shared PRIVATE "-fuse-ld=gold") | ||||
endif() | endif() | ||||
if(MGE_WITH_ROCM) | |||||
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM) | |||||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||||
# FIXME: hip obj can not find cpp obj only through lite_static | # FIXME: hip obj can not find cpp obj only through lite_static | ||||
target_link_libraries(load_and_run_depends_shared megdnn) | target_link_libraries(load_and_run_depends_shared megdnn) | ||||
endif() | endif() | ||||
@@ -30,6 +30,8 @@ enum class RunStage { | |||||
AFTER_MODEL_RUNNING = 7, | AFTER_MODEL_RUNNING = 7, | ||||
GLOBAL_OPTIMIZATION = 8, | GLOBAL_OPTIMIZATION = 8, | ||||
UPDATE_IO = 9, | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief: type of different model | * \brief: type of different model | ||||
@@ -1,7 +1,7 @@ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <string> | #include <string> | ||||
#include "misc.h" | |||||
#include "strategys/strategy.h" | #include "strategys/strategy.h" | ||||
std::string simple_usage = R"( | std::string simple_usage = R"( | ||||
load_and_run: load_and_run <model_path> [options Flags...] | load_and_run: load_and_run <model_path> [options Flags...] | ||||
@@ -29,6 +29,8 @@ More details using "--help" to get!! | |||||
)"; | )"; | ||||
int main(int argc, char** argv) { | int main(int argc, char** argv) { | ||||
mgb::set_log_level(mgb::LogLevel::INFO); | |||||
lite::set_log_level(LiteLogLevel::INFO); | |||||
std::string usage = "load_and_run <model_path> [options Flags...]"; | std::string usage = "load_and_run <model_path> [options Flags...]"; | ||||
if (argc < 2) { | if (argc < 2) { | ||||
printf("usage: %s\n", simple_usage.c_str()); | printf("usage: %s\n", simple_usage.c_str()); | ||||
@@ -8,17 +8,17 @@ DECLARE_bool(share_param_mem); | |||||
using namespace lar; | using namespace lar; | ||||
ModelLite::ModelLite(const std::string& path) : model_path(path) { | ModelLite::ModelLite(const std::string& path) : model_path(path) { | ||||
LITE_WARN("creat lite model use CPU as default comp node"); | |||||
LITE_LOG("creat lite model use CPU as default comp node"); | |||||
}; | }; | ||||
void ModelLite::load_model() { | void ModelLite::load_model() { | ||||
m_network = std::make_shared<lite::Network>(config, IO); | m_network = std::make_shared<lite::Network>(config, IO); | ||||
if (enable_layout_transform) { | if (enable_layout_transform) { | ||||
LITE_WARN("enable layout transform while load model for lite"); | |||||
LITE_LOG("enable layout transform while load model for lite"); | |||||
lite::Runtime::enable_global_layout_transform(m_network); | lite::Runtime::enable_global_layout_transform(m_network); | ||||
} | } | ||||
if (share_model_mem) { | if (share_model_mem) { | ||||
//! WARNNING:maybe not right to share param memmory for this | //! WARNNING:maybe not right to share param memmory for this | ||||
LITE_WARN("enable share model memory"); | |||||
LITE_LOG("enable share model memory"); | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | FILE* fin = fopen(model_path.c_str(), "rb"); | ||||
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | ||||
@@ -19,27 +19,27 @@ void XPUDeviceOption::config_model_internel<ModelLite>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || | if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || | ||||
(enable_multithread_default)) { | (enable_multithread_default)) { | ||||
LITE_WARN("using cpu device\n"); | |||||
LITE_LOG("using cpu device\n"); | |||||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | model->get_config().device_type = LiteDeviceType::LITE_CPU; | ||||
} | } | ||||
#if LITE_WITH_CUDA | #if LITE_WITH_CUDA | ||||
if (enable_cuda) { | if (enable_cuda) { | ||||
LITE_WARN("using cuda device\n"); | |||||
LITE_LOG("using cuda device\n"); | |||||
model->get_config().device_type = LiteDeviceType::LITE_CUDA; | model->get_config().device_type = LiteDeviceType::LITE_CUDA; | ||||
} | } | ||||
#endif | #endif | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
auto&& network = model->get_lite_network(); | auto&& network = model->get_lite_network(); | ||||
if (enable_cpu_default) { | if (enable_cpu_default) { | ||||
LITE_WARN("using cpu default device\n"); | |||||
LITE_LOG("using cpu default device\n"); | |||||
lite::Runtime::set_cpu_inplace_mode(network); | lite::Runtime::set_cpu_inplace_mode(network); | ||||
} | } | ||||
if (enable_multithread) { | if (enable_multithread) { | ||||
LITE_WARN("using multithread device\n"); | |||||
LITE_LOG("using multithread device\n"); | |||||
lite::Runtime::set_cpu_threads_number(network, thread_num); | lite::Runtime::set_cpu_threads_number(network, thread_num); | ||||
} | } | ||||
if (enable_multithread_default) { | if (enable_multithread_default) { | ||||
LITE_WARN("using multithread default device\n"); | |||||
LITE_LOG("using multithread default device\n"); | |||||
lite::Runtime::set_cpu_inplace_mode(network); | lite::Runtime::set_cpu_inplace_mode(network); | ||||
lite::Runtime::set_cpu_threads_number(network, thread_num); | lite::Runtime::set_cpu_threads_number(network, thread_num); | ||||
} | } | ||||
@@ -48,7 +48,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>( | |||||
for (auto id : core_ids) { | for (auto id : core_ids) { | ||||
core_str += std::to_string(id) + ","; | core_str += std::to_string(id) + ","; | ||||
} | } | ||||
LITE_WARN("multi thread core ids: %s\n", core_str.c_str()); | |||||
LITE_LOG("multi thread core ids: %s\n", core_str.c_str()); | |||||
lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { | lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { | ||||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | ||||
}; | }; | ||||
@@ -62,14 +62,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (enable_cpu) { | if (enable_cpu) { | ||||
mgb_log_warn("using cpu device\n"); | |||||
mgb_log("using cpu device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | ||||
loc.type = mgb::CompNode::DeviceType::CPU; | loc.type = mgb::CompNode::DeviceType::CPU; | ||||
}; | }; | ||||
} | } | ||||
#if LITE_WITH_CUDA | #if LITE_WITH_CUDA | ||||
if (enable_cuda) { | if (enable_cuda) { | ||||
mgb_log_warn("using cuda device\n"); | |||||
mgb_log("using cuda device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | ||||
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | ||||
loc.type = mgb::CompNode::DeviceType::CUDA; | loc.type = mgb::CompNode::DeviceType::CUDA; | ||||
@@ -79,14 +79,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
#endif | #endif | ||||
if (enable_cpu_default) { | if (enable_cpu_default) { | ||||
mgb_log_warn("using cpu default device\n"); | |||||
mgb_log("using cpu default device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | ||||
loc.type = mgb::CompNode::DeviceType::CPU; | loc.type = mgb::CompNode::DeviceType::CPU; | ||||
loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | ||||
}; | }; | ||||
} | } | ||||
if (enable_multithread) { | if (enable_multithread) { | ||||
mgb_log_warn("using multithread device\n"); | |||||
mgb_log("using multithread device\n"); | |||||
model->get_mdl_config().comp_node_mapper = | model->get_mdl_config().comp_node_mapper = | ||||
[&](mgb::CompNode::Locator& loc) { | [&](mgb::CompNode::Locator& loc) { | ||||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | ||||
@@ -95,7 +95,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
}; | }; | ||||
} | } | ||||
if (enable_multithread_default) { | if (enable_multithread_default) { | ||||
mgb_log_warn("using multithread default device\n"); | |||||
mgb_log("using multithread default device\n"); | |||||
model->get_mdl_config().comp_node_mapper = | model->get_mdl_config().comp_node_mapper = | ||||
[&](mgb::CompNode::Locator& loc) { | [&](mgb::CompNode::Locator& loc) { | ||||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | ||||
@@ -108,7 +108,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
for (auto id : core_ids) { | for (auto id : core_ids) { | ||||
core_str += std::to_string(id) + ","; | core_str += std::to_string(id) + ","; | ||||
} | } | ||||
mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str()); | |||||
mgb_log("set multi thread core ids:%s\n", core_str.c_str()); | |||||
auto affinity_callback = [&](size_t thread_id) { | auto affinity_callback = [&](size_t thread_id) { | ||||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | ||||
}; | }; | ||||
@@ -122,7 +122,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
XPUDeviceOption::XPUDeviceOption() { | |||||
void XPUDeviceOption::update() { | |||||
m_option_name = "xpu_device"; | m_option_name = "xpu_device"; | ||||
enable_cpu = FLAGS_cpu; | enable_cpu = FLAGS_cpu; | ||||
#if LITE_WITH_CUDA | #if LITE_WITH_CUDA | ||||
@@ -198,6 +198,7 @@ bool XPUDeviceOption::is_valid() { | |||||
std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | ||||
static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption); | static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption); | ||||
if (XPUDeviceOption::is_valid()) { | if (XPUDeviceOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<lar::OptionBase>(option); | return std::static_pointer_cast<lar::OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -24,8 +24,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
XPUDeviceOption(); | |||||
XPUDeviceOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
bool enable_cpu; | bool enable_cpu; | ||||
@@ -25,6 +25,7 @@ void COprLibOption::config_model_internel( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (!lib_path.empty()) { | if (!lib_path.empty()) { | ||||
mgb_log("load external C opr lib from %s\n", lib_path.c_str()); | |||||
load_lib(); | load_lib(); | ||||
} | } | ||||
if (c_opr_args.is_run_c_opr_with_param) { | if (c_opr_args.is_run_c_opr_with_param) { | ||||
@@ -176,7 +177,7 @@ void COprLibOption::set_Copr_IO(std::shared_ptr<ModelBase> model_ptr) { | |||||
config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); | config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); | ||||
} | } | ||||
COprLibOption::COprLibOption() { | |||||
void COprLibOption::update() { | |||||
m_option_name = "c_opr_lib"; | m_option_name = "c_opr_lib"; | ||||
lib_path = FLAGS_c_opr_lib; | lib_path = FLAGS_c_opr_lib; | ||||
c_opr_args.is_run_c_opr = !lib_path.empty(); | c_opr_args.is_run_c_opr = !lib_path.empty(); | ||||
@@ -191,6 +192,7 @@ bool COprLibOption::is_valid() { | |||||
std::shared_ptr<OptionBase> COprLibOption::create_option() { | std::shared_ptr<OptionBase> COprLibOption::create_option() { | ||||
static std::shared_ptr<COprLibOption> option(new COprLibOption); | static std::shared_ptr<COprLibOption> option(new COprLibOption); | ||||
if (COprLibOption::is_valid()) { | if (COprLibOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -32,8 +32,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
COprLibOption(); | |||||
COprLibOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -25,10 +25,10 @@ void FastRunOption::config_model_internel<ModelLite>( | |||||
uint32_t strategy = 0; | uint32_t strategy = 0; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (enable_full_run) { | if (enable_full_run) { | ||||
LITE_WARN("enable full-run strategy for algo profile"); | |||||
LITE_LOG("enable full-run strategy for algo profile"); | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy; | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy; | ||||
} else if (enable_fast_run) { | } else if (enable_fast_run) { | ||||
LITE_WARN("enable fast-run strategy for algo profile"); | |||||
LITE_LOG("enable fast-run strategy for algo profile"); | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | ||||
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | ||||
} else { | } else { | ||||
@@ -38,7 +38,7 @@ void FastRunOption::config_model_internel<ModelLite>( | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | ||||
#endif | #endif | ||||
if (batch_binary_equal || enable_reproducible) { | if (batch_binary_equal || enable_reproducible) { | ||||
LITE_WARN("enable reproducible strategy for algo profile"); | |||||
LITE_LOG("enable reproducible strategy for algo profile"); | |||||
if (batch_binary_equal) | if (batch_binary_equal) | ||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) | | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) | | ||||
strategy; | strategy; | ||||
@@ -81,10 +81,10 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||||
auto strategy = static_cast<Strategy>(0); | auto strategy = static_cast<Strategy>(0); | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (enable_full_run) { | if (enable_full_run) { | ||||
mgb_log_warn("enable full-run strategy for algo profile"); | |||||
mgb_log("enable full-run strategy for algo profile"); | |||||
strategy = Strategy::PROFILE | strategy; | strategy = Strategy::PROFILE | strategy; | ||||
} else if (enable_fast_run) { | } else if (enable_fast_run) { | ||||
mgb_log_warn("enable fast-run strategy for algo profile"); | |||||
mgb_log("enable fast-run strategy for algo profile"); | |||||
strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; | strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; | ||||
} else { | } else { | ||||
strategy = Strategy::HEURISTIC | strategy; | strategy = Strategy::HEURISTIC | strategy; | ||||
@@ -93,20 +93,20 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||||
strategy = Strategy::HEURISTIC | strategy; | strategy = Strategy::HEURISTIC | strategy; | ||||
#endif | #endif | ||||
if (batch_binary_equal || enable_reproducible) { | if (batch_binary_equal || enable_reproducible) { | ||||
mgb_log_warn("enable reproducible strategy for algo profile"); | |||||
mgb_log("enable reproducible strategy for algo profile"); | |||||
strategy = Strategy::REPRODUCIBLE | strategy; | strategy = Strategy::REPRODUCIBLE | strategy; | ||||
} | } | ||||
model->set_mdl_strategy(strategy); | model->set_mdl_strategy(strategy); | ||||
//! set binary_equal_between_batch and shared_batch_size | //! set binary_equal_between_batch and shared_batch_size | ||||
if (batch_binary_equal) { | if (batch_binary_equal) { | ||||
mgb_log_warn("enable batch binary equal"); | |||||
mgb_log("enable batch binary equal"); | |||||
model->get_mdl_config() | model->get_mdl_config() | ||||
.comp_graph->options() | .comp_graph->options() | ||||
.fast_run_config.binary_equal_between_batch = true; | .fast_run_config.binary_equal_between_batch = true; | ||||
} | } | ||||
if (share_batch_size > 0) { | if (share_batch_size > 0) { | ||||
mgb_log_warn("set shared shared batch"); | |||||
mgb_log("set shared shared batch"); | |||||
model->get_mdl_config() | model->get_mdl_config() | ||||
.comp_graph->options() | .comp_graph->options() | ||||
.fast_run_config.shared_batch_size = share_batch_size; | .fast_run_config.shared_batch_size = share_batch_size; | ||||
@@ -145,7 +145,7 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||||
using namespace lar; | using namespace lar; | ||||
bool FastRunOption::m_valid; | bool FastRunOption::m_valid; | ||||
FastRunOption::FastRunOption() { | |||||
void FastRunOption::update() { | |||||
m_option_name = "fastrun"; | m_option_name = "fastrun"; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
enable_fast_run = FLAGS_fast_run; | enable_fast_run = FLAGS_fast_run; | ||||
@@ -207,6 +207,7 @@ bool FastRunOption::is_valid() { | |||||
std::shared_ptr<OptionBase> FastRunOption::create_option() { | std::shared_ptr<OptionBase> FastRunOption::create_option() { | ||||
static std::shared_ptr<FastRunOption> option(new FastRunOption); | static std::shared_ptr<FastRunOption> option(new FastRunOption); | ||||
if (FastRunOption::is_valid()) { | if (FastRunOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -250,7 +251,7 @@ DEFINE_bool( | |||||
"https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" | "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" | ||||
"index.html#reproducibility" | "index.html#reproducibility" | ||||
"for more details."); | "for more details."); | ||||
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||||
DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||||
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | ||||
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); |
@@ -10,7 +10,7 @@ DECLARE_bool(full_run); | |||||
#endif | #endif | ||||
DECLARE_bool(reproducible); | DECLARE_bool(reproducible); | ||||
DECLARE_bool(binary_equal_between_batch); | DECLARE_bool(binary_equal_between_batch); | ||||
DECLARE_uint32(fast_run_shared_batch_size); | |||||
DECLARE_int32(fast_run_shared_batch_size); | |||||
DECLARE_string(fast_run_algo_policy); | DECLARE_string(fast_run_algo_policy); | ||||
namespace lar { | namespace lar { | ||||
@@ -33,8 +33,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
FastRunOption(); | |||||
FastRunOption() = default; | |||||
//! config template for different model | //! config template for different model | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | ||||
@@ -93,11 +93,11 @@ void IOdumpOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
if (enable_io_dump) { | if (enable_io_dump) { | ||||
LITE_WARN("enable text io dump"); | |||||
LITE_LOG("enable text io dump"); | |||||
lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); | lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); | ||||
} | } | ||||
if (enable_bin_io_dump) { | if (enable_bin_io_dump) { | ||||
LITE_WARN("enable binary io dump"); | |||||
LITE_LOG("enable binary io dump"); | |||||
lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); | lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); | ||||
} | } | ||||
//! FIX:when add API in lite complate this | //! FIX:when add API in lite complate this | ||||
@@ -108,7 +108,7 @@ void IOdumpOption::config_model_internel<ModelLite>( | |||||
LITE_THROW("lite model don't support the binary output dump"); | LITE_THROW("lite model don't support the binary output dump"); | ||||
} | } | ||||
if (enable_copy_to_host) { | if (enable_copy_to_host) { | ||||
LITE_WARN("lite model set copy to host defaultly"); | |||||
LITE_LOG("lite model set copy to host defaultly"); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -118,7 +118,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (enable_io_dump) { | if (enable_io_dump) { | ||||
mgb_log_warn("enable text io dump"); | |||||
mgb_log("enable text io dump"); | |||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | auto iodump = std::make_unique<mgb::TextOprIODump>( | ||||
model->get_mdl_config().comp_graph.get(), dump_path.c_str()); | model->get_mdl_config().comp_graph.get(), dump_path.c_str()); | ||||
iodump->print_addr(false); | iodump->print_addr(false); | ||||
@@ -126,7 +126,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
if (enable_io_dump_stdout) { | if (enable_io_dump_stdout) { | ||||
mgb_log_warn("enable text io dump to stdout"); | |||||
mgb_log("enable text io dump to stdout"); | |||||
std::shared_ptr<FILE> std_out(stdout, [](FILE*) {}); | std::shared_ptr<FILE> std_out(stdout, [](FILE*) {}); | ||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | auto iodump = std::make_unique<mgb::TextOprIODump>( | ||||
model->get_mdl_config().comp_graph.get(), std_out); | model->get_mdl_config().comp_graph.get(), std_out); | ||||
@@ -135,7 +135,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
if (enable_io_dump_stderr) { | if (enable_io_dump_stderr) { | ||||
mgb_log_warn("enable text io dump to stderr"); | |||||
mgb_log("enable text io dump to stderr"); | |||||
std::shared_ptr<FILE> std_err(stderr, [](FILE*) {}); | std::shared_ptr<FILE> std_err(stderr, [](FILE*) {}); | ||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | auto iodump = std::make_unique<mgb::TextOprIODump>( | ||||
model->get_mdl_config().comp_graph.get(), std_err); | model->get_mdl_config().comp_graph.get(), std_err); | ||||
@@ -144,14 +144,14 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
if (enable_bin_io_dump) { | if (enable_bin_io_dump) { | ||||
mgb_log_warn("enable binary io dump"); | |||||
mgb_log("enable binary io dump"); | |||||
auto iodump = std::make_unique<mgb::BinaryOprIODump>( | auto iodump = std::make_unique<mgb::BinaryOprIODump>( | ||||
model->get_mdl_config().comp_graph.get(), dump_path); | model->get_mdl_config().comp_graph.get(), dump_path); | ||||
io_dumper = std::move(iodump); | io_dumper = std::move(iodump); | ||||
} | } | ||||
if (enable_bin_out_dump) { | if (enable_bin_out_dump) { | ||||
mgb_log_warn("enable binary output dump"); | |||||
mgb_log("enable binary output dump"); | |||||
out_dumper = std::make_unique<OutputDumper>(dump_path.c_str()); | out_dumper = std::make_unique<OutputDumper>(dump_path.c_str()); | ||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
@@ -190,7 +190,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||||
////////////////////// Input options //////////////////////// | ////////////////////// Input options //////////////////////// | ||||
using namespace lar; | using namespace lar; | ||||
InputOption::InputOption() { | |||||
void InputOption::update() { | |||||
m_option_name = "input"; | m_option_name = "input"; | ||||
size_t start = 0; | size_t start = 0; | ||||
auto end = FLAGS_input.find(";", start); | auto end = FLAGS_input.find(";", start); | ||||
@@ -204,9 +204,10 @@ InputOption::InputOption() { | |||||
} | } | ||||
std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() { | std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() { | ||||
static std::shared_ptr<InputOption> m_option(new InputOption); | |||||
static std::shared_ptr<InputOption> option(new InputOption); | |||||
if (InputOption::is_valid()) { | if (InputOption::is_valid()) { | ||||
return std::static_pointer_cast<OptionBase>(m_option); | |||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -219,7 +220,7 @@ void InputOption::config_model( | |||||
////////////////////// OprIOdump options //////////////////////// | ////////////////////// OprIOdump options //////////////////////// | ||||
IOdumpOption::IOdumpOption() { | |||||
void IOdumpOption::update() { | |||||
m_option_name = "iodump"; | m_option_name = "iodump"; | ||||
size_t valid_flag = 0; | size_t valid_flag = 0; | ||||
if (!FLAGS_io_dump.empty()) { | if (!FLAGS_io_dump.empty()) { | ||||
@@ -268,6 +269,7 @@ bool IOdumpOption::is_valid() { | |||||
std::shared_ptr<OptionBase> IOdumpOption::create_option() { | std::shared_ptr<OptionBase> IOdumpOption::create_option() { | ||||
static std::shared_ptr<IOdumpOption> option(new IOdumpOption); | static std::shared_ptr<IOdumpOption> option(new IOdumpOption); | ||||
if (IOdumpOption::is_valid()) { | if (IOdumpOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -30,8 +30,10 @@ public: | |||||
//! interface implement from OptionBase | //! interface implement from OptionBase | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
InputOption(); | |||||
InputOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -50,8 +52,10 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
IOdumpOption(); | |||||
IOdumpOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -11,7 +11,7 @@ void LayoutOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
#define ENABLE_LAYOUT(layout) \ | #define ENABLE_LAYOUT(layout) \ | ||||
LITE_WARN("enable " #layout " optimization"); \ | |||||
LITE_LOG("enable " #layout " optimization"); \ | |||||
model->get_config().options.enable_##layout = true; \ | model->get_config().options.enable_##layout = true; \ | ||||
break; | break; | ||||
@@ -51,7 +51,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
#define ENABLE_LAYOUT(layout) \ | #define ENABLE_LAYOUT(layout) \ | ||||
mgb_log_warn("enable " #layout " optimization"); \ | |||||
mgb_log("enable " #layout " optimization"); \ | |||||
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | ||||
break; | break; | ||||
@@ -91,7 +91,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||||
using namespace lar; | using namespace lar; | ||||
bool LayoutOption::m_valid; | bool LayoutOption::m_valid; | ||||
LayoutOption::LayoutOption() { | |||||
void LayoutOption::update() { | |||||
m_option_name = "layout"; | m_option_name = "layout"; | ||||
m_option_flag = static_cast<OptLayoutType>(0); | m_option_flag = static_cast<OptLayoutType>(0); | ||||
m_option = { | m_option = { | ||||
@@ -157,6 +157,7 @@ bool LayoutOption::is_valid() { | |||||
std::shared_ptr<OptionBase> LayoutOption::create_option() { | std::shared_ptr<OptionBase> LayoutOption::create_option() { | ||||
static std::shared_ptr<LayoutOption> option(new LayoutOption); | static std::shared_ptr<LayoutOption> option(new LayoutOption); | ||||
if (LayoutOption::is_valid()) { | if (LayoutOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -166,16 +167,20 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||||
void LayoutOption::config_model( | void LayoutOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
size_t valid_flag = 0; | size_t valid_flag = 0; | ||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||||
if (FLAGS_enable_nchw4 || | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | ||||
} | } | ||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||||
if (FLAGS_enable_chwn4 || | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | ||||
} | } | ||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||||
if (FLAGS_enable_nchw44 || | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | ||||
} | } | ||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||||
if (FLAGS_enable_nchw88 || | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | ||||
} | } | ||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) { | if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) { | ||||
@@ -37,9 +37,11 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
//! Constructor | //! Constructor | ||||
LayoutOption(); | |||||
LayoutOption() = default; | |||||
//! configuration for different model implement | //! configuration for different model implement | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
@@ -11,7 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (m_layout_transform) { | if (m_layout_transform) { | ||||
LITE_WARN("using global layout transform optimization\n"); | |||||
LITE_LOG("using global layout transform optimization\n"); | |||||
if (m_layout_transform_target == | if (m_layout_transform_target == | ||||
mgb::gopt::GraphTuningOptions::Target::CPU) { | mgb::gopt::GraphTuningOptions::Target::CPU) { | ||||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | model->get_config().device_type = LiteDeviceType::LITE_CPU; | ||||
@@ -98,7 +98,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | ||||
if (m_layout_transform) { | if (m_layout_transform) { | ||||
mgb_log_warn("using global layout transform optimization\n"); | |||||
mgb_log("using global layout transform optimization\n"); | |||||
auto&& load_result = model->get_mdl_load_result(); | auto&& load_result = model->get_mdl_load_result(); | ||||
load_result.output_var_list = mgb::gopt::layout_transform( | load_result.output_var_list = mgb::gopt::layout_transform( | ||||
load_result.output_var_list, m_layout_transform_target); | load_result.output_var_list, m_layout_transform_target); | ||||
@@ -150,7 +150,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||||
using namespace lar; | using namespace lar; | ||||
bool GoptLayoutOption::m_valid; | bool GoptLayoutOption::m_valid; | ||||
GoptLayoutOption::GoptLayoutOption() { | |||||
void GoptLayoutOption::update() { | |||||
m_option_name = "gopt_layout"; | m_option_name = "gopt_layout"; | ||||
if (FLAGS_layout_transform != "cpu" | if (FLAGS_layout_transform != "cpu" | ||||
#if LITE_WITH_CUDA | #if LITE_WITH_CUDA | ||||
@@ -216,6 +216,7 @@ bool GoptLayoutOption::is_valid() { | |||||
std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | ||||
static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption); | static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption); | ||||
if (GoptLayoutOption::is_valid()) { | if (GoptLayoutOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -28,8 +28,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
GoptLayoutOption(); | |||||
GoptLayoutOption() = default; | |||||
//! config template for different model | //! config template for different model | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | ||||
@@ -24,7 +24,7 @@ void PackModelOption::config_model_internel( | |||||
using namespace lar; | using namespace lar; | ||||
////////////////////// PackModel options //////////////////////// | ////////////////////// PackModel options //////////////////////// | ||||
PackModelOption::PackModelOption() { | |||||
void PackModelOption::update() { | |||||
m_option_name = "pack_model"; | m_option_name = "pack_model"; | ||||
if (!FLAGS_packed_model_dump.empty()) | if (!FLAGS_packed_model_dump.empty()) | ||||
packed_model_dump = FLAGS_packed_model_dump; | packed_model_dump = FLAGS_packed_model_dump; | ||||
@@ -45,6 +45,7 @@ bool PackModelOption::is_valid() { | |||||
std::shared_ptr<OptionBase> PackModelOption::create_option() { | std::shared_ptr<OptionBase> PackModelOption::create_option() { | ||||
static std::shared_ptr<PackModelOption> option(new PackModelOption); | static std::shared_ptr<PackModelOption> option(new PackModelOption); | ||||
if (PackModelOption::is_valid()) { | if (PackModelOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -19,8 +19,10 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; } | std::string option_name() const override { return m_option_name; } | ||||
void update() override; | |||||
private: | private: | ||||
PackModelOption(); | |||||
PackModelOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | ||||
@@ -15,7 +15,7 @@ void FusePreprocessOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (enable_fuse_preprocess) { | if (enable_fuse_preprocess) { | ||||
LITE_WARN("enable fuse-preprocess optimization"); | |||||
LITE_LOG("enable fuse-preprocess optimization"); | |||||
model->get_config().options.fuse_preprocess = true; | model->get_config().options.fuse_preprocess = true; | ||||
} | } | ||||
} | } | ||||
@@ -27,7 +27,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (enable_fuse_preprocess) { | if (enable_fuse_preprocess) { | ||||
mgb_log_warn("enable fuse-preprocess optimization"); | |||||
mgb_log("enable fuse-preprocess optimization"); | |||||
graph_option.graph_opt.enable_fuse_preprocess(); | graph_option.graph_opt.enable_fuse_preprocess(); | ||||
} | } | ||||
} | } | ||||
@@ -35,7 +35,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
bool FusePreprocessOption::m_valid; | bool FusePreprocessOption::m_valid; | ||||
FusePreprocessOption::FusePreprocessOption() { | |||||
void FusePreprocessOption::update() { | |||||
m_option_name = "fuse_preprocess"; | m_option_name = "fuse_preprocess"; | ||||
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | ||||
m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}}; | m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}}; | ||||
@@ -51,6 +51,7 @@ bool FusePreprocessOption::is_valid() { | |||||
std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | ||||
static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption); | static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption); | ||||
if (FusePreprocessOption::is_valid()) { | if (FusePreprocessOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -73,7 +74,7 @@ void WeightPreprocessOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (weight_preprocess) { | if (weight_preprocess) { | ||||
LITE_WARN("enable weight-preprocess optimization"); | |||||
LITE_LOG("enable weight-preprocess optimization"); | |||||
model->get_config().options.weight_preprocess = true; | model->get_config().options.weight_preprocess = true; | ||||
} | } | ||||
} | } | ||||
@@ -85,14 +86,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (weight_preprocess) { | if (weight_preprocess) { | ||||
mgb_log_warn("enable weight-preprocess optimization"); | |||||
mgb_log("enable weight-preprocess optimization"); | |||||
graph_option.graph_opt.enable_weight_preprocess(); | graph_option.graph_opt.enable_weight_preprocess(); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
WeightPreprocessOption::WeightPreprocessOption() { | |||||
void WeightPreprocessOption::update() { | |||||
m_option_name = "weight_preprocess"; | m_option_name = "weight_preprocess"; | ||||
weight_preprocess = FLAGS_weight_preprocess; | weight_preprocess = FLAGS_weight_preprocess; | ||||
m_option = {{"weight_preprocess", lar::Bool::make(false)}}; | m_option = {{"weight_preprocess", lar::Bool::make(false)}}; | ||||
@@ -108,6 +109,7 @@ bool WeightPreprocessOption::is_valid() { | |||||
std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | ||||
static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption); | static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption); | ||||
if (WeightPreprocessOption::is_valid()) { | if (WeightPreprocessOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -142,14 +144,14 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (enable_fuse_conv_bias_nonlinearity) { | if (enable_fuse_conv_bias_nonlinearity) { | ||||
mgb_log_warn("enable fuse conv+bias+nonlinearity optimization"); | |||||
mgb_log("enable fuse conv+bias+nonlinearity optimization"); | |||||
graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); | graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | |||||
void FuseConvBiasNonlinearOption::update() { | |||||
m_option_name = "fuse_conv_bias_nonlinearity"; | m_option_name = "fuse_conv_bias_nonlinearity"; | ||||
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | ||||
m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}}; | m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}}; | ||||
@@ -166,6 +168,7 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||||
static std::shared_ptr<FuseConvBiasNonlinearOption> option( | static std::shared_ptr<FuseConvBiasNonlinearOption> option( | ||||
new FuseConvBiasNonlinearOption); | new FuseConvBiasNonlinearOption); | ||||
if (FuseConvBiasNonlinearOption::is_valid()) { | if (FuseConvBiasNonlinearOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -203,14 +206,14 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (enable_fuse_conv_bias_with_z) { | if (enable_fuse_conv_bias_with_z) { | ||||
mgb_log_warn("enable fuse conv+bias+z optimization"); | |||||
mgb_log("enable fuse conv+bias+z optimization"); | |||||
graph_option.graph_opt.enable_fuse_conv_bias_with_z(); | graph_option.graph_opt.enable_fuse_conv_bias_with_z(); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | |||||
void FuseConvBiasElemwiseAddOption::update() { | |||||
m_option_name = "fuse_conv_bias_with_z"; | m_option_name = "fuse_conv_bias_with_z"; | ||||
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | ||||
m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}}; | m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}}; | ||||
@@ -227,6 +230,7 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||||
static std::shared_ptr<FuseConvBiasElemwiseAddOption> option( | static std::shared_ptr<FuseConvBiasElemwiseAddOption> option( | ||||
new FuseConvBiasElemwiseAddOption); | new FuseConvBiasElemwiseAddOption); | ||||
if (FuseConvBiasElemwiseAddOption::is_valid()) { | if (FuseConvBiasElemwiseAddOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -250,26 +254,26 @@ void GraphRecordOption::config_model_internel<ModelLite>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& config_option = model->get_config().options; | auto&& config_option = model->get_config().options; | ||||
if (const_shape) { | if (const_shape) { | ||||
LITE_WARN("enable const var shape"); | |||||
LITE_LOG("enable const var shape"); | |||||
config_option.const_shape = true; | config_option.const_shape = true; | ||||
} | } | ||||
if (fake_first) { | if (fake_first) { | ||||
LITE_WARN("enable fake-first optimization"); | |||||
LITE_LOG("enable fake-first optimization"); | |||||
config_option.fake_next_exec = true; | config_option.fake_next_exec = true; | ||||
} | } | ||||
if (no_sanity_check) { | if (no_sanity_check) { | ||||
LITE_WARN("disable var sanity check optimization"); | |||||
LITE_LOG("disable var sanity check optimization"); | |||||
config_option.var_sanity_check_first_run = false; | config_option.var_sanity_check_first_run = false; | ||||
} | } | ||||
if (m_record_comp_seq == 1) { | if (m_record_comp_seq == 1) { | ||||
LITE_WARN("set record_comp_seq_level to 1"); | |||||
LITE_LOG("set record_comp_seq_level to 1"); | |||||
} | } | ||||
if (m_record_comp_seq == 2) { | if (m_record_comp_seq == 2) { | ||||
mgb_assert( | mgb_assert( | ||||
no_sanity_check, | no_sanity_check, | ||||
"--no-sanity-check should be set before " | "--no-sanity-check should be set before " | ||||
"--record-comp-seq2"); | "--record-comp-seq2"); | ||||
LITE_WARN("set record_comp_seq_level to 2"); | |||||
LITE_LOG("set record_comp_seq_level to 2"); | |||||
} | } | ||||
config_option.comp_node_seq_record_level = m_record_comp_seq; | config_option.comp_node_seq_record_level = m_record_comp_seq; | ||||
} | } | ||||
@@ -281,33 +285,33 @@ void GraphRecordOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (const_shape) { | if (const_shape) { | ||||
mgb_log_warn("enable const var shape"); | |||||
mgb_log("enable const var shape"); | |||||
model->get_mdl_config().const_var_shape = true; | model->get_mdl_config().const_var_shape = true; | ||||
} | } | ||||
if (fake_first) { | if (fake_first) { | ||||
mgb_log_warn("enable fake-first optimization"); | |||||
mgb_log("enable fake-first optimization"); | |||||
graph_option.fake_next_exec = true; | graph_option.fake_next_exec = true; | ||||
} | } | ||||
if (no_sanity_check) { | if (no_sanity_check) { | ||||
mgb_log_warn("disable var sanity check optimization"); | |||||
mgb_log("disable var sanity check optimization"); | |||||
graph_option.var_sanity_check_first_run = false; | graph_option.var_sanity_check_first_run = false; | ||||
} | } | ||||
if (m_record_comp_seq == 1) { | if (m_record_comp_seq == 1) { | ||||
mgb_log_warn("set record_comp_seq_level to 1"); | |||||
mgb_log("set record_comp_seq_level to 1"); | |||||
} | } | ||||
if (m_record_comp_seq == 2) { | if (m_record_comp_seq == 2) { | ||||
mgb_assert( | mgb_assert( | ||||
no_sanity_check && !fake_first, | no_sanity_check && !fake_first, | ||||
"--no-sanity-check should be set before " | "--no-sanity-check should be set before " | ||||
"--record-comp-seq2 and --fake-first should not be set"); | "--record-comp-seq2 and --fake-first should not be set"); | ||||
mgb_log_warn("set record_comp_seq_level to 2"); | |||||
mgb_log("set record_comp_seq_level to 2"); | |||||
} | } | ||||
graph_option.comp_node_seq_record_level = m_record_comp_seq; | graph_option.comp_node_seq_record_level = m_record_comp_seq; | ||||
} | } | ||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
GraphRecordOption::GraphRecordOption() { | |||||
void GraphRecordOption::update() { | |||||
m_option_name = "graph_record"; | m_option_name = "graph_record"; | ||||
m_record_comp_seq = 0; | m_record_comp_seq = 0; | ||||
const_shape = FLAGS_const_shape; | const_shape = FLAGS_const_shape; | ||||
@@ -350,6 +354,7 @@ bool GraphRecordOption::is_valid() { | |||||
std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | ||||
static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption); | static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption); | ||||
if (GraphRecordOption::is_valid()) { | if (GraphRecordOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -387,7 +392,7 @@ void MemoryOptimizeOption::config_model_internel<ModelLite>( | |||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
if (workspace_limit != SIZE_MAX) { | if (workspace_limit != SIZE_MAX) { | ||||
LITE_WARN("set workspace limit to %ld", workspace_limit); | |||||
LITE_LOG("set workspace limit to %ld", workspace_limit); | |||||
lite::Runtime::set_network_algo_workspace_limit( | lite::Runtime::set_network_algo_workspace_limit( | ||||
model->get_lite_network(), workspace_limit); | model->get_lite_network(), workspace_limit); | ||||
} | } | ||||
@@ -400,12 +405,12 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (disable_mem_opt) { | if (disable_mem_opt) { | ||||
mgb_log_warn("disable memory optimization"); | |||||
mgb_log("disable memory optimization"); | |||||
graph_option.seq_opt.enable_mem_plan_opt = false; | graph_option.seq_opt.enable_mem_plan_opt = false; | ||||
graph_option.seq_opt.enable_mem_reuse_alloc = false; | graph_option.seq_opt.enable_mem_reuse_alloc = false; | ||||
} | } | ||||
if (workspace_limit < SIZE_MAX) { | if (workspace_limit < SIZE_MAX) { | ||||
mgb_log_warn("set workspace limit to %ld", workspace_limit); | |||||
mgb_log("set workspace limit to %ld", workspace_limit); | |||||
auto&& output_spec = model->get_output_spec(); | auto&& output_spec = model->get_output_spec(); | ||||
mgb::SymbolVarArray vars; | mgb::SymbolVarArray vars; | ||||
for (auto i : output_spec) { | for (auto i : output_spec) { | ||||
@@ -417,7 +422,7 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
MemoryOptimizeOption::MemoryOptimizeOption() { | |||||
void MemoryOptimizeOption::update() { | |||||
m_option_name = "memory_optimize"; | m_option_name = "memory_optimize"; | ||||
disable_mem_opt = FLAGS_disable_mem_opt; | disable_mem_opt = FLAGS_disable_mem_opt; | ||||
workspace_limit = FLAGS_workspace_limit; | workspace_limit = FLAGS_workspace_limit; | ||||
@@ -432,6 +437,7 @@ bool MemoryOptimizeOption::is_valid() { | |||||
std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() { | std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() { | ||||
static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption); | static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption); | ||||
if (MemoryOptimizeOption::is_valid()) { | if (MemoryOptimizeOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -451,7 +457,7 @@ void JITOption::config_model_internel<ModelLite>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& config_option = model->get_config().options; | auto&& config_option = model->get_config().options; | ||||
if (enable_jit) { | if (enable_jit) { | ||||
LITE_WARN("enable JIT (level 1)"); | |||||
LITE_LOG("enable JIT (level 1)"); | |||||
config_option.jit_level = 1; | config_option.jit_level = 1; | ||||
} | } | ||||
} | } | ||||
@@ -463,13 +469,13 @@ void JITOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (enable_jit) { | if (enable_jit) { | ||||
mgb_log_warn("enable JIT (level 1)"); | |||||
mgb_log("enable JIT (level 1)"); | |||||
graph_option.graph_opt.jit = 1; | graph_option.graph_opt.jit = 1; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
JITOption::JITOption() { | |||||
void JITOption::update() { | |||||
m_option_name = "JIT"; | m_option_name = "JIT"; | ||||
enable_jit = FLAGS_enable_jit; | enable_jit = FLAGS_enable_jit; | ||||
} | } | ||||
@@ -482,6 +488,7 @@ bool JITOption::is_valid() { | |||||
std::shared_ptr<OptionBase> JITOption::create_option() { | std::shared_ptr<OptionBase> JITOption::create_option() { | ||||
static std::shared_ptr<JITOption> option(new JITOption); | static std::shared_ptr<JITOption> option(new JITOption); | ||||
if (JITOption::is_valid()) { | if (JITOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -500,12 +507,12 @@ void TensorRTOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (!tensorrt_cache.empty()) { | if (!tensorrt_cache.empty()) { | ||||
LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||||
LITE_LOG("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||||
lite::set_tensor_rt_cache(tensorrt_cache); | lite::set_tensor_rt_cache(tensorrt_cache); | ||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
if (enable_tensorrt) { | if (enable_tensorrt) { | ||||
LITE_WARN("enable TensorRT"); | |||||
LITE_LOG("enable TensorRT"); | |||||
lite::Runtime::use_tensorrt(model->get_lite_network()); | lite::Runtime::use_tensorrt(model->get_lite_network()); | ||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | ||||
@@ -521,11 +528,11 @@ void TensorRTOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | auto&& graph_option = model->get_mdl_config().comp_graph->options(); | ||||
if (enable_tensorrt) { | if (enable_tensorrt) { | ||||
mgb_log_warn("using tensorRT"); | |||||
mgb_log("using tensorRT"); | |||||
graph_option.graph_opt.tensorrt = true; | graph_option.graph_opt.tensorrt = true; | ||||
} | } | ||||
if (!tensorrt_cache.empty()) { | if (!tensorrt_cache.empty()) { | ||||
mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||||
mgb_log("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||||
mgb::TensorRTEngineCache::enable_engine_cache(true); | mgb::TensorRTEngineCache::enable_engine_cache(true); | ||||
mgb::TensorRTEngineCache::set_impl( | mgb::TensorRTEngineCache::set_impl( | ||||
std::make_shared<mgb::TensorRTEngineCacheIO>( | std::make_shared<mgb::TensorRTEngineCacheIO>( | ||||
@@ -541,7 +548,7 @@ void TensorRTOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
TensorRTOption::TensorRTOption() { | |||||
void TensorRTOption::update() { | |||||
m_option_name = "tensorRT"; | m_option_name = "tensorRT"; | ||||
enable_tensorrt = FLAGS_tensorrt; | enable_tensorrt = FLAGS_tensorrt; | ||||
tensorrt_cache = FLAGS_tensorrt_cache; | tensorrt_cache = FLAGS_tensorrt_cache; | ||||
@@ -556,6 +563,7 @@ bool TensorRTOption::is_valid() { | |||||
std::shared_ptr<OptionBase> TensorRTOption::create_option() { | std::shared_ptr<OptionBase> TensorRTOption::create_option() { | ||||
static std::shared_ptr<TensorRTOption> option(new TensorRTOption); | static std::shared_ptr<TensorRTOption> option(new TensorRTOption); | ||||
if (TensorRTOption::is_valid()) { | if (TensorRTOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -39,8 +39,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
FusePreprocessOption(); | |||||
FusePreprocessOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -65,8 +67,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
WeightPreprocessOption(); | |||||
WeightPreprocessOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -91,8 +95,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
FuseConvBiasNonlinearOption(); | |||||
FuseConvBiasNonlinearOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -117,8 +123,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
FuseConvBiasElemwiseAddOption(); | |||||
FuseConvBiasElemwiseAddOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
std::string m_option_name; | std::string m_option_name; | ||||
@@ -143,8 +151,10 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
GraphRecordOption(); | |||||
GraphRecordOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -169,8 +179,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
MemoryOptimizeOption(); | |||||
MemoryOptimizeOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -191,8 +203,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
JITOption(); | |||||
JITOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -212,8 +226,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
TensorRTOption(); | |||||
TensorRTOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
@@ -28,6 +28,10 @@ public: | |||||
//! get option map | //! get option map | ||||
virtual OptionValMap* get_option() { return nullptr; } | virtual OptionValMap* get_option() { return nullptr; } | ||||
//! update option value | |||||
virtual void update(){}; | |||||
virtual ~OptionBase() = default; | virtual ~OptionBase() = default; | ||||
}; | }; | ||||
@@ -22,10 +22,10 @@ void PluginOption::config_model_internel<ModelLite>( | |||||
else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
if (!profile_path.empty()) { | if (!profile_path.empty()) { | ||||
if (!enable_profile_host) { | if (!enable_profile_host) { | ||||
LITE_WARN("enable profiling"); | |||||
LITE_LOG("enable profiling"); | |||||
model->get_lite_network()->enable_profile_performance(profile_path); | model->get_lite_network()->enable_profile_performance(profile_path); | ||||
} else { | } else { | ||||
LITE_WARN("enable profiling for host"); | |||||
LITE_LOG("enable profiling for host"); | |||||
model->get_lite_network()->enable_profile_performance(profile_path); | model->get_lite_network()->enable_profile_performance(profile_path); | ||||
} | } | ||||
} | } | ||||
@@ -39,18 +39,18 @@ void PluginOption::config_model_internel<ModelMdl>( | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
auto&& config = model->get_mdl_config(); | auto&& config = model->get_mdl_config(); | ||||
if (range > 0) { | if (range > 0) { | ||||
mgb_log_warn("enable number range check"); | |||||
mgb_log("enable number range check"); | |||||
model->set_num_range_checker(float(range)); | model->set_num_range_checker(float(range)); | ||||
} | } | ||||
if (enable_check_dispatch) { | if (enable_check_dispatch) { | ||||
mgb_log_warn("enable cpu dispatch check"); | |||||
mgb_log("enable cpu dispatch check"); | |||||
cpu_dispatch_checker = | cpu_dispatch_checker = | ||||
std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get()); | std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get()); | ||||
} | } | ||||
if (!var_value_check_str.empty()) { | if (!var_value_check_str.empty()) { | ||||
mgb_log_warn("enable variable value check"); | |||||
mgb_log("enable variable value check"); | |||||
size_t init_idx = 0, switch_interval; | size_t init_idx = 0, switch_interval; | ||||
auto sep = var_value_check_str.find(':'); | auto sep = var_value_check_str.find(':'); | ||||
if (sep != std::string::npos) { | if (sep != std::string::npos) { | ||||
@@ -67,9 +67,9 @@ void PluginOption::config_model_internel<ModelMdl>( | |||||
if (!profile_path.empty()) { | if (!profile_path.empty()) { | ||||
if (!enable_profile_host) { | if (!enable_profile_host) { | ||||
mgb_log_warn("enable profiling"); | |||||
mgb_log("enable profiling"); | |||||
} else { | } else { | ||||
mgb_log_warn("enable profiling for host"); | |||||
mgb_log("enable profiling for host"); | |||||
} | } | ||||
model->set_profiler(); | model->set_profiler(); | ||||
} | } | ||||
@@ -79,12 +79,11 @@ void PluginOption::config_model_internel<ModelMdl>( | |||||
else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | ||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
if (!profile_path.empty()) { | if (!profile_path.empty()) { | ||||
mgb_log_warn("filename %s", profile_path.c_str()); | |||||
if (model->get_profiler()) { | if (model->get_profiler()) { | ||||
model->get_profiler() | model->get_profiler() | ||||
->to_json_full(model->get_async_func().get()) | ->to_json_full(model->get_async_func().get()) | ||||
->writeto_fpath(profile_path); | ->writeto_fpath(profile_path); | ||||
mgb_log_warn("profiling result written to %s", profile_path.c_str()); | |||||
mgb_log("profiling result written to %s", profile_path.c_str()); | |||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||
@@ -94,7 +93,7 @@ void PluginOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
PluginOption::PluginOption() { | |||||
void PluginOption::update() { | |||||
m_option_name = "plugin"; | m_option_name = "plugin"; | ||||
range = FLAGS_range; | range = FLAGS_range; | ||||
enable_check_dispatch = FLAGS_check_dispatch; | enable_check_dispatch = FLAGS_check_dispatch; | ||||
@@ -125,6 +124,7 @@ bool PluginOption::is_valid() { | |||||
std::shared_ptr<OptionBase> PluginOption::create_option() { | std::shared_ptr<OptionBase> PluginOption::create_option() { | ||||
static std::shared_ptr<PluginOption> option(new PluginOption); | static std::shared_ptr<PluginOption> option(new PluginOption); | ||||
if (PluginOption::is_valid()) { | if (PluginOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -199,7 +199,7 @@ void DebugOption::format_and_print( | |||||
std::stringstream ss; | std::stringstream ss; | ||||
ss << table; | ss << table; | ||||
printf("%s\n\n", ss.str().c_str()); | |||||
LITE_LOG("%s\n\n", ss.str().c_str()); | |||||
} | } | ||||
template <> | template <> | ||||
@@ -243,7 +243,7 @@ void DebugOption::format_and_print( | |||||
std::stringstream ss; | std::stringstream ss; | ||||
ss << table; | ss << table; | ||||
printf("%s\n\n", ss.str().c_str()); | |||||
mgb_log("%s\n\n", ss.str().c_str()); | |||||
} | } | ||||
template <> | template <> | ||||
@@ -260,7 +260,7 @@ void DebugOption::config_model_internel<ModelLite>( | |||||
#endif | #endif | ||||
#endif | #endif | ||||
if (enable_verbose) { | if (enable_verbose) { | ||||
LITE_WARN("enable verbose"); | |||||
LITE_LOG("enable verbose"); | |||||
lite::set_log_level(LiteLogLevel::DEBUG); | lite::set_log_level(LiteLogLevel::DEBUG); | ||||
} | } | ||||
@@ -272,7 +272,7 @@ void DebugOption::config_model_internel<ModelLite>( | |||||
#endif | #endif | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | ||||
if (enable_display_model_info) { | if (enable_display_model_info) { | ||||
LITE_WARN("enable display model information"); | |||||
LITE_LOG("enable display model information"); | |||||
format_and_print<ModelLite>("Runtime Model Info", model); | format_and_print<ModelLite>("Runtime Model Info", model); | ||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | ||||
@@ -287,7 +287,7 @@ void DebugOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (enable_verbose) { | if (enable_verbose) { | ||||
mgb_log_warn("enable verbose"); | |||||
mgb_log("enable verbose"); | |||||
mgb::set_log_level(mgb::LogLevel::DEBUG); | mgb::set_log_level(mgb::LogLevel::DEBUG); | ||||
} | } | ||||
@@ -299,21 +299,21 @@ void DebugOption::config_model_internel<ModelMdl>( | |||||
#endif | #endif | ||||
} else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { | } else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { | ||||
if (enable_display_model_info) { | if (enable_display_model_info) { | ||||
mgb_log_warn("enable display model information"); | |||||
mgb_log("enable display model information"); | |||||
format_and_print<ModelMdl>("Runtime Model Info", model); | format_and_print<ModelMdl>("Runtime Model Info", model); | ||||
} | } | ||||
} else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | ||||
#ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
if (!static_mem_log_dir_path.empty()) { | if (!static_mem_log_dir_path.empty()) { | ||||
mgb_log_warn("enable get static memeory information"); | |||||
mgb_log("enable get static memeory information"); | |||||
model->get_async_func()->get_static_memory_alloc_info( | model->get_async_func()->get_static_memory_alloc_info( | ||||
static_mem_log_dir_path); | static_mem_log_dir_path); | ||||
} | } | ||||
#endif | #endif | ||||
#endif | #endif | ||||
if (disable_assert_throw) { | if (disable_assert_throw) { | ||||
mgb_log_warn("disable assert throw"); | |||||
mgb_log("disable assert throw"); | |||||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | ||||
if (opr->same_type<mgb::opr::AssertEqual>()) { | if (opr->same_type<mgb::opr::AssertEqual>()) { | ||||
opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error(); | opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error(); | ||||
@@ -333,7 +333,7 @@ void DebugOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
DebugOption::DebugOption() { | |||||
void DebugOption::update() { | |||||
m_option_name = "debug"; | m_option_name = "debug"; | ||||
enable_display_model_info = FLAGS_model_info; | enable_display_model_info = FLAGS_model_info; | ||||
enable_verbose = FLAGS_verbose; | enable_verbose = FLAGS_verbose; | ||||
@@ -367,6 +367,7 @@ bool DebugOption::is_valid() { | |||||
std::shared_ptr<OptionBase> DebugOption::create_option() { | std::shared_ptr<OptionBase> DebugOption::create_option() { | ||||
static std::shared_ptr<DebugOption> option(new DebugOption); | static std::shared_ptr<DebugOption> option(new DebugOption); | ||||
if (DebugOption::is_valid()) { | if (DebugOption::is_valid()) { | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
@@ -44,8 +44,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
PluginOption(); | |||||
PluginOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
double range; | double range; | ||||
@@ -74,8 +76,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
DebugOption(); | |||||
DebugOption() = default; | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){}; | void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){}; | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
@@ -5,7 +5,7 @@ using namespace lar; | |||||
DECLARE_bool(c_opr_lib_with_param); | DECLARE_bool(c_opr_lib_with_param); | ||||
DECLARE_bool(fitting); | DECLARE_bool(fitting); | ||||
StrategyOption::StrategyOption() { | |||||
void StrategyOption::update() { | |||||
m_option_name = "run_strategy"; | m_option_name = "run_strategy"; | ||||
warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter; | warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter; | ||||
run_iter = FLAGS_fitting ? 10 : FLAGS_iter; | run_iter = FLAGS_fitting ? 10 : FLAGS_iter; | ||||
@@ -20,6 +20,7 @@ StrategyOption::StrategyOption() { | |||||
std::shared_ptr<OptionBase> StrategyOption::create_option() { | std::shared_ptr<OptionBase> StrategyOption::create_option() { | ||||
static std::shared_ptr<StrategyOption> option(new StrategyOption); | static std::shared_ptr<StrategyOption> option(new StrategyOption); | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} | } | ||||
@@ -43,12 +44,13 @@ void StrategyOption::config_model( | |||||
} | } | ||||
} | } | ||||
TestcaseOption::TestcaseOption() { | |||||
void TestcaseOption::update() { | |||||
m_option_name = "run_testcase"; | m_option_name = "run_testcase"; | ||||
} | } | ||||
std::shared_ptr<OptionBase> TestcaseOption::create_option() { | std::shared_ptr<OptionBase> TestcaseOption::create_option() { | ||||
static std::shared_ptr<TestcaseOption> option(new TestcaseOption); | static std::shared_ptr<TestcaseOption> option(new TestcaseOption); | ||||
option->update(); | |||||
return std::static_pointer_cast<OptionBase>(option); | return std::static_pointer_cast<OptionBase>(option); | ||||
} | } | ||||
@@ -25,9 +25,11 @@ public: | |||||
OptionValMap* get_option() override { return &m_option; } | OptionValMap* get_option() override { return &m_option; } | ||||
void update() override; | |||||
private: | private: | ||||
//! Constructor | //! Constructor | ||||
StrategyOption(); | |||||
StrategyOption() = default; | |||||
//! configuration for different model implement | //! configuration for different model implement | ||||
std::string m_option_name; | std::string m_option_name; | ||||
@@ -52,9 +54,11 @@ public: | |||||
//! get option name | //! get option name | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
void update() override; | |||||
private: | private: | ||||
//! Constructor | //! Constructor | ||||
TestcaseOption(); | |||||
TestcaseOption() = default; | |||||
//! configuration for different model implement | //! configuration for different model implement | ||||
std::string m_option_name; | std::string m_option_name; | ||||
@@ -233,8 +233,8 @@ void OptionsTimeProfiler::profile_with_given_options( | |||||
"the log) when profile option:\n%s\n", | "the log) when profile option:\n%s\n", | ||||
option_code.c_str()); | option_code.c_str()); | ||||
} else { | } else { | ||||
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), | |||||
average); | |||||
mgb_log("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), | |||||
average); | |||||
//! record profile result | //! record profile result | ||||
m_options_profile_result.insert({option_code, average}); | m_options_profile_result.insert({option_code, average}); | ||||
@@ -370,7 +370,6 @@ void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) { | |||||
FittingStrategy::FittingStrategy(std::string model_path) { | FittingStrategy::FittingStrategy(std::string model_path) { | ||||
m_manager = std::make_shared<OptionsFastManager>(); | m_manager = std::make_shared<OptionsFastManager>(); | ||||
m_dumped_model = FLAGS_dump_fitting_model; | m_dumped_model = FLAGS_dump_fitting_model; | ||||
mgb::set_log_level(mgb::LogLevel::INFO); | |||||
m_options = std::make_shared<OptionMap>(); | m_options = std::make_shared<OptionMap>(); | ||||
m_model_path = model_path; | m_model_path = model_path; | ||||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | ||||
@@ -518,10 +517,10 @@ void FittingStrategy::AutoCleanFile::dump_model() { | |||||
void FittingStrategy::run() { | void FittingStrategy::run() { | ||||
auto mgb_version = mgb::get_version(); | auto mgb_version = mgb::get_version(); | ||||
auto dnn_version = megdnn::get_version(); | auto dnn_version = megdnn::get_version(); | ||||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||||
dnn_version.major, dnn_version.minor, dnn_version.patch); | |||||
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||||
dnn_version.major, dnn_version.minor, dnn_version.patch); | |||||
// ! create profiler with given user info | // ! create profiler with given user info | ||||
m_info_parser.get_user_info(); | m_info_parser.get_user_info(); | ||||
m_info_parser.parse_info(m_manager); | m_info_parser.parse_info(m_manager); | ||||
@@ -5,13 +5,10 @@ | |||||
#include "megbrain/utils/timer.h" | #include "megbrain/utils/timer.h" | ||||
#include "megbrain/version.h" | #include "megbrain/version.h" | ||||
#include "megdnn/version.h" | #include "megdnn/version.h" | ||||
#include "misc.h" | |||||
using namespace lar; | using namespace lar; | ||||
NormalStrategy::NormalStrategy(std::string model_path) { | NormalStrategy::NormalStrategy(std::string model_path) { | ||||
mgb::set_log_level(mgb::LogLevel::WARN); | |||||
lite::set_log_level(LiteLogLevel::WARN); | |||||
m_options = std::make_shared<OptionMap>(); | m_options = std::make_shared<OptionMap>(); | ||||
m_model_path = model_path; | m_model_path = model_path; | ||||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | ||||
@@ -47,7 +44,7 @@ void NormalStrategy::run_subline() { | |||||
mgb::RealTimer timer; | mgb::RealTimer timer; | ||||
model->load_model(); | model->load_model(); | ||||
printf("load model: %.3fms\n", timer.get_msecs_reset()); | |||||
mgb_log("load model: %.3fms\n", timer.get_msecs_reset()); | |||||
//! after load configure | //! after load configure | ||||
auto config_after_load = [&]() { | auto config_after_load = [&]() { | ||||
@@ -62,10 +59,10 @@ void NormalStrategy::run_subline() { | |||||
auto warm_up = [&]() { | auto warm_up = [&]() { | ||||
auto warmup_num = m_runtime_param.warmup_iter; | auto warmup_num = m_runtime_param.warmup_iter; | ||||
for (size_t i = 0; i < warmup_num; i++) { | for (size_t i = 0; i < warmup_num; i++) { | ||||
printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset()); | |||||
mgb_log("=== prepare: %.3fms; going to warmup", timer.get_msecs_reset()); | |||||
model->run_model(); | model->run_model(); | ||||
model->wait(); | model->wait(); | ||||
printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset()); | |||||
mgb_log("warm up %lu %.3fms", i, timer.get_msecs_reset()); | |||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | ||||
stage_config_model(); | stage_config_model(); | ||||
} | } | ||||
@@ -83,21 +80,21 @@ void NormalStrategy::run_subline() { | |||||
auto cur = timer.get_msecs(); | auto cur = timer.get_msecs(); | ||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | ||||
stage_config_model(); | stage_config_model(); | ||||
printf("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)\n", i, run_num, cur, | |||||
exec_time); | |||||
mgb_log("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)", i, run_num, cur, | |||||
exec_time); | |||||
time_sum += cur; | time_sum += cur; | ||||
time_sqrsum += cur * cur; | time_sqrsum += cur * cur; | ||||
fflush(stdout); | fflush(stdout); | ||||
min_time = std::min(min_time, cur); | min_time = std::min(min_time, cur); | ||||
max_time = std::max(max_time, cur); | max_time = std::max(max_time, cur); | ||||
} | } | ||||
printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||||
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms\n\n", | |||||
idx, time_sum, time_sum / run_num, | |||||
std::sqrt( | |||||
(time_sqrsum * run_num - time_sum * time_sum) / | |||||
(run_num * (run_num - 1))), | |||||
min_time, max_time); | |||||
mgb_log("=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||||
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms", | |||||
idx, time_sum, time_sum / run_num, | |||||
std::sqrt( | |||||
(time_sqrsum * run_num - time_sum * time_sum) / | |||||
(run_num * (run_num - 1))), | |||||
min_time, max_time); | |||||
return time_sum; | return time_sum; | ||||
}; | }; | ||||
@@ -122,7 +119,7 @@ void NormalStrategy::run_subline() { | |||||
stage_config_model(); | stage_config_model(); | ||||
} | } | ||||
printf("=== total time: %.3fms\n", tot_time); | |||||
mgb_log("=== total time: %.3fms\n", tot_time); | |||||
//! execute after run | //! execute after run | ||||
m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | ||||
stage_config_model(); | stage_config_model(); | ||||
@@ -131,9 +128,9 @@ void NormalStrategy::run_subline() { | |||||
void NormalStrategy::run() { | void NormalStrategy::run() { | ||||
auto v0 = mgb::get_version(); | auto v0 = mgb::get_version(); | ||||
auto v1 = megdnn::get_version(); | auto v1 = megdnn::get_version(); | ||||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||||
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||||
size_t thread_num = m_runtime_param.threads; | size_t thread_num = m_runtime_param.threads; | ||||
auto run_sub = [&]() { run_subline(); }; | auto run_sub = [&]() { run_subline(); }; | ||||
@@ -73,7 +73,7 @@ public: | |||||
#define LITE_LOG_(level, msg...) (void)0 | #define LITE_LOG_(level, msg...) (void)0 | ||||
#endif | #endif | ||||
#define LITE_LOG(fmt...) LITE_LOG_(DEBUG, fmt); | |||||
#define LITE_LOG(fmt...) LITE_LOG_(INFO, fmt); | |||||
#define LITE_DEBUG(fmt...) LITE_LOG_(DEBUG, fmt); | #define LITE_DEBUG(fmt...) LITE_LOG_(DEBUG, fmt); | ||||
#define LITE_WARN(fmt...) LITE_LOG_(WARN, fmt); | #define LITE_WARN(fmt...) LITE_LOG_(WARN, fmt); | ||||
#define LITE_ERROR(fmt...) LITE_LOG_(ERROR, fmt); | #define LITE_ERROR(fmt...) LITE_LOG_(ERROR, fmt); | ||||
@@ -1,14 +1,14 @@ | |||||
if(MGE_WITH_TEST) | if(MGE_WITH_TEST) | ||||
include_directories(PUBLIC | |||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||||
file(GLOB_RECURSE SOURCES ./*.cpp main.cpp) | file(GLOB_RECURSE SOURCES ./*.cpp main.cpp) | ||||
add_executable(lite_test ${SOURCES}) | add_executable(lite_test ${SOURCES}) | ||||
target_link_libraries(lite_test lar_object) | |||||
target_link_libraries(lite_test gtest) | target_link_libraries(lite_test gtest) | ||||
target_link_libraries(lite_test lite_static) | |||||
if(LITE_BUILD_WITH_MGE) | if(LITE_BUILD_WITH_MGE) | ||||
# lite_test will depends megbrain interface | |||||
target_link_libraries(lite_test megbrain) | |||||
if(MGE_WITH_ROCM) | if(MGE_WITH_ROCM) | ||||
# FIXME: hip obj can not find cpp obj only through lite_static | # FIXME: hip obj can not find cpp obj only through lite_static | ||||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||||
target_link_libraries(lite_test megdnn) | target_link_libraries(lite_test megdnn) | ||||
endif() | endif() | ||||
endif() | endif() | ||||
@@ -0,0 +1,85 @@ | |||||
#include <gtest/gtest.h> | |||||
#include <string.h> | |||||
#include <memory> | |||||
#include "test_options.h" | |||||
using namespace lar; | |||||
DECLARE_bool(lite); | |||||
DECLARE_bool(cpu); | |||||
#if LITE_WITH_CUDA | |||||
DECLARE_bool(cuda); | |||||
#endif | |||||
DECLARE_bool(enable_nchw4); | |||||
DECLARE_bool(enable_chwn4); | |||||
DECLARE_bool(enable_nchw44); | |||||
DECLARE_bool(enable_nchw88); | |||||
DECLARE_bool(enable_nchw32); | |||||
DECLARE_bool(enable_nchw64); | |||||
DECLARE_bool(enable_nhwcd4); | |||||
DECLARE_bool(enable_nchw44_dot); | |||||
namespace { | |||||
BOOL_OPTION_WRAP(enable_nchw4); | |||||
BOOL_OPTION_WRAP(enable_chwn4); | |||||
BOOL_OPTION_WRAP(enable_nchw44); | |||||
BOOL_OPTION_WRAP(enable_nchw88); | |||||
BOOL_OPTION_WRAP(enable_nchw32); | |||||
BOOL_OPTION_WRAP(enable_nchw64); | |||||
BOOL_OPTION_WRAP(enable_nhwcd4); | |||||
BOOL_OPTION_WRAP(enable_nchw44_dot); | |||||
BOOL_OPTION_WRAP(lite); | |||||
BOOL_OPTION_WRAP(cpu); | |||||
#if LITE_WITH_CUDA | |||||
BOOL_OPTION_WRAP(cuda); | |||||
#endif | |||||
} // anonymous namespace | |||||
TEST(TestLarLayout, X86_CPU) { | |||||
DEFINE_WRAP(cpu); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(enable_nchw4); | |||||
TEST_BOOL_OPTION(enable_chwn4); | |||||
TEST_BOOL_OPTION(enable_nchw44); | |||||
TEST_BOOL_OPTION(enable_nchw44_dot); | |||||
TEST_BOOL_OPTION(enable_nchw64); | |||||
TEST_BOOL_OPTION(enable_nchw32); | |||||
TEST_BOOL_OPTION(enable_nchw88); | |||||
} | |||||
TEST(TestLarLayout, X86_CPU_LITE) { | |||||
DEFINE_WRAP(cpu); | |||||
DEFINE_WRAP(lite); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(enable_nchw4); | |||||
TEST_BOOL_OPTION(enable_nchw44); | |||||
TEST_BOOL_OPTION(enable_nchw44_dot); | |||||
TEST_BOOL_OPTION(enable_nchw64); | |||||
TEST_BOOL_OPTION(enable_nchw32); | |||||
TEST_BOOL_OPTION(enable_nchw88); | |||||
} | |||||
#if LITE_WITH_CUDA | |||||
TEST(TestLarLayout, CUDA) { | |||||
DEFINE_WRAP(cuda); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(enable_nchw4); | |||||
TEST_BOOL_OPTION(enable_chwn4); | |||||
TEST_BOOL_OPTION(enable_nchw64); | |||||
TEST_BOOL_OPTION(enable_nchw32); | |||||
FLAGS_cuda = false; | |||||
} | |||||
TEST(TestLarLayout, CUDA_LITE) { | |||||
DEFINE_WRAP(cuda); | |||||
DEFINE_WRAP(lite); | |||||
std::string model_path = "./shufflenet.mge"; | |||||
TEST_BOOL_OPTION(enable_nchw4); | |||||
TEST_BOOL_OPTION(enable_nchw64); | |||||
TEST_BOOL_OPTION(enable_nchw32); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,12 @@ | |||||
#include "test_options.h" | |||||
using namespace lar; | |||||
void lar::run_NormalStrategy(std::string model_path) { | |||||
auto origin_level = mgb::get_log_level(); | |||||
mgb::set_log_level(mgb::LogLevel::WARN); | |||||
NormalStrategy strategy(model_path); | |||||
strategy.run(); | |||||
mgb::set_log_level(origin_level); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,33 @@ | |||||
#pragma once | |||||
#include <iostream> | |||||
#include <thread> | |||||
#include "../load_and_run/src/strategys/strategy.h" | |||||
#include "../load_and_run/src/strategys/strategy_normal.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/utils/timer.h" | |||||
#include "megbrain/version.h" | |||||
#include "megdnn/version.h" | |||||
#include "misc.h" | |||||
namespace lar { | |||||
//! run load_and_run NormalStrategy to test different options | |||||
void run_NormalStrategy(std::string model_path); | |||||
} // namespace lar | |||||
#define BOOL_OPTION_WRAP(option) \ | |||||
struct BoolOptionWrap_##option { \ | |||||
BoolOptionWrap_##option() { FLAGS_##option = true; } \ | |||||
~BoolOptionWrap_##option() { FLAGS_##option = false; } \ | |||||
}; | |||||
#define DEFINE_WRAP(option) BoolOptionWrap_##option flags_##option; | |||||
#define TEST_BOOL_OPTION(option) \ | |||||
{ \ | |||||
BoolOptionWrap_##option flags_##option; \ | |||||
run_NormalStrategy(model_path); \ | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |