You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_var_data_utils.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/trans_var_data_utils.h"
  17. #include "framework/common/debug/log.h"
  18. #include "common/debug/memory_dumper.h"
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "framework/common/op/ge_op_utils.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "external/graph/types.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "common/thread_pool.h"
  27. #include <algorithm>
  28. namespace ge {
  29. namespace {
  30. class RtContextSwitchGuard {
  31. public:
  32. RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) {
  33. auto ret = rtCtxGetCurrent(&last_);
  34. if (ret != RT_ERROR_NONE) {
  35. REPORT_CALL_ERROR("E19999", "Call rtCtxGetCurrent failed, device_id:%u, ret:0x%X,",
  36. device_id, ret);
  37. GELOGE(RT_FAILED, "[Call][RtCtxGetCurrent] Failed to get current context, device_id:%u, ret:0x%X",
  38. device_id, ret);
  39. return;
  40. }
  41. ret = rtCtxCreate(&current_, mode, static_cast<int32_t>(device_id));
  42. if (ret != RT_ERROR_NONE) {
  43. REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, device_id:%u, ret:0x%X,",
  44. device_id, ret);
  45. GELOGE(RT_FAILED, "[Call][RtCtxCreate] Failed to create new context for device:%u, ret:%d", device_id, ret);
  46. return;
  47. }
  48. ret = rtCtxSetCurrent(current_);
  49. if (ret != RT_ERROR_NONE) {
  50. REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, device_id:%u, ret:0x%X", device_id, ret);
  51. GELOGE(RT_FAILED, "[Call][RtCtxSetCurrent] failed, device_id:%u, ret:0x%X", device_id, ret);
  52. return;
  53. }
  54. GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_);
  55. }
  56. ~RtContextSwitchGuard() {
  57. if (current_ != nullptr) {
  58. auto ret = rtCtxDestroy(current_);
  59. GELOGD("Destory current context %p result %d", current_, ret);
  60. }
  61. if (last_ != nullptr) {
  62. auto ret = rtCtxSetCurrent(last_);
  63. GELOGD("Recovery last context %p result %d.", last_, ret);
  64. }
  65. }
  66. private:
  67. rtContext_t last_;
  68. rtContext_t current_;
  69. };
  70. int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) {
  71. int64_t var_size = GetSizeByDataType(desc.GetDataType());
  72. if (var_size <= 0) {
  73. REPORT_INNER_ERROR("E19999", "Data type:%s in desc, it's size:%ld < 0, check invalid",
  74. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size);
  75. GELOGE(PARAM_INVALID, "[Calc][VarDataSize] by data type %s failed.",
  76. TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str());
  77. return -1;
  78. }
  79. auto shape = desc.GetShape();
  80. auto dim_num = shape.GetDimNum();
  81. for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) {
  82. var_size *= shape.GetDim(dim_index);
  83. }
  84. return var_size;
  85. }
  86. Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) {
  87. GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length);
  88. auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()),
  89. trans_result.length, RT_MEMCPY_HOST_TO_DEVICE);
  90. if (ret != RT_ERROR_NONE) {
  91. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, op:%s(%s), size:%lu, ret:0x%X,", var->GetName().c_str(),
  92. var->GetType().c_str(), trans_result.length, ret);
  93. GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, op:%s(%s), size:%lu, ret:0x%X,", var->GetName().c_str(),
  94. var->GetType().c_str(), trans_result.length, ret);
  95. return RT_FAILED;
  96. }
  97. return SUCCESS;
  98. }
  99. Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data,
  100. const GeTensorDesc &input_desc) {
  101. uint8_t *var_logic = nullptr;
  102. GE_CHECK_NOTNULL(var);
  103. auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic);
  104. if (ret != SUCCESS) {
  105. GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, node:%s, session_id:%lu, ret:%d",
  106. var->GetName().c_str(), session_id, ret);
  107. return INTERNAL_ERROR;
  108. }
  109. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  110. if (var_addr == nullptr) {
  111. REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, op:%s(%s), session_id:%lu",
  112. RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id);
  113. GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%d, op:%s(%s), session_id:%lu",
  114. RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id);
  115. return INTERNAL_ERROR;
  116. }
  117. int64_t var_size_bytes = CalcVarSizeInBytes(input_desc);
  118. if (var_size_bytes <= 0) {
  119. return INTERNAL_ERROR;
  120. }
  121. std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]);
  122. if (var_host == nullptr) {
  123. REPORT_CALL_ERROR("E19999", "New host memory failed, size:%ld, op:%s(%s), session_id:%lu",
  124. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id);
  125. GELOGE(OUT_OF_MEMORY, "[New][Memory] for rt-host failed, size:%ld, op:%s(%s), session_id:%lu",
  126. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id);
  127. return OUT_OF_MEMORY;
  128. }
  129. ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr),
  130. var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST);
  131. if (ret != RT_ERROR_NONE) {
  132. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X",
  133. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret);
  134. GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X",
  135. var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret);
  136. return RT_FAILED;
  137. }
  138. GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes);
  139. var_data.swap(var_host);
  140. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  141. return SUCCESS;
  142. }
  143. Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) {
  144. formats::TransResult result_last_time{};
  145. bool use_init_data = true;
  146. for (const auto &trans_info : trans_road) {
  147. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  148. GELOGD("Skip to trans variable data on the reshape/reformat node");
  149. continue;
  150. }
  151. uint8_t *src_data = nullptr;
  152. if (use_init_data) {
  153. src_data = var_data;
  154. use_init_data = false;
  155. } else {
  156. src_data = result_last_time.data.get();
  157. }
  158. formats::TransResult tmp_result{};
  159. if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  160. auto src_format = trans_info.input.GetFormat();
  161. auto src_shape = trans_info.input.GetShape().GetDims();
  162. auto dst_format = trans_info.output.GetFormat();
  163. auto dst_shape = trans_info.output.GetShape().GetDims();
  164. auto data_type = trans_info.input.GetDataType();
  165. GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s",
  166. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  167. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  168. TypeUtils::DataTypeToSerialString(data_type).c_str());
  169. auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result);
  170. if (ret != SUCCESS) {
  171. REPORT_CALL_ERROR("E19999", "Trans format from %s to %s, shape %s to %s failed, data type:%s, ret:%u,",
  172. TypeUtils::FormatToSerialString(src_format).c_str(),
  173. TypeUtils::FormatToSerialString(dst_format).c_str(),
  174. formats::ShapeToString(src_shape).c_str(),
  175. formats::ShapeToString(dst_shape).c_str(),
  176. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
  177. GELOGE(INTERNAL_ERROR, "[Trans][Format] from %s to %s, shape %s to %s failed, data type %s error code %u",
  178. TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(),
  179. formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(),
  180. TypeUtils::DataTypeToSerialString(data_type).c_str(), ret);
  181. return ret;
  182. }
  183. } else if (trans_info.node_type == CAST) {
  184. auto input_shape = trans_info.input.GetShape();
  185. auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize();
  186. auto src_data_type = trans_info.input.GetDataType();
  187. auto dst_data_type = trans_info.output.GetDataType();
  188. GELOGD("Trans data type from %s to %s, input shape %s, data size %ld",
  189. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  190. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  191. src_data_size);
  192. auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type},
  193. tmp_result);
  194. if (ret != SUCCESS) {
  195. REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, input shape %s, data size %ld, ret:%u",
  196. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  197. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(),
  198. formats::ShapeToString(input_shape).c_str(), src_data_size, ret);
  199. GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, input shape %s, data size %ld, error code %u",
  200. TypeUtils::DataTypeToSerialString(src_data_type).c_str(),
  201. TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(),
  202. src_data_size, ret);
  203. return ret;
  204. }
  205. } else {
  206. REPORT_INNER_ERROR("E19999", "Trans var data failed, the trans type %s does not supported, check invalid",
  207. trans_info.node_type.c_str());
  208. GELOGE(UNSUPPORTED, "[Trans][VarData] failed, the trans type %s does not supported",
  209. trans_info.node_type.c_str());
  210. return UNSUPPORTED;
  211. }
  212. result_last_time = tmp_result;
  213. }
  214. result = result_last_time;
  215. return SUCCESS;
  216. }
  217. /// re-alloc var memory on device using var-manager
  218. /// free origin var memory(var manager does not support now)
  219. /// @param session_id
  220. /// @param var
  221. /// @param var_size_bytes
  222. /// @param var_device
  223. /// @return
  224. Status ReAssignVarAddr(uint64_t session_id,
  225. const std::string &var_name,
  226. const GeTensorDesc &tensor_desc,
  227. void **var_device) {
  228. uint8_t *var_logic = nullptr;
  229. Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic);
  230. if (ret != SUCCESS) {
  231. GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, var name:%s, session_id:%lu, ret:%u",
  232. var_name.c_str(), session_id, ret);
  233. return INTERNAL_ERROR;
  234. }
  235. uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
  236. if (var_addr == nullptr) {
  237. REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, var_name:%s, session_id:%lu,",
  238. RT_MEMORY_HBM, var_name.c_str(), session_id);
  239. GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%d, var_name:%s, session_id:%lu",
  240. RT_MEMORY_HBM, var_name.c_str(), session_id);
  241. return INTERNAL_ERROR;
  242. }
  243. *var_device = var_addr;
  244. GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr);
  245. return SUCCESS;
  246. }
  247. Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) {
  248. // do not need to do anything if only all reshape/reformat node on the trans_road
  249. GE_CHECK_NOTNULL(var);
  250. bool need_trans = false;
  251. for (auto &road : trans_road) {
  252. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  253. need_trans = true;
  254. break;
  255. }
  256. }
  257. if (!need_trans) {
  258. return SUCCESS;
  259. }
  260. // Sync var data from device
  261. std::unique_ptr<uint8_t[]> var_data;
  262. if (trans_road.empty()) {
  263. REPORT_INNER_ERROR("E19999", "Param trans_road is empty, session_id:%lu, check invalid", session_id);
  264. GELOGE(INTERNAL_ERROR, "[Check][Param] trans_road is empty, session_id:%lu", session_id);
  265. return INTERNAL_ERROR;
  266. }
  267. const GeTensorDesc &input_desc = trans_road.begin()->input;
  268. auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc);
  269. if (ret != SUCCESS) {
  270. return ret;
  271. }
  272. formats::TransResult trans_result{};
  273. ret = TransVarOnHost(var_data.get(), trans_road, trans_result);
  274. if (ret != SUCCESS) {
  275. GELOGE(ret, "[Call][TransVarOnHost] failed, session_id:%lu, ret:%u", session_id, ret);
  276. return ret;
  277. }
  278. void *var_device = nullptr;
  279. /// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager
  280. /// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the
  281. /// size of the converted variable. To complete the final solution, the dependency of the variable manager on
  282. /// TensorDesc needs to be removed. This change is large and needs to be performed step by step.
  283. ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device);
  284. if (ret != SUCCESS) {
  285. GELOGE(ret, "[Call][ReAssignVarAddr] failed, session id:%lu, op:%s, ret:%u",
  286. session_id, var->GetName().c_str(), ret);
  287. return ret;
  288. }
  289. // sync new data to device
  290. ret = CopyVarToDevice(var, trans_result, var_device);
  291. if (ret != SUCCESS) {
  292. GELOGE(ret, "[Call][CopyVarToDevice] failed, var:%s, ret:%u", var->GetName().c_str(), ret);
  293. return ret;
  294. }
  295. return SUCCESS;
  296. }
  297. Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) {
  298. GE_CHECK_NOTNULL(var_src);
  299. GE_CHECK_NOTNULL(var_src->GetOpDesc());
  300. GE_CHECK_NOTNULL(var_dst);
  301. GE_CHECK_NOTNULL(var_dst->GetOpDesc());
  302. auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize();
  303. auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType();
  304. auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType();
  305. GE_IF_BOOL_EXEC(
  306. src_data_datatype != dst_data_datatype,
  307. auto ret = formats::TransDataType(
  308. {var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result);
  309. if (ret != SUCCESS) {
  310. REPORT_CALL_ERROR("E19999", "Trans data type from %s to %s failed, data size %ld, ret:%u",
  311. TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
  312. TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(),
  313. src_data_shape_size, ret);
  314. GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, data size %ld, ret:%u",
  315. TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(),
  316. TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(),
  317. src_data_shape_size, ret);
  318. return ret;
  319. });
  320. return SUCCESS;
  321. }
  322. Status CopyTensorFromSrcVarNode(const NodePtr &var_src,
  323. const NodePtr &var_dst,
  324. uint64_t session_id,
  325. uint32_t device_id) {
  326. /// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is
  327. /// var_fp16(new).
  328. /// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node.
  329. /// need copy value from var_fp32 to var_fp16.
  330. /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr]
  331. GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr,
  332. REPORT_INNER_ERROR("E19999", "Param var_src or var_dst is nullptr, session_id:%lu, device_id:%u, "
  333. "check invalid", session_id, device_id);
  334. GELOGE(FAILED, "[Check][Param] Param var_src or var_dst is nullptr, session_id:%lu, device_id:%u",
  335. session_id, device_id);
  336. return FAILED);
  337. // src_node output_desc (fp32)
  338. GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0);
  339. auto src_data_type = output_desc.GetDataType();
  340. auto src_shape = output_desc.GetShape();
  341. auto src_format = output_desc.GetFormat();
  342. GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(),
  343. TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(),
  344. TypeUtils::DataTypeToSerialString(src_data_type).c_str());
  345. // dst_node output_desc (fp16)
  346. GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0);
  347. auto data_type = dst_tensor_desc.GetDataType();
  348. auto data_shape = dst_tensor_desc.GetShape();
  349. auto data_format = dst_tensor_desc.GetFormat();
  350. GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(),
  351. TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(),
  352. TypeUtils::DataTypeToSerialString(data_type).c_str());
  353. // Sync var data from device
  354. std::unique_ptr<uint8_t[]> var_src_data;
  355. RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id);
  356. // copy from src_node
  357. auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc);
  358. GE_IF_BOOL_EXEC(ret != SUCCESS,
  359. GELOGE(FAILED, "[Call][CopyVarFromDevice] failed, session id:%lu, var_src:%s",
  360. session_id, var_src->GetName().c_str());
  361. return ret);
  362. // trans dtype
  363. formats::TransResult trans_result{};
  364. ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result);
  365. GE_IF_BOOL_EXEC(ret != SUCCESS,
  366. GELOGE(INTERNAL_ERROR, "[Trans][Tensor] failed, var_src:%s, var_dst:%s",
  367. var_src->GetName().c_str(), var_dst->GetName().c_str());
  368. return ret);
  369. // reset src value.
  370. void *var_device = nullptr;
  371. ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device);
  372. GE_IF_BOOL_EXEC(ret != SUCCESS,
  373. GELOGE(INTERNAL_ERROR, "[Call][ReAssignVarAddr] failed, session id:%lu, var_dst:%s",
  374. session_id, var_dst->GetName().c_str());
  375. return ret);
  376. // copy to device
  377. ret = CopyVarToDevice(var_dst, trans_result, var_device);
  378. GE_IF_BOOL_EXEC(ret != SUCCESS,
  379. GELOGE(ret, "[Call][CopyVarToDevice] failed, var_dst:%s, ret:%u",
  380. var_dst->GetName().c_str(), ret);
  381. return ret);
  382. return SUCCESS;
  383. }
  384. } // namespace
  385. Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes,
  386. uint64_t session_id,
  387. rtContext_t context,
  388. uint32_t graph_id,
  389. uint32_t thread_num) {
  390. ThreadPool executor(thread_num);
  391. std::vector<std::future<Status>> vector_future;
  392. for (auto &node : variable_nodes) {
  393. if (node == nullptr) {
  394. continue;
  395. }
  396. if (node->GetType() != VARIABLE) {
  397. continue;
  398. }
  399. std::future<Status> f = executor.commit(
  400. [](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id,
  401. const struct error_message::Context &error_context) -> Status {
  402. ErrorManager::GetInstance().SetErrorContext(error_context);
  403. rtError_t rt_ret = rtCtxSetCurrent(ctx);
  404. if (rt_ret != RT_ERROR_NONE) {
  405. REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, session_id:%lu, graph_id:%u, ret:0x%X,",
  406. session_id, graph_id, rt_ret);
  407. GELOGE(RT_FAILED, "[Call][RtCtxSetCurrent] failed, session_id:%lu, graph_id:%u, ret:0x%X,",
  408. session_id, graph_id, rt_ret);
  409. return RT_ERROR_TO_GE_STATUS(rt_ret);
  410. }
  411. uint32_t allocated_graph_id = 0;
  412. Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id);
  413. if (ret != SUCCESS) {
  414. REPORT_CALL_ERROR("E19999", "Get allocated GraphId failed, session_id:%lu, graph_id:%u, ret:0x%X,",
  415. session_id, graph_id, ret);
  416. GELOGE(INTERNAL_ERROR, "[Get][AllocatedGraphId] failed, node:%s, graph_id:%u.",
  417. node->GetName().c_str(), graph_id);
  418. return INTERNAL_ERROR;
  419. }
  420. uint32_t changed_graph_id = 0;
  421. ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id);
  422. bool call_trans_var =
  423. (ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id);
  424. if (call_trans_var) {
  425. GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id);
  426. VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName());
  427. if (trans_road == nullptr) {
  428. GELOGI("The variable %s does not have any trans road", node->GetName().c_str());
  429. return SUCCESS;
  430. }
  431. ret = TransVarData(node, *trans_road, session_id);
  432. if (ret != SUCCESS) {
  433. GELOGE(INTERNAL_ERROR, "[Trans][VarData] failed, node:%s, graph_id:%u, session_id:%lu.",
  434. node->GetName().c_str(), graph_id, session_id);
  435. return INTERNAL_ERROR;
  436. }
  437. VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName());
  438. }
  439. return SUCCESS;
  440. },
  441. node, session_id, context, graph_id, ErrorManager::GetInstance().GetErrorManagerContext());
  442. if (!f.valid()) {
  443. GELOGE(FAILED, "[Check][Param] Future is invalid, session id:%lu, graph id:%u", session_id, graph_id);
  444. return FAILED;
  445. }
  446. vector_future.push_back(std::move(f));
  447. }
  448. Status ret_status;
  449. for (size_t i = 0; i < vector_future.size(); ++i) {
  450. ret_status = vector_future[i].get();
  451. if (ret_status != SUCCESS) {
  452. GELOGE(ret_status, "[Check][Param] trans %zu vardata failed", i);
  453. return ret_status;
  454. }
  455. }
  456. return SUCCESS;
  457. }
  458. Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) {
  459. GELOGD("CopyVarData start: session_id:%lu.", session_id);
  460. if (compute_graph == nullptr) {
  461. REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, session_id:%lu, device_id:%u, check invalid",
  462. session_id, device_id);
  463. GELOGE(FAILED, "[Check][Param] compute_graph is nullptr, session_id:%lu, device_id:%u",
  464. session_id, device_id);
  465. return FAILED;
  466. }
  467. string cp_from_node;
  468. bool copy_value = false;
  469. for (auto &node : compute_graph->GetAllNodes()) {
  470. GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue);
  471. GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node),
  472. GELOGI("Get original type of cp_from_node"));
  473. if (cp_from_node.length() != 0) {
  474. (void) ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value
  475. if (!copy_value) {
  476. auto src_node = compute_graph->FindNode(cp_from_node);
  477. GE_CHECK_NOTNULL(src_node);
  478. GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(),
  479. src_node->GetName().c_str());
  480. auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id);
  481. GE_IF_BOOL_EXEC(ret != SUCCESS,
  482. GELOGE(FAILED, "[Copy][Tensor] failed, src_node:%s, node:%s, session_id:%lu, device_id:%u",
  483. src_node->GetName().c_str(), node->GetName().c_str(), session_id, device_id);
  484. return FAILED);
  485. // only copy once
  486. (void) ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value
  487. }
  488. }
  489. }
  490. return SUCCESS;
  491. }
  492. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示