You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/arm_common/matrix_mul/opr_impl.h"
  14. namespace megdnn {
  15. namespace aarch64 {
  16. class MatrixMulImpl : public arm_common::MatrixMulImpl {
  17. public:
  18. using arm_common::MatrixMulImpl::MatrixMulImpl;
  19. class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase {
  20. public:
  21. AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() {
  22. m_handle_type = Handle::HandleType::AARCH64;
  23. }
  24. };
  25. SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
  26. override;
  27. MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
  28. private:
  29. class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
  30. class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1
  31. class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
  32. class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
  33. class AlgoF32Gemv; // Aarch64 F32 Gemv
  34. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  35. class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1
  36. class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8
  37. #endif
  38. #if __ARM_FEATURE_DOTPROD
  39. class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
  40. // 8x12x4 DotProduct
  41. class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
  42. // 8x12x4 DotProduct
  43. #else
  44. class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
  45. class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
  46. class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
  47. #endif
  48. class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
  49. class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
  50. class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16
  51. class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8
  52. class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1
  53. class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8
  54. #if __ARM_FEATURE_DOTPROD
  55. class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel
  56. // 8x8x4 DotProduct
  57. class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct
  58. #else
  59. class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
  60. #endif
  61. class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
  62. class AlgoPack;
  63. public:
  64. static const AlgoPack& algo_pack();
  65. };
  66. } // namespace aarch64
  67. } // namespace megdnn
  68. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台