You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

integer_subbyte_utils.cuh 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. /**
  2. * \file dnn/src/cuda/integer_subbyte_utils.cuh
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #if MEGDNN_CC_CUDA
  13. #pragma once
  14. #include "src/cuda/utils.cuh"
  15. namespace megdnn {
  16. namespace cuda {
  17. namespace integer_subbyte {
  18. template <bool signedness>
  19. struct integer_trait;
  20. template <>
  21. struct integer_trait<true> {
  22. using type = int;
  23. };
  24. template <>
  25. struct integer_trait<false> {
  26. using type = unsigned;
  27. };
  28. MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8(
  29. int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
  30. unsigned out;
  31. #if __CUDA_ARCH__ >= 750 && \
  32. ((__CUDACC_VER_MAJOR__ > 10) || \
  33. ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
  34. asm volatile(
  35. "{ .reg .u32 r4;"
  36. "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;"
  37. "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;"
  38. "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;"
  39. "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;"
  40. "}"
  41. : "=r"(out)
  42. : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6),
  43. "r"(s7));
  44. #else
  45. #define CVT_SAT_S4_S32(r, bits) \
  46. r = r <= -8 ? -8 : r; \
  47. r = r > 7 ? 7 : r; \
  48. r = (((unsigned)r & 0xf) << bits);
  49. CVT_SAT_S4_S32(s0, 0)
  50. CVT_SAT_S4_S32(s1, 4)
  51. CVT_SAT_S4_S32(s2, 8)
  52. CVT_SAT_S4_S32(s3, 12)
  53. CVT_SAT_S4_S32(s4, 16)
  54. CVT_SAT_S4_S32(s5, 20)
  55. CVT_SAT_S4_S32(s6, 24)
  56. CVT_SAT_S4_S32(s7, 28)
  57. out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;
  58. #undef CVT_SAT_S4_S32
  59. #endif
  60. return reinterpret_cast<int const&>(out);
  61. }
  62. MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
  63. int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
  64. unsigned out;
  65. #if __CUDA_ARCH__ >= 750 && \
  66. ((__CUDACC_VER_MAJOR__ > 10) || \
  67. ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
  68. asm volatile(
  69. "{ .reg .u32 r4;"
  70. "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
  71. "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
  72. "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
  73. "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
  74. "}"
  75. : "=r"(out)
  76. : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6),
  77. "r"(s7));
  78. #else
  79. #define CVT_SAT_U4_S32(r, bits) \
  80. r = r <= 0 ? 0 : r; \
  81. r = r > 15 ? 15 : r; \
  82. r = (((unsigned)r & 0xf) << bits);
  83. CVT_SAT_U4_S32(s0, 0)
  84. CVT_SAT_U4_S32(s1, 4)
  85. CVT_SAT_U4_S32(s2, 8)
  86. CVT_SAT_U4_S32(s3, 12)
  87. CVT_SAT_U4_S32(s4, 16)
  88. CVT_SAT_U4_S32(s5, 20)
  89. CVT_SAT_U4_S32(s6, 24)
  90. CVT_SAT_U4_S32(s7, 28)
  91. out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;
  92. #undef CVT_SAT_U4_S32
  93. #endif
  94. return reinterpret_cast<int const&>(out);
  95. }
  96. template <bool signedness, typename T>
  97. MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
  98. int bits) {
  99. //! size in bits of 32 bit integer - 4 bits
  100. static constexpr int shift = 28;
  101. using type = typename integer_trait<signedness>::type;
  102. unsigned intermediate = static_cast<unsigned>(storage);
  103. type result = reinterpret_cast<type&>(intermediate);
  104. return (result << (shift - bits)) >> shift;
  105. }
  106. MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(
  107. int (&result)[8], const int& source) {
  108. #pragma unroll
  109. for (int i = 0; i < 8; i++) {
  110. result[i] = unpack_integer_4bits<true>(
  111. reinterpret_cast<unsigned const&>(source), (i << 2));
  112. }
  113. }
  114. MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8(
  115. int (&result)[8], const int& source) {
  116. #pragma unroll
  117. for (int i = 0; i < 8; i++) {
  118. result[i] = unpack_integer_4bits<false>(
  119. reinterpret_cast<unsigned const&>(source), (i << 2));
  120. }
  121. }
  122. MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8(
  123. int (&result)[2], const uint8_t& source) {
  124. result[0] = unpack_integer_4bits<true>(source, 0);
  125. result[1] = unpack_integer_4bits<true>(source, 4);
  126. }
  127. MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8(
  128. int (&result)[2], const uint8_t& source) {
  129. result[0] = unpack_integer_4bits<false>(source, 0);
  130. result[1] = unpack_integer_4bits<false>(source, 4);
  131. }
  132. } // namespace integer_subbyte
  133. } // namespace cuda
  134. } // namespace megdnn
  135. #endif
  136. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台