You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transpose_utils.cuh 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. /**
  2. * \file dnn/src/cuda/memory_utils.cuh
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #if MEGDNN_CC_CUDA
  13. #pragma once
  14. #include "src/cuda/utils.cuh"
  15. namespace megdnn {
  16. namespace cuda {
  17. MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl(
  18. const int src0, const int src1, const int src2, const int src3,
  19. int& dst0, int& dst1, int& dst2, int& dst3) {
  20. int dst01_lo = __byte_perm(src0, src1, 0x5140);
  21. int dst01_hi = __byte_perm(src0, src1, 0x7362);
  22. int dst23_lo = __byte_perm(src2, src3, 0x5140);
  23. int dst23_hi = __byte_perm(src2, src3, 0x7362);
  24. dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410);
  25. dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632);
  26. dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410);
  27. dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632);
  28. }
  29. template <uint32_t interleaved, typename vec_type>
  30. MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4(
  31. const int src[interleaved], vec_type (&dst)[4]);
  32. template <>
  33. MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>(
  34. const int src[4], int (&dst)[4]) {
  35. transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1],
  36. dst[2], dst[3]);
  37. }
  38. template <>
  39. MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>(
  40. const int src[8], int2 (&dst)[4]) {
  41. transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
  42. dst[2].x, dst[3].x);
  43. transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
  44. dst[2].y, dst[3].y);
  45. }
  46. template <>
  47. MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>(
  48. const int src[16], int4 (&dst)[4]) {
  49. transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
  50. dst[2].x, dst[3].x);
  51. transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
  52. dst[2].y, dst[3].y);
  53. transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z,
  54. dst[1].z, dst[2].z, dst[3].z);
  55. transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w,
  56. dst[1].w, dst[2].w, dst[3].w);
  57. }
  58. } // namespace cuda
  59. } // namespace megdnn
  60. #endif
  61. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台