|
|
@@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { |
|
|
|
void NetworkImplDft::application_config() { |
|
|
|
auto device_type = m_user_config->device_type; |
|
|
|
m_compnode_locator.type = to_compnode_locator(device_type).type; |
|
|
|
m_compnode_locator.device = m_user_config->device_id; |
|
|
|
//! when the device id is not configured, configure it |
|
|
|
if (m_compnode_locator.device == -1) { |
|
|
|
m_compnode_locator.device = m_user_config->device_id; |
|
|
|
} |
|
|
|
if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { |
|
|
|
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; |
|
|
|
m_compnode_locator.device = m_user_config->device_id; |
|
|
|
if (m_compnode_locator.device == -1) { |
|
|
|
m_compnode_locator.device = m_user_config->device_id; |
|
|
|
} |
|
|
|
} |
|
|
|
//! model options |
|
|
|
#define ConfigOption(mge_name, lite_name) \ |
|
|
@@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() { |
|
|
|
m_is_cpu_inplace_mode = true; |
|
|
|
if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { |
|
|
|
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; |
|
|
|
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; |
|
|
|
} else { |
|
|
|
LITE_ASSERT( |
|
|
|
m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, |
|
|
|
"cpu inplace mode is only avaliable in CPU."); |
|
|
|
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; |
|
|
|
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) { |
|
|
|
if (nr_threads > 1) { |
|
|
|
m_nr_threads = nr_threads; |
|
|
|
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; |
|
|
|
if (m_is_cpu_inplace_mode) { |
|
|
|
m_compnode_locator.device = |
|
|
|
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; |
|
|
|
m_user_config->device_id = |
|
|
|
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; |
|
|
|
} |
|
|
|
m_compnode_locator.nr_threads = nr_threads; |
|
|
|
} |
|
|
|
} |
|
|
|