@@ -1,6 +1,6 @@ | |||
load("//brain/megbrain/lite:flags.bzl","pthread_select", "lite_opts") | |||
cc_library( | |||
name = "mgblar", | |||
name = "lar_object", | |||
srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), | |||
hdrs = glob(["src/**/*.h"]), | |||
includes = ["src"], | |||
@@ -28,7 +28,8 @@ cc_megvii_binary( | |||
"no_exceptions", | |||
"no_rtti", | |||
]), | |||
internal_deps = [":mgblar"], | |||
internal_deps = [":lar_object"], | |||
visibility = ["//visibility:public"], | |||
) | |||
@@ -1,23 +1,35 @@ | |||
# BUILD the load and run for lite | |||
include_directories(PUBLIC | |||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||
file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||
file(GLOB_RECURSE SOURCES src/**/*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||
add_executable(load_and_run ${SOURCES}) | |||
target_link_libraries(load_and_run lite_static) | |||
target_link_libraries(load_and_run megbrain) | |||
add_library(lar_object OBJECT ${SOURCES}) | |||
target_link_libraries(lar_object lite_static) | |||
target_link_libraries(lar_object megbrain) | |||
if(APPLE) | |||
target_link_libraries(load_and_run gflags) | |||
target_link_libraries(lar_object gflags) | |||
else() | |||
target_link_libraries(load_and_run gflags -Wl,--version-script=${MGE_VERSION_SCRIPT}) | |||
target_link_libraries(lar_object gflags -Wl,--version-script=${MGE_VERSION_SCRIPT}) | |||
endif() | |||
if(LITE_BUILD_WITH_MGE | |||
AND NOT WIN32 | |||
AND NOT APPLE) | |||
# FXIME third_party cpp redis do not support build with clang-cl | |||
target_include_directories(lar_object PRIVATE ${CPP_REDIS_INCLUDES}) | |||
endif() | |||
add_executable(load_and_run src/main.cpp) | |||
target_link_libraries(load_and_run lar_object) | |||
if(LITE_BUILD_WITH_RKNPU) | |||
# rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check | |||
target_link_options(load_and_run PRIVATE "-fuse-ld=gold") | |||
endif() | |||
if(MGE_WITH_ROCM) | |||
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM) | |||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||
# FIXME: hip obj can not find cpp obj only through lite_static | |||
target_link_libraries(load_and_run megdnn) | |||
endif() | |||
@@ -30,17 +42,11 @@ if(UNIX) | |||
endif() | |||
endif() | |||
if(LITE_BUILD_WITH_MGE | |||
AND NOT WIN32 | |||
AND NOT APPLE) | |||
# FXIME third_party cpp redis do not support build with clang-cl | |||
target_include_directories(load_and_run PRIVATE ${CPP_REDIS_INCLUDES}) | |||
endif() | |||
install( | |||
TARGETS load_and_run | |||
EXPORT ${LITE_EXPORT_TARGETS} | |||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) | |||
if(BUILD_SHARED_LIBS) | |||
if(LITE_BUILD_WITH_MGE | |||
AND NOT WIN32 | |||
@@ -48,7 +54,7 @@ if(BUILD_SHARED_LIBS) | |||
# FXIME third_party cpp redis do not support build with clang-cl | |||
list(APPEND SOURCES ${CPP_REDIS_SRCS}) | |||
endif() | |||
add_executable(load_and_run_depends_shared ${SOURCES}) | |||
add_executable(load_and_run_depends_shared ${SOURCES} src/main.cpp) | |||
target_link_libraries(load_and_run_depends_shared lite_shared) | |||
target_link_libraries(load_and_run_depends_shared gflags) | |||
target_link_libraries(load_and_run_depends_shared megengine) | |||
@@ -58,7 +64,8 @@ if(BUILD_SHARED_LIBS) | |||
target_link_options(load_and_run_depends_shared PRIVATE "-fuse-ld=gold") | |||
endif() | |||
if(MGE_WITH_ROCM) | |||
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM) | |||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||
# FIXME: hip obj can not find cpp obj only through lite_static | |||
target_link_libraries(load_and_run_depends_shared megdnn) | |||
endif() | |||
@@ -30,6 +30,8 @@ enum class RunStage { | |||
AFTER_MODEL_RUNNING = 7, | |||
GLOBAL_OPTIMIZATION = 8, | |||
UPDATE_IO = 9, | |||
}; | |||
/*! | |||
* \brief: type of different model | |||
@@ -1,7 +1,7 @@ | |||
#include <gflags/gflags.h> | |||
#include <string> | |||
#include "misc.h" | |||
#include "strategys/strategy.h" | |||
std::string simple_usage = R"( | |||
load_and_run: load_and_run <model_path> [options Flags...] | |||
@@ -29,6 +29,8 @@ More details using "--help" to get!! | |||
)"; | |||
int main(int argc, char** argv) { | |||
mgb::set_log_level(mgb::LogLevel::INFO); | |||
lite::set_log_level(LiteLogLevel::INFO); | |||
std::string usage = "load_and_run <model_path> [options Flags...]"; | |||
if (argc < 2) { | |||
printf("usage: %s\n", simple_usage.c_str()); | |||
@@ -8,17 +8,17 @@ DECLARE_bool(share_param_mem); | |||
using namespace lar; | |||
ModelLite::ModelLite(const std::string& path) : model_path(path) { | |||
LITE_WARN("creat lite model use CPU as default comp node"); | |||
LITE_LOG("creat lite model use CPU as default comp node"); | |||
}; | |||
void ModelLite::load_model() { | |||
m_network = std::make_shared<lite::Network>(config, IO); | |||
if (enable_layout_transform) { | |||
LITE_WARN("enable layout transform while load model for lite"); | |||
LITE_LOG("enable layout transform while load model for lite"); | |||
lite::Runtime::enable_global_layout_transform(m_network); | |||
} | |||
if (share_model_mem) { | |||
//! WARNNING:maybe not right to share param memmory for this | |||
LITE_WARN("enable share model memory"); | |||
LITE_LOG("enable share model memory"); | |||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
@@ -19,27 +19,27 @@ void XPUDeviceOption::config_model_internel<ModelLite>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || | |||
(enable_multithread_default)) { | |||
LITE_WARN("using cpu device\n"); | |||
LITE_LOG("using cpu device\n"); | |||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | |||
} | |||
#if LITE_WITH_CUDA | |||
if (enable_cuda) { | |||
LITE_WARN("using cuda device\n"); | |||
LITE_LOG("using cuda device\n"); | |||
model->get_config().device_type = LiteDeviceType::LITE_CUDA; | |||
} | |||
#endif | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
auto&& network = model->get_lite_network(); | |||
if (enable_cpu_default) { | |||
LITE_WARN("using cpu default device\n"); | |||
LITE_LOG("using cpu default device\n"); | |||
lite::Runtime::set_cpu_inplace_mode(network); | |||
} | |||
if (enable_multithread) { | |||
LITE_WARN("using multithread device\n"); | |||
LITE_LOG("using multithread device\n"); | |||
lite::Runtime::set_cpu_threads_number(network, thread_num); | |||
} | |||
if (enable_multithread_default) { | |||
LITE_WARN("using multithread default device\n"); | |||
LITE_LOG("using multithread default device\n"); | |||
lite::Runtime::set_cpu_inplace_mode(network); | |||
lite::Runtime::set_cpu_threads_number(network, thread_num); | |||
} | |||
@@ -48,7 +48,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>( | |||
for (auto id : core_ids) { | |||
core_str += std::to_string(id) + ","; | |||
} | |||
LITE_WARN("multi thread core ids: %s\n", core_str.c_str()); | |||
LITE_LOG("multi thread core ids: %s\n", core_str.c_str()); | |||
lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { | |||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||
}; | |||
@@ -62,14 +62,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (enable_cpu) { | |||
mgb_log_warn("using cpu device\n"); | |||
mgb_log("using cpu device\n"); | |||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
loc.type = mgb::CompNode::DeviceType::CPU; | |||
}; | |||
} | |||
#if LITE_WITH_CUDA | |||
if (enable_cuda) { | |||
mgb_log_warn("using cuda device\n"); | |||
mgb_log("using cuda device\n"); | |||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | |||
loc.type = mgb::CompNode::DeviceType::CUDA; | |||
@@ -79,14 +79,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
} | |||
#endif | |||
if (enable_cpu_default) { | |||
mgb_log_warn("using cpu default device\n"); | |||
mgb_log("using cpu default device\n"); | |||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
loc.type = mgb::CompNode::DeviceType::CPU; | |||
loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||
}; | |||
} | |||
if (enable_multithread) { | |||
mgb_log_warn("using multithread device\n"); | |||
mgb_log("using multithread device\n"); | |||
model->get_mdl_config().comp_node_mapper = | |||
[&](mgb::CompNode::Locator& loc) { | |||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
@@ -95,7 +95,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
}; | |||
} | |||
if (enable_multithread_default) { | |||
mgb_log_warn("using multithread default device\n"); | |||
mgb_log("using multithread default device\n"); | |||
model->get_mdl_config().comp_node_mapper = | |||
[&](mgb::CompNode::Locator& loc) { | |||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
@@ -108,7 +108,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
for (auto id : core_ids) { | |||
core_str += std::to_string(id) + ","; | |||
} | |||
mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str()); | |||
mgb_log("set multi thread core ids:%s\n", core_str.c_str()); | |||
auto affinity_callback = [&](size_t thread_id) { | |||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||
}; | |||
@@ -122,7 +122,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
} | |||
} // namespace lar | |||
XPUDeviceOption::XPUDeviceOption() { | |||
void XPUDeviceOption::update() { | |||
m_option_name = "xpu_device"; | |||
enable_cpu = FLAGS_cpu; | |||
#if LITE_WITH_CUDA | |||
@@ -198,6 +198,7 @@ bool XPUDeviceOption::is_valid() { | |||
std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||
static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption); | |||
if (XPUDeviceOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<lar::OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -24,8 +24,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
XPUDeviceOption(); | |||
XPUDeviceOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
bool enable_cpu; | |||
@@ -25,6 +25,7 @@ void COprLibOption::config_model_internel( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (!lib_path.empty()) { | |||
mgb_log("load external C opr lib from %s\n", lib_path.c_str()); | |||
load_lib(); | |||
} | |||
if (c_opr_args.is_run_c_opr_with_param) { | |||
@@ -176,7 +177,7 @@ void COprLibOption::set_Copr_IO(std::shared_ptr<ModelBase> model_ptr) { | |||
config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); | |||
} | |||
COprLibOption::COprLibOption() { | |||
void COprLibOption::update() { | |||
m_option_name = "c_opr_lib"; | |||
lib_path = FLAGS_c_opr_lib; | |||
c_opr_args.is_run_c_opr = !lib_path.empty(); | |||
@@ -191,6 +192,7 @@ bool COprLibOption::is_valid() { | |||
std::shared_ptr<OptionBase> COprLibOption::create_option() { | |||
static std::shared_ptr<COprLibOption> option(new COprLibOption); | |||
if (COprLibOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -32,8 +32,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
COprLibOption(); | |||
COprLibOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -25,10 +25,10 @@ void FastRunOption::config_model_internel<ModelLite>( | |||
uint32_t strategy = 0; | |||
#if MGB_ENABLE_FASTRUN | |||
if (enable_full_run) { | |||
LITE_WARN("enable full-run strategy for algo profile"); | |||
LITE_LOG("enable full-run strategy for algo profile"); | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy; | |||
} else if (enable_fast_run) { | |||
LITE_WARN("enable fast-run strategy for algo profile"); | |||
LITE_LOG("enable fast-run strategy for algo profile"); | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | |||
} else { | |||
@@ -38,7 +38,7 @@ void FastRunOption::config_model_internel<ModelLite>( | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||
#endif | |||
if (batch_binary_equal || enable_reproducible) { | |||
LITE_WARN("enable reproducible strategy for algo profile"); | |||
LITE_LOG("enable reproducible strategy for algo profile"); | |||
if (batch_binary_equal) | |||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) | | |||
strategy; | |||
@@ -81,10 +81,10 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||
auto strategy = static_cast<Strategy>(0); | |||
#if MGB_ENABLE_FASTRUN | |||
if (enable_full_run) { | |||
mgb_log_warn("enable full-run strategy for algo profile"); | |||
mgb_log("enable full-run strategy for algo profile"); | |||
strategy = Strategy::PROFILE | strategy; | |||
} else if (enable_fast_run) { | |||
mgb_log_warn("enable fast-run strategy for algo profile"); | |||
mgb_log("enable fast-run strategy for algo profile"); | |||
strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; | |||
} else { | |||
strategy = Strategy::HEURISTIC | strategy; | |||
@@ -93,20 +93,20 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||
strategy = Strategy::HEURISTIC | strategy; | |||
#endif | |||
if (batch_binary_equal || enable_reproducible) { | |||
mgb_log_warn("enable reproducible strategy for algo profile"); | |||
mgb_log("enable reproducible strategy for algo profile"); | |||
strategy = Strategy::REPRODUCIBLE | strategy; | |||
} | |||
model->set_mdl_strategy(strategy); | |||
//! set binary_equal_between_batch and shared_batch_size | |||
if (batch_binary_equal) { | |||
mgb_log_warn("enable batch binary equal"); | |||
mgb_log("enable batch binary equal"); | |||
model->get_mdl_config() | |||
.comp_graph->options() | |||
.fast_run_config.binary_equal_between_batch = true; | |||
} | |||
if (share_batch_size > 0) { | |||
mgb_log_warn("set shared shared batch"); | |||
mgb_log("set shared shared batch"); | |||
model->get_mdl_config() | |||
.comp_graph->options() | |||
.fast_run_config.shared_batch_size = share_batch_size; | |||
@@ -145,7 +145,7 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||
using namespace lar; | |||
bool FastRunOption::m_valid; | |||
FastRunOption::FastRunOption() { | |||
void FastRunOption::update() { | |||
m_option_name = "fastrun"; | |||
#if MGB_ENABLE_FASTRUN | |||
enable_fast_run = FLAGS_fast_run; | |||
@@ -207,6 +207,7 @@ bool FastRunOption::is_valid() { | |||
std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||
static std::shared_ptr<FastRunOption> option(new FastRunOption); | |||
if (FastRunOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -250,7 +251,7 @@ DEFINE_bool( | |||
"https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" | |||
"index.html#reproducibility" | |||
"for more details."); | |||
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||
DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | |||
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); |
@@ -10,7 +10,7 @@ DECLARE_bool(full_run); | |||
#endif | |||
DECLARE_bool(reproducible); | |||
DECLARE_bool(binary_equal_between_batch); | |||
DECLARE_uint32(fast_run_shared_batch_size); | |||
DECLARE_int32(fast_run_shared_batch_size); | |||
DECLARE_string(fast_run_algo_policy); | |||
namespace lar { | |||
@@ -33,8 +33,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
FastRunOption(); | |||
FastRunOption() = default; | |||
//! config template for different model | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | |||
@@ -93,11 +93,11 @@ void IOdumpOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (enable_io_dump) { | |||
LITE_WARN("enable text io dump"); | |||
LITE_LOG("enable text io dump"); | |||
lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); | |||
} | |||
if (enable_bin_io_dump) { | |||
LITE_WARN("enable binary io dump"); | |||
LITE_LOG("enable binary io dump"); | |||
lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); | |||
} | |||
//! FIX:when add API in lite complate this | |||
@@ -108,7 +108,7 @@ void IOdumpOption::config_model_internel<ModelLite>( | |||
LITE_THROW("lite model don't support the binary output dump"); | |||
} | |||
if (enable_copy_to_host) { | |||
LITE_WARN("lite model set copy to host defaultly"); | |||
LITE_LOG("lite model set copy to host defaultly"); | |||
} | |||
} | |||
} | |||
@@ -118,7 +118,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (enable_io_dump) { | |||
mgb_log_warn("enable text io dump"); | |||
mgb_log("enable text io dump"); | |||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
model->get_mdl_config().comp_graph.get(), dump_path.c_str()); | |||
iodump->print_addr(false); | |||
@@ -126,7 +126,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||
} | |||
if (enable_io_dump_stdout) { | |||
mgb_log_warn("enable text io dump to stdout"); | |||
mgb_log("enable text io dump to stdout"); | |||
std::shared_ptr<FILE> std_out(stdout, [](FILE*) {}); | |||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
model->get_mdl_config().comp_graph.get(), std_out); | |||
@@ -135,7 +135,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||
} | |||
if (enable_io_dump_stderr) { | |||
mgb_log_warn("enable text io dump to stderr"); | |||
mgb_log("enable text io dump to stderr"); | |||
std::shared_ptr<FILE> std_err(stderr, [](FILE*) {}); | |||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
model->get_mdl_config().comp_graph.get(), std_err); | |||
@@ -144,14 +144,14 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||
} | |||
if (enable_bin_io_dump) { | |||
mgb_log_warn("enable binary io dump"); | |||
mgb_log("enable binary io dump"); | |||
auto iodump = std::make_unique<mgb::BinaryOprIODump>( | |||
model->get_mdl_config().comp_graph.get(), dump_path); | |||
io_dumper = std::move(iodump); | |||
} | |||
if (enable_bin_out_dump) { | |||
mgb_log_warn("enable binary output dump"); | |||
mgb_log("enable binary output dump"); | |||
out_dumper = std::make_unique<OutputDumper>(dump_path.c_str()); | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
@@ -190,7 +190,7 @@ void IOdumpOption::config_model_internel<ModelMdl>( | |||
////////////////////// Input options //////////////////////// | |||
using namespace lar; | |||
InputOption::InputOption() { | |||
void InputOption::update() { | |||
m_option_name = "input"; | |||
size_t start = 0; | |||
auto end = FLAGS_input.find(";", start); | |||
@@ -204,9 +204,10 @@ InputOption::InputOption() { | |||
} | |||
std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() { | |||
static std::shared_ptr<InputOption> m_option(new InputOption); | |||
static std::shared_ptr<InputOption> option(new InputOption); | |||
if (InputOption::is_valid()) { | |||
return std::static_pointer_cast<OptionBase>(m_option); | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
} | |||
@@ -219,7 +220,7 @@ void InputOption::config_model( | |||
////////////////////// OprIOdump options //////////////////////// | |||
IOdumpOption::IOdumpOption() { | |||
void IOdumpOption::update() { | |||
m_option_name = "iodump"; | |||
size_t valid_flag = 0; | |||
if (!FLAGS_io_dump.empty()) { | |||
@@ -268,6 +269,7 @@ bool IOdumpOption::is_valid() { | |||
std::shared_ptr<OptionBase> IOdumpOption::create_option() { | |||
static std::shared_ptr<IOdumpOption> option(new IOdumpOption); | |||
if (IOdumpOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -30,8 +30,10 @@ public: | |||
//! interface implement from OptionBase | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
InputOption(); | |||
InputOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -50,8 +52,10 @@ public: | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
IOdumpOption(); | |||
IOdumpOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -11,7 +11,7 @@ void LayoutOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
#define ENABLE_LAYOUT(layout) \ | |||
LITE_WARN("enable " #layout " optimization"); \ | |||
LITE_LOG("enable " #layout " optimization"); \ | |||
model->get_config().options.enable_##layout = true; \ | |||
break; | |||
@@ -51,7 +51,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
#define ENABLE_LAYOUT(layout) \ | |||
mgb_log_warn("enable " #layout " optimization"); \ | |||
mgb_log("enable " #layout " optimization"); \ | |||
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | |||
break; | |||
@@ -91,7 +91,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
using namespace lar; | |||
bool LayoutOption::m_valid; | |||
LayoutOption::LayoutOption() { | |||
void LayoutOption::update() { | |||
m_option_name = "layout"; | |||
m_option_flag = static_cast<OptLayoutType>(0); | |||
m_option = { | |||
@@ -157,6 +157,7 @@ bool LayoutOption::is_valid() { | |||
std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||
static std::shared_ptr<LayoutOption> option(new LayoutOption); | |||
if (LayoutOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -166,16 +167,20 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||
void LayoutOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
size_t valid_flag = 0; | |||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||
if (FLAGS_enable_nchw4 || | |||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | |||
} | |||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||
if (FLAGS_enable_chwn4 || | |||
std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | |||
} | |||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||
if (FLAGS_enable_nchw44 || | |||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | |||
} | |||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||
if (FLAGS_enable_nchw88 || | |||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | |||
} | |||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) { | |||
@@ -37,9 +37,11 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
//! Constructor | |||
LayoutOption(); | |||
LayoutOption() = default; | |||
//! configuration for different model implement | |||
template <typename ModelImpl> | |||
@@ -11,7 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (m_layout_transform) { | |||
LITE_WARN("using global layout transform optimization\n"); | |||
LITE_LOG("using global layout transform optimization\n"); | |||
if (m_layout_transform_target == | |||
mgb::gopt::GraphTuningOptions::Target::CPU) { | |||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | |||
@@ -98,7 +98,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
} | |||
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | |||
if (m_layout_transform) { | |||
mgb_log_warn("using global layout transform optimization\n"); | |||
mgb_log("using global layout transform optimization\n"); | |||
auto&& load_result = model->get_mdl_load_result(); | |||
load_result.output_var_list = mgb::gopt::layout_transform( | |||
load_result.output_var_list, m_layout_transform_target); | |||
@@ -150,7 +150,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||
using namespace lar; | |||
bool GoptLayoutOption::m_valid; | |||
GoptLayoutOption::GoptLayoutOption() { | |||
void GoptLayoutOption::update() { | |||
m_option_name = "gopt_layout"; | |||
if (FLAGS_layout_transform != "cpu" | |||
#if LITE_WITH_CUDA | |||
@@ -216,6 +216,7 @@ bool GoptLayoutOption::is_valid() { | |||
std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | |||
static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption); | |||
if (GoptLayoutOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -28,8 +28,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
GoptLayoutOption(); | |||
GoptLayoutOption() = default; | |||
//! config template for different model | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | |||
@@ -24,7 +24,7 @@ void PackModelOption::config_model_internel( | |||
using namespace lar; | |||
////////////////////// PackModel options //////////////////////// | |||
PackModelOption::PackModelOption() { | |||
void PackModelOption::update() { | |||
m_option_name = "pack_model"; | |||
if (!FLAGS_packed_model_dump.empty()) | |||
packed_model_dump = FLAGS_packed_model_dump; | |||
@@ -45,6 +45,7 @@ bool PackModelOption::is_valid() { | |||
std::shared_ptr<OptionBase> PackModelOption::create_option() { | |||
static std::shared_ptr<PackModelOption> option(new PackModelOption); | |||
if (PackModelOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -19,8 +19,10 @@ public: | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
std::string option_name() const override { return m_option_name; } | |||
void update() override; | |||
private: | |||
PackModelOption(); | |||
PackModelOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | |||
@@ -15,7 +15,7 @@ void FusePreprocessOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (enable_fuse_preprocess) { | |||
LITE_WARN("enable fuse-preprocess optimization"); | |||
LITE_LOG("enable fuse-preprocess optimization"); | |||
model->get_config().options.fuse_preprocess = true; | |||
} | |||
} | |||
@@ -27,7 +27,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (enable_fuse_preprocess) { | |||
mgb_log_warn("enable fuse-preprocess optimization"); | |||
mgb_log("enable fuse-preprocess optimization"); | |||
graph_option.graph_opt.enable_fuse_preprocess(); | |||
} | |||
} | |||
@@ -35,7 +35,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||
} // namespace lar | |||
using namespace lar; | |||
bool FusePreprocessOption::m_valid; | |||
FusePreprocessOption::FusePreprocessOption() { | |||
void FusePreprocessOption::update() { | |||
m_option_name = "fuse_preprocess"; | |||
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | |||
m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}}; | |||
@@ -51,6 +51,7 @@ bool FusePreprocessOption::is_valid() { | |||
std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||
static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption); | |||
if (FusePreprocessOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -73,7 +74,7 @@ void WeightPreprocessOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (weight_preprocess) { | |||
LITE_WARN("enable weight-preprocess optimization"); | |||
LITE_LOG("enable weight-preprocess optimization"); | |||
model->get_config().options.weight_preprocess = true; | |||
} | |||
} | |||
@@ -85,14 +86,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (weight_preprocess) { | |||
mgb_log_warn("enable weight-preprocess optimization"); | |||
mgb_log("enable weight-preprocess optimization"); | |||
graph_option.graph_opt.enable_weight_preprocess(); | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
WeightPreprocessOption::WeightPreprocessOption() { | |||
void WeightPreprocessOption::update() { | |||
m_option_name = "weight_preprocess"; | |||
weight_preprocess = FLAGS_weight_preprocess; | |||
m_option = {{"weight_preprocess", lar::Bool::make(false)}}; | |||
@@ -108,6 +109,7 @@ bool WeightPreprocessOption::is_valid() { | |||
std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||
static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption); | |||
if (WeightPreprocessOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -142,14 +144,14 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (enable_fuse_conv_bias_nonlinearity) { | |||
mgb_log_warn("enable fuse conv+bias+nonlinearity optimization"); | |||
mgb_log("enable fuse conv+bias+nonlinearity optimization"); | |||
graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | |||
void FuseConvBiasNonlinearOption::update() { | |||
m_option_name = "fuse_conv_bias_nonlinearity"; | |||
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||
m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}}; | |||
@@ -166,6 +168,7 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||
static std::shared_ptr<FuseConvBiasNonlinearOption> option( | |||
new FuseConvBiasNonlinearOption); | |||
if (FuseConvBiasNonlinearOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -203,14 +206,14 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (enable_fuse_conv_bias_with_z) { | |||
mgb_log_warn("enable fuse conv+bias+z optimization"); | |||
mgb_log("enable fuse conv+bias+z optimization"); | |||
graph_option.graph_opt.enable_fuse_conv_bias_with_z(); | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | |||
void FuseConvBiasElemwiseAddOption::update() { | |||
m_option_name = "fuse_conv_bias_with_z"; | |||
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | |||
m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}}; | |||
@@ -227,6 +230,7 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||
static std::shared_ptr<FuseConvBiasElemwiseAddOption> option( | |||
new FuseConvBiasElemwiseAddOption); | |||
if (FuseConvBiasElemwiseAddOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -250,26 +254,26 @@ void GraphRecordOption::config_model_internel<ModelLite>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& config_option = model->get_config().options; | |||
if (const_shape) { | |||
LITE_WARN("enable const var shape"); | |||
LITE_LOG("enable const var shape"); | |||
config_option.const_shape = true; | |||
} | |||
if (fake_first) { | |||
LITE_WARN("enable fake-first optimization"); | |||
LITE_LOG("enable fake-first optimization"); | |||
config_option.fake_next_exec = true; | |||
} | |||
if (no_sanity_check) { | |||
LITE_WARN("disable var sanity check optimization"); | |||
LITE_LOG("disable var sanity check optimization"); | |||
config_option.var_sanity_check_first_run = false; | |||
} | |||
if (m_record_comp_seq == 1) { | |||
LITE_WARN("set record_comp_seq_level to 1"); | |||
LITE_LOG("set record_comp_seq_level to 1"); | |||
} | |||
if (m_record_comp_seq == 2) { | |||
mgb_assert( | |||
no_sanity_check, | |||
"--no-sanity-check should be set before " | |||
"--record-comp-seq2"); | |||
LITE_WARN("set record_comp_seq_level to 2"); | |||
LITE_LOG("set record_comp_seq_level to 2"); | |||
} | |||
config_option.comp_node_seq_record_level = m_record_comp_seq; | |||
} | |||
@@ -281,33 +285,33 @@ void GraphRecordOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (const_shape) { | |||
mgb_log_warn("enable const var shape"); | |||
mgb_log("enable const var shape"); | |||
model->get_mdl_config().const_var_shape = true; | |||
} | |||
if (fake_first) { | |||
mgb_log_warn("enable fake-first optimization"); | |||
mgb_log("enable fake-first optimization"); | |||
graph_option.fake_next_exec = true; | |||
} | |||
if (no_sanity_check) { | |||
mgb_log_warn("disable var sanity check optimization"); | |||
mgb_log("disable var sanity check optimization"); | |||
graph_option.var_sanity_check_first_run = false; | |||
} | |||
if (m_record_comp_seq == 1) { | |||
mgb_log_warn("set record_comp_seq_level to 1"); | |||
mgb_log("set record_comp_seq_level to 1"); | |||
} | |||
if (m_record_comp_seq == 2) { | |||
mgb_assert( | |||
no_sanity_check && !fake_first, | |||
"--no-sanity-check should be set before " | |||
"--record-comp-seq2 and --fake-first should not be set"); | |||
mgb_log_warn("set record_comp_seq_level to 2"); | |||
mgb_log("set record_comp_seq_level to 2"); | |||
} | |||
graph_option.comp_node_seq_record_level = m_record_comp_seq; | |||
} | |||
} | |||
} // namespace lar | |||
GraphRecordOption::GraphRecordOption() { | |||
void GraphRecordOption::update() { | |||
m_option_name = "graph_record"; | |||
m_record_comp_seq = 0; | |||
const_shape = FLAGS_const_shape; | |||
@@ -350,6 +354,7 @@ bool GraphRecordOption::is_valid() { | |||
std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||
static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption); | |||
if (GraphRecordOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -387,7 +392,7 @@ void MemoryOptimizeOption::config_model_internel<ModelLite>( | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (workspace_limit != SIZE_MAX) { | |||
LITE_WARN("set workspace limit to %ld", workspace_limit); | |||
LITE_LOG("set workspace limit to %ld", workspace_limit); | |||
lite::Runtime::set_network_algo_workspace_limit( | |||
model->get_lite_network(), workspace_limit); | |||
} | |||
@@ -400,12 +405,12 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (disable_mem_opt) { | |||
mgb_log_warn("disable memory optimization"); | |||
mgb_log("disable memory optimization"); | |||
graph_option.seq_opt.enable_mem_plan_opt = false; | |||
graph_option.seq_opt.enable_mem_reuse_alloc = false; | |||
} | |||
if (workspace_limit < SIZE_MAX) { | |||
mgb_log_warn("set workspace limit to %ld", workspace_limit); | |||
mgb_log("set workspace limit to %ld", workspace_limit); | |||
auto&& output_spec = model->get_output_spec(); | |||
mgb::SymbolVarArray vars; | |||
for (auto i : output_spec) { | |||
@@ -417,7 +422,7 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||
} | |||
} // namespace lar | |||
MemoryOptimizeOption::MemoryOptimizeOption() { | |||
void MemoryOptimizeOption::update() { | |||
m_option_name = "memory_optimize"; | |||
disable_mem_opt = FLAGS_disable_mem_opt; | |||
workspace_limit = FLAGS_workspace_limit; | |||
@@ -432,6 +437,7 @@ bool MemoryOptimizeOption::is_valid() { | |||
std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() { | |||
static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption); | |||
if (MemoryOptimizeOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -451,7 +457,7 @@ void JITOption::config_model_internel<ModelLite>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& config_option = model->get_config().options; | |||
if (enable_jit) { | |||
LITE_WARN("enable JIT (level 1)"); | |||
LITE_LOG("enable JIT (level 1)"); | |||
config_option.jit_level = 1; | |||
} | |||
} | |||
@@ -463,13 +469,13 @@ void JITOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (enable_jit) { | |||
mgb_log_warn("enable JIT (level 1)"); | |||
mgb_log("enable JIT (level 1)"); | |||
graph_option.graph_opt.jit = 1; | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
JITOption::JITOption() { | |||
void JITOption::update() { | |||
m_option_name = "JIT"; | |||
enable_jit = FLAGS_enable_jit; | |||
} | |||
@@ -482,6 +488,7 @@ bool JITOption::is_valid() { | |||
std::shared_ptr<OptionBase> JITOption::create_option() { | |||
static std::shared_ptr<JITOption> option(new JITOption); | |||
if (JITOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -500,12 +507,12 @@ void TensorRTOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (!tensorrt_cache.empty()) { | |||
LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||
LITE_LOG("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||
lite::set_tensor_rt_cache(tensorrt_cache); | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (enable_tensorrt) { | |||
LITE_WARN("enable TensorRT"); | |||
LITE_LOG("enable TensorRT"); | |||
lite::Runtime::use_tensorrt(model->get_lite_network()); | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
@@ -521,11 +528,11 @@ void TensorRTOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (enable_tensorrt) { | |||
mgb_log_warn("using tensorRT"); | |||
mgb_log("using tensorRT"); | |||
graph_option.graph_opt.tensorrt = true; | |||
} | |||
if (!tensorrt_cache.empty()) { | |||
mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||
mgb_log("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||
mgb::TensorRTEngineCache::enable_engine_cache(true); | |||
mgb::TensorRTEngineCache::set_impl( | |||
std::make_shared<mgb::TensorRTEngineCacheIO>( | |||
@@ -541,7 +548,7 @@ void TensorRTOption::config_model_internel<ModelMdl>( | |||
} | |||
} // namespace lar | |||
TensorRTOption::TensorRTOption() { | |||
void TensorRTOption::update() { | |||
m_option_name = "tensorRT"; | |||
enable_tensorrt = FLAGS_tensorrt; | |||
tensorrt_cache = FLAGS_tensorrt_cache; | |||
@@ -556,6 +563,7 @@ bool TensorRTOption::is_valid() { | |||
std::shared_ptr<OptionBase> TensorRTOption::create_option() { | |||
static std::shared_ptr<TensorRTOption> option(new TensorRTOption); | |||
if (TensorRTOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -39,8 +39,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
FusePreprocessOption(); | |||
FusePreprocessOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -65,8 +67,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
WeightPreprocessOption(); | |||
WeightPreprocessOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -91,8 +95,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
FuseConvBiasNonlinearOption(); | |||
FuseConvBiasNonlinearOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -117,8 +123,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
FuseConvBiasElemwiseAddOption(); | |||
FuseConvBiasElemwiseAddOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
std::string m_option_name; | |||
@@ -143,8 +151,10 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
GraphRecordOption(); | |||
GraphRecordOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -169,8 +179,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
MemoryOptimizeOption(); | |||
MemoryOptimizeOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -191,8 +203,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
JITOption(); | |||
JITOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -212,8 +226,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
TensorRTOption(); | |||
TensorRTOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
@@ -28,6 +28,10 @@ public: | |||
//! get option map | |||
virtual OptionValMap* get_option() { return nullptr; } | |||
//! update option value | |||
virtual void update(){}; | |||
virtual ~OptionBase() = default; | |||
}; | |||
@@ -22,10 +22,10 @@ void PluginOption::config_model_internel<ModelLite>( | |||
else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (!profile_path.empty()) { | |||
if (!enable_profile_host) { | |||
LITE_WARN("enable profiling"); | |||
LITE_LOG("enable profiling"); | |||
model->get_lite_network()->enable_profile_performance(profile_path); | |||
} else { | |||
LITE_WARN("enable profiling for host"); | |||
LITE_LOG("enable profiling for host"); | |||
model->get_lite_network()->enable_profile_performance(profile_path); | |||
} | |||
} | |||
@@ -39,18 +39,18 @@ void PluginOption::config_model_internel<ModelMdl>( | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& config = model->get_mdl_config(); | |||
if (range > 0) { | |||
mgb_log_warn("enable number range check"); | |||
mgb_log("enable number range check"); | |||
model->set_num_range_checker(float(range)); | |||
} | |||
if (enable_check_dispatch) { | |||
mgb_log_warn("enable cpu dispatch check"); | |||
mgb_log("enable cpu dispatch check"); | |||
cpu_dispatch_checker = | |||
std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get()); | |||
} | |||
if (!var_value_check_str.empty()) { | |||
mgb_log_warn("enable variable value check"); | |||
mgb_log("enable variable value check"); | |||
size_t init_idx = 0, switch_interval; | |||
auto sep = var_value_check_str.find(':'); | |||
if (sep != std::string::npos) { | |||
@@ -67,9 +67,9 @@ void PluginOption::config_model_internel<ModelMdl>( | |||
if (!profile_path.empty()) { | |||
if (!enable_profile_host) { | |||
mgb_log_warn("enable profiling"); | |||
mgb_log("enable profiling"); | |||
} else { | |||
mgb_log_warn("enable profiling for host"); | |||
mgb_log("enable profiling for host"); | |||
} | |||
model->set_profiler(); | |||
} | |||
@@ -79,12 +79,11 @@ void PluginOption::config_model_internel<ModelMdl>( | |||
else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
#if MGB_ENABLE_JSON | |||
if (!profile_path.empty()) { | |||
mgb_log_warn("filename %s", profile_path.c_str()); | |||
if (model->get_profiler()) { | |||
model->get_profiler() | |||
->to_json_full(model->get_async_func().get()) | |||
->writeto_fpath(profile_path); | |||
mgb_log_warn("profiling result written to %s", profile_path.c_str()); | |||
mgb_log("profiling result written to %s", profile_path.c_str()); | |||
} | |||
} | |||
#endif | |||
@@ -94,7 +93,7 @@ void PluginOption::config_model_internel<ModelMdl>( | |||
} // namespace lar | |||
using namespace lar; | |||
PluginOption::PluginOption() { | |||
void PluginOption::update() { | |||
m_option_name = "plugin"; | |||
range = FLAGS_range; | |||
enable_check_dispatch = FLAGS_check_dispatch; | |||
@@ -125,6 +124,7 @@ bool PluginOption::is_valid() { | |||
std::shared_ptr<OptionBase> PluginOption::create_option() { | |||
static std::shared_ptr<PluginOption> option(new PluginOption); | |||
if (PluginOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -199,7 +199,7 @@ void DebugOption::format_and_print( | |||
std::stringstream ss; | |||
ss << table; | |||
printf("%s\n\n", ss.str().c_str()); | |||
LITE_LOG("%s\n\n", ss.str().c_str()); | |||
} | |||
template <> | |||
@@ -243,7 +243,7 @@ void DebugOption::format_and_print( | |||
std::stringstream ss; | |||
ss << table; | |||
printf("%s\n\n", ss.str().c_str()); | |||
mgb_log("%s\n\n", ss.str().c_str()); | |||
} | |||
template <> | |||
@@ -260,7 +260,7 @@ void DebugOption::config_model_internel<ModelLite>( | |||
#endif | |||
#endif | |||
if (enable_verbose) { | |||
LITE_WARN("enable verbose"); | |||
LITE_LOG("enable verbose"); | |||
lite::set_log_level(LiteLogLevel::DEBUG); | |||
} | |||
@@ -272,7 +272,7 @@ void DebugOption::config_model_internel<ModelLite>( | |||
#endif | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
if (enable_display_model_info) { | |||
LITE_WARN("enable display model information"); | |||
LITE_LOG("enable display model information"); | |||
format_and_print<ModelLite>("Runtime Model Info", model); | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
@@ -287,7 +287,7 @@ void DebugOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (enable_verbose) { | |||
mgb_log_warn("enable verbose"); | |||
mgb_log("enable verbose"); | |||
mgb::set_log_level(mgb::LogLevel::DEBUG); | |||
} | |||
@@ -299,21 +299,21 @@ void DebugOption::config_model_internel<ModelMdl>( | |||
#endif | |||
} else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { | |||
if (enable_display_model_info) { | |||
mgb_log_warn("enable display model information"); | |||
mgb_log("enable display model information"); | |||
format_and_print<ModelMdl>("Runtime Model Info", model); | |||
} | |||
} else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | |||
#ifndef __IN_TEE_ENV__ | |||
#if MGB_ENABLE_JSON | |||
if (!static_mem_log_dir_path.empty()) { | |||
mgb_log_warn("enable get static memeory information"); | |||
mgb_log("enable get static memeory information"); | |||
model->get_async_func()->get_static_memory_alloc_info( | |||
static_mem_log_dir_path); | |||
} | |||
#endif | |||
#endif | |||
if (disable_assert_throw) { | |||
mgb_log_warn("disable assert throw"); | |||
mgb_log("disable assert throw"); | |||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||
if (opr->same_type<mgb::opr::AssertEqual>()) { | |||
opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error(); | |||
@@ -333,7 +333,7 @@ void DebugOption::config_model_internel<ModelMdl>( | |||
} // namespace lar | |||
DebugOption::DebugOption() { | |||
void DebugOption::update() { | |||
m_option_name = "debug"; | |||
enable_display_model_info = FLAGS_model_info; | |||
enable_verbose = FLAGS_verbose; | |||
@@ -367,6 +367,7 @@ bool DebugOption::is_valid() { | |||
std::shared_ptr<OptionBase> DebugOption::create_option() { | |||
static std::shared_ptr<DebugOption> option(new DebugOption); | |||
if (DebugOption::is_valid()) { | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
@@ -44,8 +44,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
PluginOption(); | |||
PluginOption() = default; | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
double range; | |||
@@ -74,8 +76,10 @@ public: | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
DebugOption(); | |||
DebugOption() = default; | |||
template <typename ModelImpl> | |||
void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){}; | |||
template <typename ModelImpl> | |||
@@ -5,7 +5,7 @@ using namespace lar; | |||
DECLARE_bool(c_opr_lib_with_param); | |||
DECLARE_bool(fitting); | |||
StrategyOption::StrategyOption() { | |||
void StrategyOption::update() { | |||
m_option_name = "run_strategy"; | |||
warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter; | |||
run_iter = FLAGS_fitting ? 10 : FLAGS_iter; | |||
@@ -20,6 +20,7 @@ StrategyOption::StrategyOption() { | |||
std::shared_ptr<OptionBase> StrategyOption::create_option() { | |||
static std::shared_ptr<StrategyOption> option(new StrategyOption); | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} | |||
@@ -43,12 +44,13 @@ void StrategyOption::config_model( | |||
} | |||
} | |||
TestcaseOption::TestcaseOption() { | |||
void TestcaseOption::update() { | |||
m_option_name = "run_testcase"; | |||
} | |||
std::shared_ptr<OptionBase> TestcaseOption::create_option() { | |||
static std::shared_ptr<TestcaseOption> option(new TestcaseOption); | |||
option->update(); | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} | |||
@@ -25,9 +25,11 @@ public: | |||
OptionValMap* get_option() override { return &m_option; } | |||
void update() override; | |||
private: | |||
//! Constructor | |||
StrategyOption(); | |||
StrategyOption() = default; | |||
//! configuration for different model implement | |||
std::string m_option_name; | |||
@@ -52,9 +54,11 @@ public: | |||
//! get option name | |||
std::string option_name() const override { return m_option_name; }; | |||
void update() override; | |||
private: | |||
//! Constructor | |||
TestcaseOption(); | |||
TestcaseOption() = default; | |||
//! configuration for different model implement | |||
std::string m_option_name; | |||
@@ -233,8 +233,8 @@ void OptionsTimeProfiler::profile_with_given_options( | |||
"the log) when profile option:\n%s\n", | |||
option_code.c_str()); | |||
} else { | |||
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), | |||
average); | |||
mgb_log("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), | |||
average); | |||
//! record profile result | |||
m_options_profile_result.insert({option_code, average}); | |||
@@ -370,7 +370,6 @@ void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) { | |||
FittingStrategy::FittingStrategy(std::string model_path) { | |||
m_manager = std::make_shared<OptionsFastManager>(); | |||
m_dumped_model = FLAGS_dump_fitting_model; | |||
mgb::set_log_level(mgb::LogLevel::INFO); | |||
m_options = std::make_shared<OptionMap>(); | |||
m_model_path = model_path; | |||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||
@@ -518,10 +517,10 @@ void FittingStrategy::AutoCleanFile::dump_model() { | |||
void FittingStrategy::run() { | |||
auto mgb_version = mgb::get_version(); | |||
auto dnn_version = megdnn::get_version(); | |||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||
dnn_version.major, dnn_version.minor, dnn_version.patch); | |||
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain " | |||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||
dnn_version.major, dnn_version.minor, dnn_version.patch); | |||
// ! create profiler with given user info | |||
m_info_parser.get_user_info(); | |||
m_info_parser.parse_info(m_manager); | |||
@@ -5,13 +5,10 @@ | |||
#include "megbrain/utils/timer.h" | |||
#include "megbrain/version.h" | |||
#include "megdnn/version.h" | |||
#include "misc.h" | |||
using namespace lar; | |||
NormalStrategy::NormalStrategy(std::string model_path) { | |||
mgb::set_log_level(mgb::LogLevel::WARN); | |||
lite::set_log_level(LiteLogLevel::WARN); | |||
m_options = std::make_shared<OptionMap>(); | |||
m_model_path = model_path; | |||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||
@@ -47,7 +44,7 @@ void NormalStrategy::run_subline() { | |||
mgb::RealTimer timer; | |||
model->load_model(); | |||
printf("load model: %.3fms\n", timer.get_msecs_reset()); | |||
mgb_log("load model: %.3fms\n", timer.get_msecs_reset()); | |||
//! after load configure | |||
auto config_after_load = [&]() { | |||
@@ -62,10 +59,10 @@ void NormalStrategy::run_subline() { | |||
auto warm_up = [&]() { | |||
auto warmup_num = m_runtime_param.warmup_iter; | |||
for (size_t i = 0; i < warmup_num; i++) { | |||
printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset()); | |||
mgb_log("=== prepare: %.3fms; going to warmup", timer.get_msecs_reset()); | |||
model->run_model(); | |||
model->wait(); | |||
printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset()); | |||
mgb_log("warm up %lu %.3fms", i, timer.get_msecs_reset()); | |||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||
stage_config_model(); | |||
} | |||
@@ -83,21 +80,21 @@ void NormalStrategy::run_subline() { | |||
auto cur = timer.get_msecs(); | |||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||
stage_config_model(); | |||
printf("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)\n", i, run_num, cur, | |||
exec_time); | |||
mgb_log("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)", i, run_num, cur, | |||
exec_time); | |||
time_sum += cur; | |||
time_sqrsum += cur * cur; | |||
fflush(stdout); | |||
min_time = std::min(min_time, cur); | |||
max_time = std::max(max_time, cur); | |||
} | |||
printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms\n\n", | |||
idx, time_sum, time_sum / run_num, | |||
std::sqrt( | |||
(time_sqrsum * run_num - time_sum * time_sum) / | |||
(run_num * (run_num - 1))), | |||
min_time, max_time); | |||
mgb_log("=== finished test #%u: time=%.3f ms avg_time=%.3f ms " | |||
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms", | |||
idx, time_sum, time_sum / run_num, | |||
std::sqrt( | |||
(time_sqrsum * run_num - time_sum * time_sum) / | |||
(run_num * (run_num - 1))), | |||
min_time, max_time); | |||
return time_sum; | |||
}; | |||
@@ -122,7 +119,7 @@ void NormalStrategy::run_subline() { | |||
stage_config_model(); | |||
} | |||
printf("=== total time: %.3fms\n", tot_time); | |||
mgb_log("=== total time: %.3fms\n", tot_time); | |||
//! execute after run | |||
m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | |||
stage_config_model(); | |||
@@ -131,9 +128,9 @@ void NormalStrategy::run_subline() { | |||
void NormalStrategy::run() { | |||
auto v0 = mgb::get_version(); | |||
auto v1 = megdnn::get_version(); | |||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain " | |||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||
size_t thread_num = m_runtime_param.threads; | |||
auto run_sub = [&]() { run_subline(); }; | |||
@@ -73,7 +73,7 @@ public: | |||
#define LITE_LOG_(level, msg...) (void)0 | |||
#endif | |||
#define LITE_LOG(fmt...) LITE_LOG_(DEBUG, fmt); | |||
#define LITE_LOG(fmt...) LITE_LOG_(INFO, fmt); | |||
#define LITE_DEBUG(fmt...) LITE_LOG_(DEBUG, fmt); | |||
#define LITE_WARN(fmt...) LITE_LOG_(WARN, fmt); | |||
#define LITE_ERROR(fmt...) LITE_LOG_(ERROR, fmt); | |||
@@ -1,14 +1,14 @@ | |||
if(MGE_WITH_TEST) | |||
include_directories(PUBLIC | |||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||
file(GLOB_RECURSE SOURCES ./*.cpp main.cpp) | |||
add_executable(lite_test ${SOURCES}) | |||
target_link_libraries(lite_test lar_object) | |||
target_link_libraries(lite_test gtest) | |||
target_link_libraries(lite_test lite_static) | |||
if(LITE_BUILD_WITH_MGE) | |||
# lite_test will depends megbrain interface | |||
target_link_libraries(lite_test megbrain) | |||
if(MGE_WITH_ROCM) | |||
# FIXME: hip obj can not find cpp obj only through lite_static | |||
message(WARNING "MGE_WITH_ROCM is valid link to megdnn") | |||
target_link_libraries(lite_test megdnn) | |||
endif() | |||
endif() | |||
@@ -0,0 +1,85 @@ | |||
#include <gtest/gtest.h> | |||
#include <string.h> | |||
#include <memory> | |||
#include "test_options.h" | |||
using namespace lar; | |||
DECLARE_bool(lite); | |||
DECLARE_bool(cpu); | |||
#if LITE_WITH_CUDA | |||
DECLARE_bool(cuda); | |||
#endif | |||
DECLARE_bool(enable_nchw4); | |||
DECLARE_bool(enable_chwn4); | |||
DECLARE_bool(enable_nchw44); | |||
DECLARE_bool(enable_nchw88); | |||
DECLARE_bool(enable_nchw32); | |||
DECLARE_bool(enable_nchw64); | |||
DECLARE_bool(enable_nhwcd4); | |||
DECLARE_bool(enable_nchw44_dot); | |||
namespace { | |||
BOOL_OPTION_WRAP(enable_nchw4); | |||
BOOL_OPTION_WRAP(enable_chwn4); | |||
BOOL_OPTION_WRAP(enable_nchw44); | |||
BOOL_OPTION_WRAP(enable_nchw88); | |||
BOOL_OPTION_WRAP(enable_nchw32); | |||
BOOL_OPTION_WRAP(enable_nchw64); | |||
BOOL_OPTION_WRAP(enable_nhwcd4); | |||
BOOL_OPTION_WRAP(enable_nchw44_dot); | |||
BOOL_OPTION_WRAP(lite); | |||
BOOL_OPTION_WRAP(cpu); | |||
#if LITE_WITH_CUDA | |||
BOOL_OPTION_WRAP(cuda); | |||
#endif | |||
} // anonymous namespace | |||
TEST(TestLarLayout, X86_CPU) { | |||
DEFINE_WRAP(cpu); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(enable_nchw4); | |||
TEST_BOOL_OPTION(enable_chwn4); | |||
TEST_BOOL_OPTION(enable_nchw44); | |||
TEST_BOOL_OPTION(enable_nchw44_dot); | |||
TEST_BOOL_OPTION(enable_nchw64); | |||
TEST_BOOL_OPTION(enable_nchw32); | |||
TEST_BOOL_OPTION(enable_nchw88); | |||
} | |||
TEST(TestLarLayout, X86_CPU_LITE) { | |||
DEFINE_WRAP(cpu); | |||
DEFINE_WRAP(lite); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(enable_nchw4); | |||
TEST_BOOL_OPTION(enable_nchw44); | |||
TEST_BOOL_OPTION(enable_nchw44_dot); | |||
TEST_BOOL_OPTION(enable_nchw64); | |||
TEST_BOOL_OPTION(enable_nchw32); | |||
TEST_BOOL_OPTION(enable_nchw88); | |||
} | |||
#if LITE_WITH_CUDA | |||
TEST(TestLarLayout, CUDA) { | |||
DEFINE_WRAP(cuda); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(enable_nchw4); | |||
TEST_BOOL_OPTION(enable_chwn4); | |||
TEST_BOOL_OPTION(enable_nchw64); | |||
TEST_BOOL_OPTION(enable_nchw32); | |||
FLAGS_cuda = false; | |||
} | |||
TEST(TestLarLayout, CUDA_LITE) { | |||
DEFINE_WRAP(cuda); | |||
DEFINE_WRAP(lite); | |||
std::string model_path = "./shufflenet.mge"; | |||
TEST_BOOL_OPTION(enable_nchw4); | |||
TEST_BOOL_OPTION(enable_nchw64); | |||
TEST_BOOL_OPTION(enable_nchw32); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,12 @@ | |||
#include "test_options.h" | |||
using namespace lar; | |||
void lar::run_NormalStrategy(std::string model_path) { | |||
auto origin_level = mgb::get_log_level(); | |||
mgb::set_log_level(mgb::LogLevel::WARN); | |||
NormalStrategy strategy(model_path); | |||
strategy.run(); | |||
mgb::set_log_level(origin_level); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,33 @@ | |||
#pragma once | |||
#include <iostream> | |||
#include <thread> | |||
#include "../load_and_run/src/strategys/strategy.h" | |||
#include "../load_and_run/src/strategys/strategy_normal.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/utils/timer.h" | |||
#include "megbrain/version.h" | |||
#include "megdnn/version.h" | |||
#include "misc.h" | |||
namespace lar { | |||
//! run load_and_run NormalStrategy to test different options | |||
void run_NormalStrategy(std::string model_path); | |||
} // namespace lar | |||
#define BOOL_OPTION_WRAP(option) \ | |||
struct BoolOptionWrap_##option { \ | |||
BoolOptionWrap_##option() { FLAGS_##option = true; } \ | |||
~BoolOptionWrap_##option() { FLAGS_##option = false; } \ | |||
}; | |||
#define DEFINE_WRAP(option) BoolOptionWrap_##option flags_##option; | |||
#define TEST_BOOL_OPTION(option) \ | |||
{ \ | |||
BoolOptionWrap_##option flags_##option; \ | |||
run_NormalStrategy(model_path); \ | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |