You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_factory_impl.cc 5.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/operator_factory_impl.h"
  17. #include "debug/ge_log.h"
  18. #include "framework/common/debug/ge_log.h"
  19. namespace ge {
  20. shared_ptr<std::map<string, OpCreator>> OperatorFactoryImpl::operator_creators_;
  21. shared_ptr<std::map<string, InferShapeFunc>> OperatorFactoryImpl::operator_infershape_funcs_;
  22. shared_ptr<std::map<string, InferFormatFunc>> OperatorFactoryImpl::operator_inferformat_funcs_;
  23. shared_ptr<std::map<string, VerifyFunc>> OperatorFactoryImpl::operator_verify_funcs_;
  24. Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) {
  25. if (operator_creators_ == nullptr) {
  26. return Operator();
  27. }
  28. auto it = operator_creators_->find(operator_type);
  29. if (it == operator_creators_->end()) {
  30. GELOGW("no OpProto of [%s] registered", operator_type.c_str());
  31. return Operator();
  32. }
  33. return it->second(operator_name);
  34. }
  35. graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector<std::string> &all_ops) {
  36. all_ops.clear();
  37. if (operator_creators_ != nullptr) {
  38. for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) {
  39. all_ops.emplace_back(it->first);
  40. }
  41. } else {
  42. GELOGE(GRAPH_FAILED, "no operator creators found");
  43. return GRAPH_FAILED;
  44. }
  45. return GRAPH_SUCCESS;
  46. }
  47. bool OperatorFactoryImpl::IsExistOp(const string &operator_type) {
  48. if (operator_creators_ == nullptr) {
  49. return false;
  50. }
  51. auto it = operator_creators_->find(operator_type);
  52. if (it == operator_creators_->end()) {
  53. return false;
  54. }
  55. return true;
  56. }
  57. InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) {
  58. if (operator_infershape_funcs_ == nullptr) {
  59. return nullptr;
  60. }
  61. auto it = operator_infershape_funcs_->find(operator_type);
  62. if (it == operator_infershape_funcs_->end()) {
  63. return nullptr;
  64. }
  65. return it->second;
  66. }
  67. InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) {
  68. if (operator_inferformat_funcs_ == nullptr) {
  69. GELOGI("operator_inferformat_funcs_ is null");
  70. return nullptr;
  71. }
  72. auto it = operator_inferformat_funcs_->find(operator_type);
  73. if (it == operator_inferformat_funcs_->end()) {
  74. return nullptr;
  75. }
  76. return it->second;
  77. }
  78. VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) {
  79. if (operator_verify_funcs_ == nullptr) {
  80. return nullptr;
  81. }
  82. auto it = operator_verify_funcs_->find(operator_type);
  83. if (it == operator_verify_funcs_->end()) {
  84. return nullptr;
  85. }
  86. return it->second;
  87. }
  88. graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) {
  89. if (operator_creators_ == nullptr) {
  90. GELOGI("operator_creators_ init");
  91. operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>());
  92. }
  93. auto it = operator_creators_->find(operator_type);
  94. if (it != operator_creators_->end()) {
  95. return GRAPH_FAILED;
  96. }
  97. (void)operator_creators_->emplace(operator_type, op_creator);
  98. return GRAPH_SUCCESS;
  99. }
  100. graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type,
  101. InferShapeFunc const infer_shape_func) {
  102. if (operator_infershape_funcs_ == nullptr) {
  103. GELOGI("operator_infershape_funcs_ init");
  104. operator_infershape_funcs_.reset(new (std::nothrow) std::map<string, InferShapeFunc>());
  105. }
  106. auto it = operator_infershape_funcs_->find(operator_type);
  107. if (it != operator_infershape_funcs_->end()) {
  108. return GRAPH_FAILED;
  109. }
  110. (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func);
  111. return GRAPH_SUCCESS;
  112. }
  113. graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type,
  114. InferFormatFunc const infer_format_func) {
  115. if (operator_inferformat_funcs_ == nullptr) {
  116. GELOGI("operator_inferformat_funcs_ init");
  117. operator_inferformat_funcs_.reset(new (std::nothrow) std::map<string, InferFormatFunc>());
  118. }
  119. auto it = operator_inferformat_funcs_->find(operator_type);
  120. if (it != operator_inferformat_funcs_->end()) {
  121. return GRAPH_FAILED;
  122. }
  123. (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func);
  124. return GRAPH_SUCCESS;
  125. }
  126. graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) {
  127. if (operator_verify_funcs_ == nullptr) {
  128. GELOGI("operator_verify_funcs_ init");
  129. operator_verify_funcs_.reset(new (std::nothrow) std::map<string, VerifyFunc>());
  130. }
  131. auto it = operator_verify_funcs_->find(operator_type);
  132. if (it != operator_verify_funcs_->end()) {
  133. return GRAPH_FAILED;
  134. }
  135. (void)operator_verify_funcs_->emplace(operator_type, verify_func);
  136. return GRAPH_SUCCESS;
  137. }
  138. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示