You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_factory_impl.cc 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/operator_factory_impl.h"
  17. #include "debug/ge_log.h"
  18. #include "framework/common/debug/ge_log.h"
  19. namespace ge {
  20. shared_ptr<std::map<string, OpCreator>> OperatorFactoryImpl::operator_creators_;
  21. shared_ptr<std::map<string, OpCreatorV2>> OperatorFactoryImpl::operator_creators_v2_;
  22. shared_ptr<std::map<string, InferShapeFunc>> OperatorFactoryImpl::operator_infershape_funcs_;
  23. shared_ptr<std::map<string, InferFormatFunc>> OperatorFactoryImpl::operator_inferformat_funcs_;
  24. shared_ptr<std::map<string, VerifyFunc>> OperatorFactoryImpl::operator_verify_funcs_;
  25. shared_ptr<std::map<string, InferDataSliceFunc>> OperatorFactoryImpl::operator_infer_data_slice_funcs_;
  26. Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) {
  27. if (operator_creators_v2_ != nullptr) {
  28. auto it_v2 = operator_creators_v2_->find(operator_type);
  29. if (it_v2 != operator_creators_v2_->end()) {
  30. return it_v2->second(operator_name.c_str());
  31. } else {
  32. GELOGW("No OpProto of [%s] registered by AscendString.", operator_type.c_str());
  33. }
  34. }
  35. if (operator_creators_ == nullptr) {
  36. return Operator();
  37. }
  38. auto it = operator_creators_->find(operator_type);
  39. if (it == operator_creators_->end()) {
  40. GELOGW("no OpProto of [%s] registered by string.", operator_type.c_str());
  41. return Operator();
  42. }
  43. return it->second(operator_name);
  44. }
  45. graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector<std::string> &all_ops) {
  46. all_ops.clear();
  47. if (operator_creators_v2_ != nullptr) {
  48. for (auto it_v2 = operator_creators_v2_->begin(); it_v2 != operator_creators_v2_->end(); ++it_v2) {
  49. all_ops.emplace_back(it_v2->first);
  50. }
  51. return GRAPH_SUCCESS;
  52. } else {
  53. GELOGW("Ops not registered by AscendString.");
  54. }
  55. if (operator_creators_ != nullptr) {
  56. for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) {
  57. all_ops.emplace_back(it->first);
  58. }
  59. } else {
  60. GELOGE(GRAPH_FAILED, "no operator creators found");
  61. return GRAPH_FAILED;
  62. }
  63. return GRAPH_SUCCESS;
  64. }
  65. bool OperatorFactoryImpl::IsExistOp(const string &operator_type) {
  66. if (operator_creators_v2_ != nullptr) {
  67. auto it_v2 = operator_creators_v2_->find(operator_type);
  68. if (it_v2 != operator_creators_v2_->end()) {
  69. return true;
  70. }
  71. }
  72. if (operator_creators_ == nullptr) {
  73. return false;
  74. }
  75. auto it = operator_creators_->find(operator_type);
  76. if (it == operator_creators_->end()) {
  77. return false;
  78. }
  79. return true;
  80. }
  81. InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) {
  82. if (operator_infershape_funcs_ == nullptr) {
  83. return nullptr;
  84. }
  85. auto it = operator_infershape_funcs_->find(operator_type);
  86. if (it == operator_infershape_funcs_->end()) {
  87. return nullptr;
  88. }
  89. return it->second;
  90. }
  91. InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) {
  92. if (operator_inferformat_funcs_ == nullptr) {
  93. GELOGI("operator_inferformat_funcs_ is null");
  94. return nullptr;
  95. }
  96. auto it = operator_inferformat_funcs_->find(operator_type);
  97. if (it == operator_inferformat_funcs_->end()) {
  98. return nullptr;
  99. }
  100. return it->second;
  101. }
  102. VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) {
  103. if (operator_verify_funcs_ == nullptr) {
  104. return nullptr;
  105. }
  106. auto it = operator_verify_funcs_->find(operator_type);
  107. if (it == operator_verify_funcs_->end()) {
  108. return nullptr;
  109. }
  110. return it->second;
  111. }
  112. InferDataSliceFunc OperatorFactoryImpl::GetInferDataSliceFunc(const std::string &operator_type) {
  113. if (operator_infer_data_slice_funcs_ == nullptr) {
  114. return nullptr;
  115. }
  116. auto it = operator_infer_data_slice_funcs_->find(operator_type);
  117. if (it == operator_infer_data_slice_funcs_->end()) {
  118. return nullptr;
  119. }
  120. return it->second;
  121. }
  122. graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) {
  123. if (operator_creators_ == nullptr) {
  124. operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>());
  125. }
  126. auto it = operator_creators_->find(operator_type);
  127. if (it != operator_creators_->end()) {
  128. return GRAPH_FAILED;
  129. }
  130. (void)operator_creators_->emplace(operator_type, op_creator);
  131. return GRAPH_SUCCESS;
  132. }
  133. graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreatorV2 const &op_creator) {
  134. if (operator_creators_v2_ == nullptr) {
  135. operator_creators_v2_.reset(new (std::nothrow) std::map<string, OpCreatorV2>());
  136. }
  137. auto it = operator_creators_v2_->find(operator_type);
  138. if (it != operator_creators_v2_->end()) {
  139. return GRAPH_FAILED;
  140. }
  141. (void)operator_creators_v2_->emplace(operator_type, op_creator);
  142. return GRAPH_SUCCESS;
  143. }
  144. graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type,
  145. InferShapeFunc const infer_shape_func) {
  146. if (operator_infershape_funcs_ == nullptr) {
  147. GELOGI("operator_infershape_funcs_ init");
  148. operator_infershape_funcs_.reset(new (std::nothrow) std::map<string, InferShapeFunc>());
  149. }
  150. auto it = operator_infershape_funcs_->find(operator_type);
  151. if (it != operator_infershape_funcs_->end()) {
  152. return GRAPH_FAILED;
  153. }
  154. (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func);
  155. return GRAPH_SUCCESS;
  156. }
  157. graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type,
  158. InferFormatFunc const infer_format_func) {
  159. if (operator_inferformat_funcs_ == nullptr) {
  160. GELOGI("operator_inferformat_funcs_ init");
  161. operator_inferformat_funcs_.reset(new (std::nothrow) std::map<string, InferFormatFunc>());
  162. }
  163. auto it = operator_inferformat_funcs_->find(operator_type);
  164. if (it != operator_inferformat_funcs_->end()) {
  165. return GRAPH_FAILED;
  166. }
  167. (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func);
  168. return GRAPH_SUCCESS;
  169. }
  170. graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) {
  171. if (operator_verify_funcs_ == nullptr) {
  172. GELOGI("operator_verify_funcs_ init");
  173. operator_verify_funcs_.reset(new (std::nothrow) std::map<string, VerifyFunc>());
  174. }
  175. auto it = operator_verify_funcs_->find(operator_type);
  176. if (it != operator_verify_funcs_->end()) {
  177. return GRAPH_FAILED;
  178. }
  179. (void)operator_verify_funcs_->emplace(operator_type, verify_func);
  180. return GRAPH_SUCCESS;
  181. }
  182. graphStatus OperatorFactoryImpl::RegisterInferDataSliceFunc(const std::string &operator_type,
  183. InferDataSliceFunc const infer_data_slice_func) {
  184. if (operator_infer_data_slice_funcs_ == nullptr) {
  185. GELOGI("operator_infer_data_slice_funcs_ init");
  186. operator_infer_data_slice_funcs_.reset(new (std::nothrow) std::map<string, InferDataSliceFunc>());
  187. }
  188. auto it = operator_infer_data_slice_funcs_->find(operator_type);
  189. if (it != operator_infer_data_slice_funcs_->end()) {
  190. return GRAPH_FAILED;
  191. }
  192. (void)operator_infer_data_slice_funcs_->emplace(operator_type, infer_data_slice_func);
  193. return GRAPH_SUCCESS;
  194. }
  195. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示