|
|
@@ -116,35 +116,22 @@ void NetworkImplDft::application_config() { |
|
|
|
m_load_config.tensor_value_loader = decompressed_tensor_value_loader; |
|
|
|
} |
|
|
|
|
|
|
|
//! if device is LITE_NONE, the compnode information is stored in model |
|
|
|
//! if device is LITE_NONE, the compnode information is stored in model or |
|
|
|
//! xpu in MegEngine |
|
|
|
if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) { |
|
|
|
//! currently not set Locator type because an atlas mgb model is a |
|
|
|
//! cross-compnode graph |
|
|
|
if (device_type == LiteDeviceType::LITE_ATLAS) { |
|
|
|
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { |
|
|
|
if (loc.type == mgb::CompNode::DeviceType::ATLAS) { |
|
|
|
loc.device = m_compnode_locator.device; |
|
|
|
loc.stream = m_compnode_locator.stream; |
|
|
|
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { |
|
|
|
loc.stream = m_nr_threads; |
|
|
|
} |
|
|
|
}; |
|
|
|
//! currently not set Locator type because a cambricon mgb model is a |
|
|
|
//! cross-compnode graph |
|
|
|
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) { |
|
|
|
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { |
|
|
|
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) { |
|
|
|
loc.device = m_compnode_locator.device; |
|
|
|
loc.stream = m_compnode_locator.stream; |
|
|
|
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { |
|
|
|
loc.stream = m_nr_threads; |
|
|
|
} |
|
|
|
}; |
|
|
|
} else { |
|
|
|
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { |
|
|
|
loc = m_compnode_locator; |
|
|
|
}; |
|
|
|
} |
|
|
|
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { |
|
|
|
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { |
|
|
|
loc.type = m_compnode_locator.type; |
|
|
|
} |
|
|
|
loc.device = m_compnode_locator.device; |
|
|
|
//! if user set the thread number and the compnode is multithread |
|
|
|
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD && |
|
|
|
m_nr_threads != 1) { |
|
|
|
loc.stream = m_nr_threads; |
|
|
|
} else { |
|
|
|
loc.stream = m_compnode_locator.stream; |
|
|
|
} |
|
|
|
}; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|