You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_var_data_utils.cc 23 kB

5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/trans_var_data_utils.h"
  17. #include "common/debug/log.h"
  18. #include "common/debug/memory_dumper.h"
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/types.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "common/thread_pool.h"
  27. #include <algorithm>
  28. namespace ge {
  29. namespace {
  30. class RtContextSwitchGuard {
  31. public:
  32. RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) {
  33. auto ret = rtCtxGetCurrent(&last_);
  34. if (ret != RT_ERROR_NONE) {
  35. GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret);
  36. return;
  37. }
  38. ret = rtCtxCreate(&current_, mode, static_cast<int32_t>(device_id));
  39. if (ret != RT_ERROR_NONE) {
  40. GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret);
  41. return;
  42. }
  43. ret = rtCtxSetCurrent(current_);
  44. if (ret != RT_ERROR_NONE) {
  45. GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id);
  46. return;
  47. }
  48. GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_);
  49. }
  50. ~RtContextSwitchGuard() {
  51. if (current_ != nullptr) {
  52. auto ret = rtCtxDestroy(current_);
  53. GELOGD("Destory current context %p result %d", current_, ret);
  54. }
  55. if (last_ != nullptr) {
  56. auto ret = rtCtxSetCurrent(last_);
  57. GELOGD("Recovery last context %p result %d.", last_, ret);
  58. }
  59. }
  60. private:
  61. rtContext_t last_;
  62. rtContext_t current_;
  63. };
  64. int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
  65. int64_t var_size = GetSizeByDataType(desc.GetDataType());
  66. if (var_size <= 0) {
  67. GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s",
  68. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
  69. return -1;
  70. }
  71. auto shape = desc.GetShape();
  72. auto dim_num = shape.GetDimNum();
  73. for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) {
  74. var_size *= shape.GetDim(dim_index);
  75. }
  76. return var_size;
  77. }
  78. Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) {
  79. GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length);
  80. auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()),
  81. trans_result.length, RT_MEMCPY_HOST_TO_DEVICE);
  82. if (ret != RT_ERROR_NONE) {
  83. GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length);
  84. return RT_FAILED;
  85. }
  86. return SUCCESS;
  87. }
  88. Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data,
  89. const GeTensorDesc &input_desc) {
  90. uint8_t *var_logic = nullptr;
  91. GE_CHECK_NOTNULL(var);
  92. auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic);
  93. if (ret != SUCCESS) {
  94. GELOGE(INTERNAL_ERROR,
  95. "Failed to copy var %s from device, can not find it"
  96. " from var manager %u",
  97. var->GetName().c_str(), ret);
  98. return INTERNAL_ERROR;
  99. }
  100. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  101. if (var_addr == nullptr) {
  102. GELOGE(INTERNAL_ERROR,
  103. "Failed to copy var %s from device, cant not get "
  104. "var addr from logic addr %p",
  105. var->GetName().c_str(), var_logic);
  106. return INTERNAL_ERROR;
  107. }
  108. int64_t var_size_bytes = CalcVarSizeInBytes(input_desc);
  109. if (var_size_bytes <= 0) {
  110. return INTERNAL_ERROR;
  111. }
  112. std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]);
  113. if (var_host == nullptr) {
  114. GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes);
  115. return OUT_OF_MEMORY;
  116. }
  117. ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr),
  118. var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST);
  119. if (ret != RT_ERROR_NONE) {
  120. GELOGE(RT_FAILED,
  121. "Failed to copy var memory from device, var %s, size %ld,"
  122. " rt-error-code %u",
  123. var->GetName().c_str(), var_size_bytes, ret);
  124. return RT_FAILED;
  125. }
  126. GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes);
  127. var_data.swap(var_host);
  128. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  129. return SUCCESS;
  130. }
  131. Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) {
  132. formats::TransResult result_last_time{};
  133. bool use_init_data = true;
  134. for (const auto &trans_info : trans_road) {
  135. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  136. GELOGD("Skip to trans variable data on the reshape/reformat node");
  137. continue;
  138. }
  139. uint8_t *src_data = nullptr;
  140. if (use_init_data) {
  141. src_data = var_data;
  142. use_init_data = false;
  143. } else {
  144. src_data = result_last_time.data.get();
  145. }
  146. formats::TransResult tmp_result{};
  147. if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  148. auto src_format = trans_info.input.GetFormat();
  149. auto src_shape = trans_info.input.GetShape().GetDims();
  150. auto dst_format = trans_info.output.GetFormat();
  151. auto dst_shape = trans_info.output.GetShape().GetDims();
  152. auto data_type = trans_info.input.GetDataType();
  153. GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s",
  154. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  155. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  156. TypeUtils::DataTypeToSerialString(data_type).c_str());
  157. auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result);
  158. if (ret != SUCCESS) {
  159. GELOGE(INTERNAL_ERROR,
  160. "Failed to trans format from %s to %s, shape %s to %s, "
  161. "data type %s error code %u",
  162. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  163. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  164. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
  165. return ret;
  166. }
  167. } else if (trans_info.node_type == CAST) {
  168. auto input_shape = trans_info.input.GetShape();
  169. auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize();
  170. auto src_data_type = trans_info.input.GetDataType();
  171. auto dst_data_type = trans_info.output.GetDataType();
  172. GELOGD("Trans data type from %s to %s, input shape %s, data size %ld",
  173. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  174. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  175. src_data_size);
  176. auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type},
  177. tmp_result);
  178. if (ret != SUCCESS) {
  179. GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u",
  180. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  181. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  182. src_data_size, ret);
  183. return ret;
  184. }
  185. } else {
  186. GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported",
  187. trans_info.node_type.c_str());
  188. return UNSUPPORTED;
  189. }
  190. result_last_time = tmp_result;
  191. }
  192. result = result_last_time;
  193. return SUCCESS;
  194. }
  195. /// re-alloc var memory on device using var-manager
  196. /// free origin var memory(var manager does not support now)
  197. /// @param session_id
  198. /// @param var
  199. /// @param var_size_bytes
  200. /// @param var_device
  201. /// @return
  202. Status ReAssignVarAddr(uint64_t session_id,
  203. const std::string &var_name,
  204. const GeTensorDesc &tensor_desc,
  205. void **var_device) {
  206. uint8_t *var_logic = nullptr;
  207. Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic);
  208. if (ret != SUCCESS) {
  209. GELOGE(INTERNAL_ERROR,
  210. "Failed to get var %s device addr, can not find it"
  211. " from var manager %u",
  212. var_name.c_str(), ret);
  213. return INTERNAL_ERROR;
  214. }
  215. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  216. if (var_addr == nullptr) {
  217. GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str());
  218. return INTERNAL_ERROR;
  219. }
  220. *var_device = var_addr;
  221. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  222. return SUCCESS;
  223. }
  224. Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) {
  225. // do not need to do anything if only all reshape/reformat node on the trans_road
  226. GE_CHECK_NOTNULL(var);
  227. bool need_trans = false;
  228. for (auto &road : trans_road) {
  229. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  230. need_trans = true;
  231. break;
  232. }
  233. }
  234. if (!need_trans) {
  235. return SUCCESS;
  236. }
  237. // Sync var data from device
  238. std::unique_ptr<uint8_t[]> var_data;
  239. if (trans_road.empty()) {
  240. GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty.");
  241. return INTERNAL_ERROR;
  242. }
  243. const GeTensorDesc &input_desc = trans_road.begin()->input;
  244. auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc);
  245. if (ret != SUCCESS) {
  246. return ret;
  247. }
  248. formats::TransResult trans_result{};
  249. ret = TransVarOnHost(var_data.get(), trans_road, trans_result);
  250. if (ret != SUCCESS) {
  251. GELOGE(ret, "Failed to trans var data on host, error code %u", ret);
  252. return ret;
  253. }
  254. void *var_device = nullptr;
  255. /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager
  256. /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the
  257. /// size of the converted variable. To complete the final solution, the dependency of the variable manager on
  258. /// TensorDesc needs to be removed. This change is large and needs to be performed step by step.
  259. ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device);
  260. if (ret != SUCCESS) {
  261. GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length);
  262. return ret;
  263. }
  264. // sync new data to device
  265. ret = CopyVarToDevice(var, trans_result, var_device);
  266. if (ret != SUCCESS) {
  267. GELOGE(ret, "Failed to send var data to device");
  268. return ret;
  269. }
  270. return SUCCESS;
  271. }
  272. Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) {
  273. GE_CHECK_NOTNULL(var_src);
  274. GE_CHECK_NOTNULL(var_src->GetOpDesc());
  275. GE_CHECK_NOTNULL(var_dst);
  276. GE_CHECK_NOTNULL(var_dst->GetOpDesc());
  277. auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize();
  278. auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType();
  279. auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType();
  280. GE_IF_BOOL_EXEC(
  281. src_data_datatype != dst_data_datatype,
  282. auto ret = formats::TransDataType(
  283. {var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
  284. if (ret != SUCCESS) {
  285. GELOGE(INTERNAL_ERROR, "trans var data on host failed");
  286. return ret;
  287. });
  288. return SUCCESS;
  289. }
  290. Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
  291. const NodePtr &var_dst,
  292. uint64_t session_id,
  293. uint32_t device_id) {
  294. /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is
  295. /// var_fp16(new).
  296. /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node.
  297. /// need copy value from var_fp32 to var_fp16.
  298. /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr]
  299. GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED);
  300. // src_node output_desc (fp32)
  301. GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0);
  302. auto src_data_type = output_desc.GetDataType();
  303. auto src_shape = output_desc.GetShape();
  304. auto src_format = output_desc.GetFormat();
  305. GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(),
  306. TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(),
  307. TypeUtils::DataTypeToSerialString(src_data_type).c_str());
  308. // dst_node output_desc (fp16)
  309. GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0);
  310. auto data_type = dst_tensor_desc.GetDataType();
  311. auto data_shape = dst_tensor_desc.GetShape();
  312. auto data_format = dst_tensor_desc.GetFormat();
  313. GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(),
  314. TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(),
  315. TypeUtils::DataTypeToSerialString(data_type).c_str());
  316. // Sync var data from device
  317. std::unique_ptr<uint8_t[]> var_src_data;
  318. RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id);
  319. // copy from src_node
  320. auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc);
  321. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret);
  322. // trans dtype
  323. formats::TransResult trans_result{};
  324. ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result);
  325. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret);
  326. // reset src value.
  327. void *var_device = nullptr;
  328. ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device);
  329. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret);
  330. // copy to device
  331. ret = CopyVarToDevice(var_dst, trans_result, var_device);
  332. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret);
  333. return SUCCESS;
  334. }
  335. } // namespace
  336. Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  337. uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) {
  338. GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. ");
  339. uint8_t *src_host_addr = nullptr;
  340. int64_t src_addr_size = 0;
  341. GE_MAKE_GUARD_RTMEM(src_host_addr);
  342. GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id));
  343. GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size);
  344. GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, "var data size is not equal broadcast ");
  345. GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  346. return SUCCESS;
  347. }
  348. Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name,
  349. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  350. GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "src addr is null. ");
  351. uint8_t *host_addr = nullptr;
  352. GE_MAKE_GUARD_RTMEM(host_addr);
  353. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size));
  354. GE_CHK_RT_RET(rtMemcpy(host_addr, src_addr_size, src_addr, src_addr_size, RT_MEMCPY_DEVICE_TO_HOST));
  355. GE_CHK_STATUS_RET(
  356. SyncTensorToDevice(var_name, reinterpret_cast<uint8_t *>(host_addr), src_addr_size, dst_tensor_desc, session_id));
  357. return SUCCESS;
  358. }
  359. Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc,
  360. uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) {
  361. GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "get size from TensorDesc failed");
  362. uint8_t *src_addr = nullptr;
  363. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr));
  364. uint8_t *mem_addr =
  365. src_addr -
  366. static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  367. static_cast<int64_t>(
  368. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  369. GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(host_addr), src_tensor_size));
  370. GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST));
  371. GELOGI("SyncTensorToHost var_name %s, src_tensor_size %ld", var_name.c_str(), src_tensor_size);
  372. return SUCCESS;
  373. }
  374. Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8_t *host_addr, uint32_t addr_size,
  375. const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) {
  376. uint8_t *dst_addr = nullptr;
  377. GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr));
  378. uint8_t *mem_addr =
  379. dst_addr -
  380. static_cast<int64_t>(static_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemLogicBase())) +
  381. static_cast<int64_t>(
  382. reinterpret_cast<uintptr_t>(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM)));
  383. GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE));
  384. GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size);
  385. return SUCCESS;
  386. }
  387. Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
  388. uint64_t session_id,
  389. rtContext_t context,
  390. uint32_t graph_id,
  391. uint32_t thread_num) {
  392. ThreadPool executor(thread_num);
  393. std::vector<std::future<Status>> vector_future;
  394. for (auto &node : variable_nodes) {
  395. if (node == nullptr) {
  396. continue;
  397. }
  398. if (node->GetType() != VARIABLE) {
  399. continue;
  400. }
  401. std::future<Status> f = executor.commit(
  402. [](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id) -> Status {
  403. rtError_t rt_ret = rtCtxSetCurrent(ctx);
  404. if (rt_ret != RT_ERROR_NONE) {
  405. GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret);
  406. return RT_ERROR_TO_GE_STATUS(rt_ret);
  407. }
  408. uint32_t allocated_graph_id = 0;
  409. Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id);
  410. if (ret != SUCCESS) {
  411. GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(),
  412. graph_id);
  413. return INTERNAL_ERROR;
  414. }
  415. uint32_t changed_graph_id = 0;
  416. ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id);
  417. bool call_trans_var =
  418. (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id);
  419. if (call_trans_var) {
  420. GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  421. VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName());
  422. if (trans_road == nullptr) {
  423. GELOGI("The variable %s does not have any trans road", node->GetName().c_str());
  424. return SUCCESS;
  425. }
  426. ret = TransVarData(node, *trans_road, session_id);
  427. if (ret != SUCCESS) {
  428. GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  429. return INTERNAL_ERROR;
  430. }
  431. VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName());
  432. }
  433. return SUCCESS;
  434. },
  435. node, session_id, context, graph_id);
  436. if (!f.valid()) {
  437. GELOGE(FAILED, "Future is invalid");
  438. return FAILED;
  439. }
  440. vector_future.push_back(std::move(f));
  441. }
  442. Status ret_status;
  443. for (size_t i = 0; i < vector_future.size(); ++i) {
  444. ret_status = vector_future[i].get();
  445. if (ret_status != SUCCESS) {
  446. GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i);
  447. return ret_status;
  448. }
  449. }
  450. return SUCCESS;
  451. }
  452. Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) {
  453. GELOGD("CopyVarData start: session_id:%lu.", session_id);
  454. if (compute_graph == nullptr) {
  455. GELOGE(FAILED, "compute_graph is nullptr");
  456. return FAILED;
  457. }
  458. string cp_from_node;
  459. bool copy_value = false;
  460. for (auto &node : compute_graph->GetAllNodes()) {
  461. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue);
  462. GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node),
  463. GELOGI("Get original type of cp_from_node"));
  464. if (cp_from_node.length() != 0) {
  465. (void) ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value
  466. if (!copy_value) {
  467. auto src_node = compute_graph->FindNode(cp_from_node);
  468. GE_CHECK_NOTNULL(src_node);
  469. GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(),
  470. src_node->GetName().c_str());
  471. auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id);
  472. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED);
  473. // only copy once
  474. (void) ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value
  475. }
  476. }
  477. }
  478. return SUCCESS;
  479. }
  480. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示