Browse Source

fix(lite): fix lite compnode mapper

GitOrigin-RevId: 994308b511
tags/v1.7.2.m1
Megvii Engine Team 3 years ago
parent
commit
643ab1c12b
2 changed files with 16 additions and 29 deletions
  1. +1
    -1
      lite/include/lite/common_enum_c.h
  2. +15
    -28
      lite/src/mge/network_impl.cpp

+ 1
- 1
lite/include/lite/common_enum_c.h View File

@@ -33,7 +33,7 @@ typedef enum {
LITE_NPU = 4,
LITE_CAMBRICON = 5,
//! when the device information is set in model, so set LITE_DEVICE_DEFAULT
//! in lite
//! in lite, which equal to xpu in megengine
LITE_DEVICE_DEFAULT = 6,
} LiteDeviceType;



+ 15
- 28
lite/src/mge/network_impl.cpp View File

@@ -116,35 +116,22 @@ void NetworkImplDft::application_config() {
m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
}

//! if device is LITE_NONE, the compnode information is stored in model
//! if device is LITE_NONE, the compnode information is stored in model or
//! xpu in MegEngine
if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
//! currently not set Locator type because an atlas mgb model is a
//! cross-compnode graph
if (device_type == LiteDeviceType::LITE_ATLAS) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::ATLAS) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
} else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
loc = m_compnode_locator;
};
}
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.type = m_compnode_locator.type;
}
loc.device = m_compnode_locator.device;
//! if user set the thread number and the compnode is multithread
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
m_nr_threads != 1) {
loc.stream = m_nr_threads;
} else {
loc.stream = m_compnode_locator.stream;
}
};
}
}



Loading…
Cancel
Save