You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rsqrt_kernel.cc 5.3 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/rsqrt_kernel.h"
  17. #include <cfloat>
  18. #include <memory>
  19. #include "common/debug/ge_log.h"
  20. #include "common/debug/log.h"
  21. #include "common/ge_inner_error_codes.h"
  22. #include "common/op/ge_op_utils.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "host_kernels/kernel_utils.h"
  25. #include "inc/kernel_factory.h"
  26. #include "common/math/math_util.h"
  27. #include "framework/common/types.h"
  28. namespace ge {
  29. namespace {
  30. const size_t kRsqrtInputSize = 1;
  31. const size_t kRsqrtInputIndex0 = 0;
  32. template <typename T>
  33. Status ZeroCheck(T x, const DataType &data_type) {
  34. switch (data_type) {
  35. case DT_FLOAT16:
  36. FMK_FP16_ZEROCHECK(static_cast<double>(x))
  37. break;
  38. case DT_FLOAT:
  39. FMK_FLOAT_ZEROCHECK(static_cast<float>(x))
  40. break;
  41. case DT_DOUBLE:
  42. FMK_DOUBLE_ZEROCHECK(static_cast<double>(x))
  43. break;
  44. default:
  45. break;
  46. }
  47. return SUCCESS;
  48. }
  49. #define SET_RSQRT_CASE(DTYPE, TYPE) \
  50. case (DTYPE): \
  51. ret = RsqrtKernel::RsqrtCompute<TYPE>(input_ptr, output_ptr); \
  52. break;
  53. } // namespace
  54. template<typename T>
  55. Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr &output_tensor_ptr) {
  56. GE_CHECK_NOTNULL(input_tensor_ptr);
  57. GE_CHECK_NOTNULL(output_tensor_ptr);
  58. size_t data_size = input_tensor_ptr->GetData().size();
  59. size_t data_count = data_size / sizeof(T);
  60. auto data_type = input_tensor_ptr->GetTensorDesc().GetDataType();
  61. if (data_count > 0) {
  62. unique_ptr<T[]> buf(new(std::nothrow) T[data_count]());
  63. if (buf == nullptr) {
  64. GELOGW("New buf failed");
  65. return NOT_CHANGED;
  66. }
  67. auto ptr = const_cast<T * >(reinterpret_cast<const T *>(input_tensor_ptr->GetData().data()));
  68. for (size_t i = 0; i < data_count; i++) {
  69. if (ZeroCheck(*(ptr + i), data_type) != SUCCESS) {
  70. GELOGW("Rsqrt: The input data can not less than or equal to zero, rsqrt folding failed.");
  71. return NOT_CHANGED;
  72. }
  73. switch (data_type) {
  74. case DT_FLOAT16: {
  75. double val = static_cast<double>(*(reinterpret_cast<const fp16_t*>(input_tensor_ptr->GetData().data()) + i));
  76. double drSqrt = 1.0 / std::sqrt(val);
  77. buf[i] = drSqrt;
  78. break;
  79. }
  80. case DT_FLOAT:{
  81. float denominator = std::sqrt(*(reinterpret_cast<const float*>(input_tensor_ptr->GetData().data()) + i));
  82. buf[i] = static_cast<float >(1 / denominator);
  83. break;
  84. }
  85. case DT_DOUBLE: {
  86. double denominator = std::sqrt(*(reinterpret_cast<const double*>(input_tensor_ptr->GetData().data()) + i));
  87. buf[i] = static_cast<double>(1 / denominator);
  88. break;
  89. }
  90. default:
  91. GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
  92. return NOT_CHANGED;
  93. }
  94. }
  95. GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_size) != GRAPH_SUCCESS,
  96. GELOGW("Set data failed"); return NOT_CHANGED);
  97. output_tensor_ptr->MutableTensorDesc().SetDataType(data_type);
  98. output_tensor_ptr->MutableTensorDesc().SetShape(input_tensor_ptr->GetTensorDesc().GetShape());
  99. }
  100. return SUCCESS;
  101. }
  102. Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
  103. std::vector<GeTensorPtr> &v_output) {
  104. GELOGI("RsqrtKernel in.");
  105. GE_CHECK_NOTNULL(op_desc_ptr);
  106. // check input size
  107. if (input.size() != kRsqrtInputSize) {
  108. GELOGW("The number of input for rsqrt must be %zu.", kRsqrtInputSize);
  109. return NOT_CHANGED;
  110. }
  111. ConstGeTensorPtr input_ptr = input.at(kRsqrtInputIndex0);
  112. GE_CHECK_NOTNULL(input_ptr);
  113. // Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
  114. auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0);
  115. GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
  116. if (output_ptr == nullptr) {
  117. GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
  118. return NOT_CHANGED;
  119. }
  120. Status ret = NOT_CHANGED;
  121. auto dtype = input_ptr->GetTensorDesc().GetDataType();
  122. switch (dtype) {
  123. SET_RSQRT_CASE(DT_FLOAT16, fp16_t)
  124. SET_RSQRT_CASE(DT_FLOAT, float)
  125. SET_RSQRT_CASE(DT_DOUBLE, double)
  126. default:
  127. GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
  128. return NOT_CHANGED;
  129. }
  130. if (ret != SUCCESS) {
  131. GELOGW("Rsqrt folding failed.");
  132. return NOT_CHANGED;
  133. }
  134. v_output.push_back(output_ptr);
  135. GELOGI("RsqrtKernel success.");
  136. return SUCCESS;
  137. }
  138. REGISTER_KERNEL(RSQRT, RsqrtKernel);
  139. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示