You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rsqrt_kernel.cc 5.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/rsqrt_kernel.h"
  17. #include <cfloat>
  18. #include <memory>
  19. #include "common/debug/ge_log.h"
  20. #include "common/debug/log.h"
  21. #include "common/ge_inner_error_codes.h"
  22. #include "common/op/ge_op_utils.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "host_kernels/kernel_utils.h"
  25. #include "inc/kernel_factory.h"
  26. #include "common/math/math_util.h"
  27. namespace ge {
  28. namespace {
  29. const size_t kRsqrtInputSize = 1;
  30. const size_t kRsqrtInputIndex0 = 0;
  31. template <typename T>
  32. Status ZeroCheck(T x, const DataType &data_type) {
  33. switch (data_type) {
  34. case DT_FLOAT16:
  35. FMK_FP16_ZEROCHECK(static_cast<double>(x))
  36. break;
  37. case DT_FLOAT:
  38. FMK_FLOAT_ZEROCHECK(static_cast<float>(x))
  39. break;
  40. case DT_DOUBLE:
  41. FMK_DOUBLE_ZEROCHECK(static_cast<double>(x))
  42. break;
  43. default:
  44. break;
  45. }
  46. return SUCCESS;
  47. }
  48. #define SET_RSQRT_CASE(DTYPE, TYPE) \
  49. case (DTYPE): \
  50. ret = RsqrtKernel::RsqrtCompute<TYPE>(input_ptr, output_ptr); \
  51. break;
  52. } // namespace
  53. template <typename T>
  54. Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr &output_tensor_ptr) {
  55. GE_CHECK_NOTNULL(input_tensor_ptr);
  56. GE_CHECK_NOTNULL(output_tensor_ptr);
  57. size_t data_size = input_tensor_ptr->GetData().size();
  58. size_t data_count = data_size / sizeof(T);
  59. auto data_type = input_tensor_ptr->GetTensorDesc().GetDataType();
  60. if (data_count > 0) {
  61. unique_ptr<T[]> buf(new (std::nothrow) T[data_count]());
  62. if (buf == nullptr) {
  63. GELOGW("New buf failed");
  64. return NOT_CHANGED;
  65. }
  66. auto ptr = const_cast<T *>(reinterpret_cast<const T *>(input_tensor_ptr->GetData().data()));
  67. for (size_t i = 0; i < data_count; i++) {
  68. if (ZeroCheck(*(ptr + i), data_type) != SUCCESS) {
  69. GELOGW("Rsqrt: The input data can not less than or equal to zero, rsqrt folding failed.");
  70. return NOT_CHANGED;
  71. }
  72. switch (data_type) {
  73. case DT_FLOAT16: {
  74. double val = static_cast<double>(*(reinterpret_cast<const fp16_t *>(input_tensor_ptr->GetData().data()) + i));
  75. double drSqrt = 1.0 / std::sqrt(val);
  76. buf[i] = drSqrt;
  77. break;
  78. }
  79. case DT_FLOAT: {
  80. float denominator = std::sqrt(*(reinterpret_cast<const float *>(input_tensor_ptr->GetData().data()) + i));
  81. buf[i] = static_cast<float>(1 / denominator);
  82. break;
  83. }
  84. case DT_DOUBLE: {
  85. double denominator = std::sqrt(*(reinterpret_cast<const double *>(input_tensor_ptr->GetData().data()) + i));
  86. buf[i] = static_cast<double>(1 / denominator);
  87. break;
  88. }
  89. default:
  90. GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
  91. return NOT_CHANGED;
  92. }
  93. }
  94. GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_size) != GRAPH_SUCCESS,
  95. GELOGW("Set data failed");
  96. return NOT_CHANGED);
  97. output_tensor_ptr->MutableTensorDesc().SetDataType(data_type);
  98. output_tensor_ptr->MutableTensorDesc().SetShape(input_tensor_ptr->GetTensorDesc().GetShape());
  99. }
  100. return SUCCESS;
  101. }
  102. Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
  103. std::vector<GeTensorPtr> &v_output) {
  104. GELOGI("RsqrtKernel in.");
  105. GE_CHECK_NOTNULL(op_desc_ptr);
  106. // check input size
  107. if (input.size() != kRsqrtInputSize) {
  108. GELOGW("The number of input for rsqrt must be %zu.", kRsqrtInputSize);
  109. return NOT_CHANGED;
  110. }
  111. ConstGeTensorPtr input_ptr = input.at(kRsqrtInputIndex0);
  112. GE_CHECK_NOTNULL(input_ptr);
  113. // Index 0 can always gets a GeTensorDesc object from any OpDescPtr.
  114. auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0);
  115. GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
  116. if (output_ptr == nullptr) {
  117. GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
  118. return NOT_CHANGED;
  119. }
  120. Status ret = NOT_CHANGED;
  121. auto dtype = input_ptr->GetTensorDesc().GetDataType();
  122. switch (dtype) {
  123. SET_RSQRT_CASE(DT_FLOAT16, fp16_t)
  124. SET_RSQRT_CASE(DT_FLOAT, float)
  125. SET_RSQRT_CASE(DT_DOUBLE, double)
  126. default:
  127. GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
  128. return NOT_CHANGED;
  129. }
  130. if (ret != SUCCESS) {
  131. GELOGW("Rsqrt folding failed.");
  132. return NOT_CHANGED;
  133. }
  134. v_output.push_back(output_ptr);
  135. GELOGI("RsqrtKernel success.");
  136. return SUCCESS;
  137. }
  138. REGISTER_KERNEL(RSQRT, RsqrtKernel);
  139. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示