You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_task_builder.cc 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "single_op/task/tbe_task_builder.h"
  17. #include <mutex>
  18. #include <vector>
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "graph/load/model_manager/model_utils.h"
  21. #include "graph/manager/graph_var_manager.h"
  22. #include "runtime/rt.h"
  23. #include "single_op/task/build_task_utils.h"
  24. namespace ge {
  25. namespace {
  26. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  27. constexpr char const *kAttrOpParamSize = "op_para_size";
  28. std::mutex g_reg_mutex;
  29. inline void GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) {
  30. (void)AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
  31. }
  32. inline TBEKernelPtr GetTbeKernel(const OpDescPtr &op_desc) {
  33. return op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
  34. }
  35. } // namespace
  36. KernelHolder::KernelHolder(const char *stub_func, std::shared_ptr<ge::OpKernelBin> kernel_bin)
  37. : stub_func_(stub_func), bin_handle_(nullptr), kernel_bin_(std::move(kernel_bin)) {}
  38. KernelHolder::~KernelHolder() {
  39. if (bin_handle_ != nullptr) {
  40. GE_CHK_RT(rtDevBinaryUnRegister(bin_handle_));
  41. }
  42. }
  43. HandleHolder::HandleHolder(void *bin_handle)
  44. : bin_handle_(bin_handle) {}
  45. HandleHolder::~HandleHolder() {
  46. if (bin_handle_ != nullptr) {
  47. GE_CHK_RT(rtDevBinaryUnRegister(bin_handle_));
  48. }
  49. }
  50. const char *KernelBinRegistry::GetUnique(const string &stub_func) {
  51. std::lock_guard<std::mutex> lock(mutex_);
  52. auto it = unique_stubs_.find(stub_func);
  53. if (it != unique_stubs_.end()) {
  54. return it->c_str();
  55. } else {
  56. it = unique_stubs_.insert(unique_stubs_.end(), stub_func);
  57. return it->c_str();
  58. }
  59. }
  60. const char *KernelBinRegistry::GetStubFunc(const std::string &stub_name) {
  61. std::lock_guard<std::mutex> lock(mutex_);
  62. auto iter = registered_bins_.find(stub_name);
  63. if (iter != registered_bins_.end()) {
  64. return iter->second->stub_func_;
  65. }
  66. return nullptr;
  67. }
  68. bool KernelBinRegistry::AddKernel(const std::string &stub_name, std::unique_ptr<KernelHolder> &&holder) {
  69. std::lock_guard<std::mutex> lock(mutex_);
  70. auto ret = registered_bins_.emplace(stub_name, std::move(holder));
  71. return ret.second;
  72. }
  73. bool HandleRegistry::AddHandle(std::unique_ptr<HandleHolder> &&holder) {
  74. auto ret = registered_handles_.emplace(std::move(holder));
  75. return ret.second;
  76. }
  77. TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::TaskDef &task_def)
  78. : node_(node),
  79. op_desc_(node->GetOpDesc()),
  80. task_def_(task_def),
  81. kernel_def_(task_def.kernel()),
  82. kernel_def_with_handle_(task_def.kernel_with_handle()),
  83. stub_name_(model_name + "/" + node->GetName() + "_tvmbin") {}
  84. Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle,
  85. const SingleOpModelParam &param) const {
  86. rtDevBinary_t binary;
  87. binary.version = 0;
  88. binary.data = kernel_bin.GetBinData();
  89. binary.length = kernel_bin.GetBinDataSize();
  90. binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC;
  91. Status ret = 0;
  92. if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) {
  93. ret = rtRegisterAllKernel(&binary, bin_handle);
  94. } else {
  95. ret = rtDevBinaryRegister(&binary, bin_handle);
  96. }
  97. if (ret != RT_ERROR_NONE) {
  98. GELOGE(ret, "DoRegisterBinary failed, bin key = %s, core_type = %ld, rt ret = %d", stub_name_.c_str(),
  99. param.core_type, static_cast<int>(ret));
  100. return ret;
  101. }
  102. return SUCCESS;
  103. }
  104. Status TbeTaskBuilder::DoRegisterMeta(void *bin_handle) {
  105. std::string meta_data;
  106. (void)AttrUtils::GetStr(op_desc_, TVM_ATTR_NAME_METADATA, meta_data);
  107. GELOGI("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
  108. if (!meta_data.empty()) {
  109. auto rt_ret = rtMetadataRegister(bin_handle, meta_data.c_str());
  110. if (rt_ret != RT_ERROR_NONE) {
  111. GELOGE(rt_ret, "rtMetadataRegister failed. bin key = %s, meta_data = %s, rt ret = %d", stub_name_.c_str(),
  112. meta_data.c_str(), static_cast<int>(rt_ret));
  113. return rt_ret;
  114. }
  115. }
  116. return SUCCESS;
  117. }
  118. Status TbeTaskBuilder::DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name) {
  119. auto rt_ret = rtFunctionRegister(bin_handle, stub_name, stub_name, kernel_name, FUNC_MODE_NORMAL);
  120. if (rt_ret != RT_ERROR_NONE) {
  121. GELOGE(rt_ret, "rtFunctionRegister failed. bin key = %s, kernel name = %s, rt ret = %d", stub_name, kernel_name,
  122. static_cast<int>(rt_ret));
  123. return rt_ret;
  124. }
  125. return SUCCESS;
  126. }
  127. Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, void **bin_handle,
  128. const SingleOpModelParam &param) {
  129. void *handle = nullptr;
  130. auto ret = DoRegisterBinary(tbe_kernel, &handle, param);
  131. if (ret != SUCCESS) {
  132. return ret;
  133. }
  134. if (task_def_.type() == RT_MODEL_TASK_ALL_KERNEL) {
  135. *bin_handle = handle;
  136. return SUCCESS;
  137. }
  138. ret = DoRegisterMeta(handle);
  139. if (ret != SUCCESS) {
  140. GE_CHK_RT(rtDevBinaryUnRegister(handle));
  141. return ret;
  142. }
  143. std::string kernel_name;
  144. GetKernelName(op_desc_, kernel_name);
  145. ret = DoRegisterFunction(handle, bin_file_key, kernel_name.c_str());
  146. if (ret != SUCCESS) {
  147. GE_CHK_RT(rtDevBinaryUnRegister(handle));
  148. return ret;
  149. }
  150. GELOGI("Register function succeeded: kernel_name = %s", kernel_name.c_str());
  151. *bin_handle = handle;
  152. return SUCCESS;
  153. }
  154. Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam &param) {
  155. KernelBinRegistry &registry = KernelBinRegistry::GetInstance();
  156. // check if already registered
  157. const char *stub_func = registry.GetStubFunc(stub_name_);
  158. if (stub_func != nullptr) {
  159. task.SetStubFunc(stub_name_, stub_func);
  160. return SUCCESS;
  161. }
  162. // to avoid repeat register
  163. std::lock_guard<std::mutex> lock(g_reg_mutex);
  164. // check again
  165. stub_func = registry.GetStubFunc(stub_name_);
  166. if (stub_func == nullptr) {
  167. stub_func = registry.GetUnique(stub_name_);
  168. GELOGI("RegisterKernel begin, stub_func = %s", stub_func);
  169. auto tbe_kernel = GetTbeKernel(op_desc_);
  170. if (tbe_kernel == nullptr) {
  171. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s",
  172. op_desc_->GetName().c_str());
  173. return ACL_ERROR_GE_INTERNAL_ERROR;
  174. }
  175. auto holder = std::unique_ptr<KernelHolder>(new (std::nothrow) KernelHolder(stub_func, tbe_kernel));
  176. if (holder == nullptr) {
  177. GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "create KernelHodler failed.");
  178. return ACL_ERROR_GE_MEMORY_ALLOCATION;
  179. }
  180. void *bin_handle = nullptr;
  181. auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle, param);
  182. if (ret != SUCCESS) {
  183. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "RegisterKernel failed. stub name = %s", stub_name_.c_str());
  184. return ACL_ERROR_GE_INTERNAL_ERROR;
  185. }
  186. holder->SetBinHandle(bin_handle);
  187. if (!registry.AddKernel(stub_name_, std::move(holder))) {
  188. // should not happen. only one thread can reach here
  189. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "Add kernel failed. stub name = %s", stub_name_.c_str());
  190. return ACL_ERROR_GE_INTERNAL_ERROR;
  191. }
  192. }
  193. task.SetStubFunc(stub_name_, stub_func);
  194. return SUCCESS;
  195. }
  196. Status TbeTaskBuilder::RegisterKernelWithHandle(TbeOpTask &task, const SingleOpModelParam &param) {
  197. GELOGD("RegisterKernelWithHandle begin.");
  198. HandleRegistry &registry = HandleRegistry::GetInstance();
  199. auto tbe_kernel = GetTbeKernel(op_desc_);
  200. if (tbe_kernel == nullptr) {
  201. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s",
  202. op_desc_->GetName().c_str());
  203. return ACL_ERROR_GE_INTERNAL_ERROR;
  204. }
  205. void *bin_handle = nullptr;
  206. auto ret = DoRegisterKernel(*tbe_kernel, nullptr, &bin_handle, param);
  207. if (ret != SUCCESS) {
  208. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "RegisterKernel failed. node name = %s", op_desc_->GetName().c_str());
  209. return ACL_ERROR_GE_INTERNAL_ERROR;
  210. }
  211. handle_ = bin_handle;
  212. auto holder = std::unique_ptr<HandleHolder>(new (std::nothrow) HandleHolder(handle_));
  213. if (holder == nullptr) {
  214. GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "create HandleHodler failed.");
  215. return ACL_ERROR_GE_MEMORY_ALLOCATION;
  216. }
  217. if (!registry.AddHandle(std::move(holder))) {
  218. GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "Add handle failed. node name = %s", op_desc_->GetName().c_str());
  219. return ACL_ERROR_GE_INTERNAL_ERROR;
  220. }
  221. return SUCCESS;
  222. }
  223. Status TbeTaskBuilder::GetSmDesc(void **sm_desc, const SingleOpModelParam &param) const {
  224. const std::string &sm_desc_str = kernel_def_.sm_desc();
  225. if (sm_desc_str.empty()) {
  226. *sm_desc = nullptr;
  227. } else {
  228. GELOGD("To process sm desc, size = %zu", sm_desc_str.size());
  229. char *sm_control = const_cast<char *>(sm_desc_str.data());
  230. auto *l2_ctrl_info = reinterpret_cast<rtL2Ctrl_t *>(sm_control);
  231. uint64_t gen_base_addr = param.base_addr;
  232. // There is no weight for te op now. Update L2_mirror_addr by data memory base.
  233. uint64_t data_base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(param.mem_base)) - gen_base_addr;
  234. for (auto &data_index : l2_ctrl_info->data) {
  235. if (data_index.L2_mirror_addr != 0) {
  236. data_index.L2_mirror_addr += data_base_addr;
  237. }
  238. }
  239. auto rt_ret = rtMemAllocManaged(sm_desc, sm_desc_str.size(), RT_MEMORY_SPM);
  240. if (rt_ret != RT_ERROR_NONE) {
  241. GELOGE(rt_ret, "rtMemAllocManaged failed, ret: %d", static_cast<int>(rt_ret));
  242. return rt_ret;
  243. }
  244. rt_ret = rtMemcpy(*sm_desc, sm_desc_str.size(), sm_desc_str.data(), sm_desc_str.size(), RT_MEMCPY_HOST_TO_DEVICE);
  245. if (rt_ret != RT_ERROR_NONE) {
  246. (void)rtMemFreeManaged(*sm_desc);
  247. GELOGE(rt_ret, "rtMemcpy, ret: %d", static_cast<int>(rt_ret));
  248. return rt_ret;
  249. }
  250. }
  251. return SUCCESS;
  252. }
  253. Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam &param, const OpDescPtr &op_desc) {
  254. size_t arg_size = kernel_def_.args_size();
  255. auto args = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[arg_size]);
  256. GE_CHECK_NOTNULL(args);
  257. auto rt_ret = rtMemcpy(args.get(), arg_size, kernel_def_.args().data(), arg_size, RT_MEMCPY_HOST_TO_HOST);
  258. if (rt_ret != RT_ERROR_NONE) {
  259. GELOGE(rt_ret, "rtMemcpy args failed, size = %zu, ret = %d", arg_size, static_cast<int>(rt_ret));
  260. return RT_ERROR_TO_GE_STATUS(rt_ret);
  261. }
  262. const domi::KernelContext &context = kernel_def_.context();
  263. const auto *args_offset_tmp = reinterpret_cast<const uint16_t *>(context.args_offset().data());
  264. uint16_t offset = *args_offset_tmp;
  265. bool is_dynamic = false;
  266. (void)AttrUtils::GetBool(op_desc_, kAttrSupportDynamicShape, is_dynamic);
  267. if (is_dynamic) {
  268. GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(task));
  269. } else {
  270. // copy args
  271. std::vector<void *> tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param);
  272. void *src_addr = reinterpret_cast<void *>(tensor_device_addr_vec.data());
  273. uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size();
  274. rt_ret = rtMemcpy(args.get() + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST);
  275. if (rt_ret != RT_ERROR_NONE) {
  276. GELOGE(rt_ret, "rtMemcpy addresses failed, ret = %d", static_cast<int>(rt_ret));
  277. return RT_ERROR_TO_GE_STATUS(rt_ret);
  278. }
  279. }
  280. task.SetKernelArgs(std::move(args), arg_size, kernel_def_.block_dim(), op_desc);
  281. return SUCCESS;
  282. }
  283. Status TbeTaskBuilder::SetKernelWithHandleArgs(TbeOpTask &task, const SingleOpModelParam &param,
  284. const OpDescPtr &op_desc) {
  285. size_t arg_size = kernel_def_with_handle_.args_size();
  286. auto args = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[arg_size]);
  287. GE_CHECK_NOTNULL(args);
  288. auto rt_ret = rtMemcpy(args.get(), arg_size, kernel_def_with_handle_.args().data(), arg_size, RT_MEMCPY_HOST_TO_HOST);
  289. if (rt_ret != RT_ERROR_NONE) {
  290. GELOGE(rt_ret, "rtMemcpy args failed, size = %zu, ret = %d", arg_size, static_cast<int>(rt_ret));
  291. return rt_ret;
  292. }
  293. const domi::KernelContext &context = kernel_def_with_handle_.context();
  294. const auto *args_offset_tmp = reinterpret_cast<const uint16_t *>(context.args_offset().data());
  295. uint16_t offset = *args_offset_tmp;
  296. bool is_dynamic = false;
  297. (void)AttrUtils::GetBool(op_desc_, kAttrSupportDynamicShape, is_dynamic);
  298. if (is_dynamic) {
  299. GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(task));
  300. } else {
  301. // copy args
  302. std::vector<void *> tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param);
  303. void *src_addr = reinterpret_cast<void *>(tensor_device_addr_vec.data());
  304. uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size();
  305. rt_ret = rtMemcpy(args.get() + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST);
  306. if (rt_ret != RT_ERROR_NONE) {
  307. GELOGE(rt_ret, "rtMemcpy addresses failed, ret = %d", static_cast<int>(rt_ret));
  308. return rt_ret;
  309. }
  310. }
  311. task.SetKernelWithHandleArgs(std::move(args), arg_size, kernel_def_with_handle_.block_dim(), op_desc,
  312. kernel_def_with_handle_);
  313. return SUCCESS;
  314. }
  315. Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam &param) {
  316. GELOGD("Build tbe task begin");
  317. auto task_type = static_cast<rtModelTaskType_t>(task_def_.type());
  318. auto ret = task_type == RT_MODEL_TASK_ALL_KERNEL ? SetKernelWithHandleArgs(task, param, op_desc_) :
  319. SetKernelArgs(task, param, op_desc_);
  320. if (ret != SUCCESS) {
  321. return ret;
  322. }
  323. ret = task_type == RT_MODEL_TASK_ALL_KERNEL ? RegisterKernelWithHandle(task, param) :
  324. RegisterKernel(task, param);
  325. task.SetHandle(handle_);
  326. if (ret != SUCCESS) {
  327. return ret;
  328. }
  329. auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_);
  330. GELOGI("[TASK_INFO] %s %s", stub_name_.c_str(), task_info.c_str());
  331. if (task_type != RT_MODEL_TASK_ALL_KERNEL) {
  332. void *stub_func = nullptr;
  333. auto rt_ret = rtGetFunctionByName(stub_name_.c_str(), &stub_func);
  334. if (rt_ret != SUCCESS) {
  335. GELOGE(rt_ret, "rtGetFunctionByName failed.");
  336. return RT_ERROR_TO_GE_STATUS(rt_ret);
  337. }
  338. task.SetStubFunc(stub_name_, stub_func);
  339. }
  340. return SUCCESS;
  341. }
  342. Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) {
  343. GELOGD("Start alloc tiling data of node %s.", op_desc_->GetName().c_str());
  344. int64_t max_size = -1;
  345. (void)AttrUtils::GetInt(op_desc_, kAttrOpParamSize, max_size);
  346. GELOGD("Got op param size by key: %s, ret = %ld", kAttrOpParamSize, max_size);
  347. if (max_size < 0) {
  348. GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[%s] Invalid op_param_size: %ld.", op_desc_->GetName().c_str(), max_size);
  349. return ACL_ERROR_GE_PARAM_INVALID;
  350. }
  351. void *tiling_buffer = nullptr;
  352. if (max_size > 0) {
  353. GE_CHK_RT_RET(rtMalloc(&tiling_buffer, static_cast<uint64_t>(max_size), RT_MEMORY_HBM));
  354. GE_CHECK_NOTNULL(tiling_buffer);
  355. GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc_->GetName().c_str(), max_size);
  356. }
  357. task.EnableDynamicSupport(node_, tiling_buffer, static_cast<size_t>(max_size));
  358. return SUCCESS;
  359. }
  360. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示