|
|
@@ -80,7 +80,9 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( |
|
|
|
if (enable_cuda) { |
|
|
|
mgb_log_warn("using cuda device\n"); |
|
|
|
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { |
|
|
|
loc.type = mgb::CompNode::DeviceType::CUDA; |
|
|
|
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { |
|
|
|
loc.type = mgb::CompNode::DeviceType::CUDA; |
|
|
|
} |
|
|
|
loc.device = 0; |
|
|
|
}; |
|
|
|
} |
|
|
|