You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /**
  2. * \file src/gopt/impl/utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megbrain/gopt/global_layout_transform.h"
  14. namespace mgb {
  15. namespace gopt {
  16. static inline const char* opr_format_to_string(
  17. OprTensorFormatsConfiguration::OprFormat opr_format) {
  18. using OprFormat = OprTensorFormatsConfiguration::OprFormat;
  19. #define cb(_fmt) \
  20. case OprFormat::_fmt: \
  21. return #_fmt
  22. switch (opr_format) {
  23. cb(NCHW);
  24. cb(NHWC);
  25. cb(NCHW4);
  26. cb(NCHW32);
  27. cb(NCHW64);
  28. cb(CHWN4);
  29. default:
  30. mgb_assert(false, "Invalid opr format(got:%u)",
  31. static_cast<uint32_t>(opr_format));
  32. }
  33. #undef cb
  34. }
  35. static inline TensorFormats opr_format_to_tensor_formats(
  36. OprTensorFormatsConfiguration::OprFormat opr_format) {
  37. using OprFormat = OprTensorFormatsConfiguration::OprFormat;
  38. switch (opr_format) {
  39. case OprFormat::NCHW:
  40. return TensorFormats::NCHW;
  41. case OprFormat::NHWC:
  42. return TensorFormats::NHWC;
  43. case OprFormat::NCHW4:
  44. return TensorFormats::NCHWc4;
  45. case OprFormat::NCHW32:
  46. return TensorFormats::NCHWc32;
  47. case OprFormat::NCHW64:
  48. return TensorFormats::NCHWc64;
  49. case OprFormat::CHWN4:
  50. return TensorFormats::CHWNc4;
  51. default:
  52. mgb_throw(AssertionError, "format(%s) is not supported",
  53. opr_format_to_string(opr_format));
  54. };
  55. }
  56. static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape(
  57. TensorFormats format) {
  58. switch (format) {
  59. case TensorFormats::NCHW:
  60. return {{"N"}, {"C"}, {"H"}, {"W"}};
  61. case TensorFormats::NHWC:
  62. return {{"N"}, {"H"}, {"W"}, {"C"}};
  63. case TensorFormats::NCHWc4:
  64. return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}};
  65. case TensorFormats::NCHWc8:
  66. return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}};
  67. case TensorFormats::NCHWc32:
  68. return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}};
  69. case TensorFormats::NCHWc64:
  70. return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}};
  71. case TensorFormats::CHWNc4:
  72. return {{"C//4"}, {"H"}, {"W"}, {"N"}, {"C%4"}};
  73. case TensorFormats::NHCWc4:
  74. return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}};
  75. case TensorFormats::KRSCk4:
  76. return {{"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}};
  77. case TensorFormats::GKRSCk4:
  78. return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}};
  79. case TensorFormats::C1RSc4:
  80. return {{"C//4"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}};
  81. case TensorFormats::KRSCk4c4:
  82. return {{"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}};
  83. case TensorFormats::GKRSCk4c4:
  84. return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}};
  85. case TensorFormats::KCRSk4c4:
  86. return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}};
  87. case TensorFormats::GKCRSk4c4:
  88. return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}};
  89. case TensorFormats::KCRSc4k4:
  90. return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}};
  91. case TensorFormats::GKCRSc4k4:
  92. return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}};
  93. case TensorFormats::C11RSc4:
  94. return {{"C//4"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}};
  95. case TensorFormats::KCRSc8k8:
  96. return {{"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}};
  97. case TensorFormats::GKCRSc8k8:
  98. return {{"G"}, {"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}};
  99. case TensorFormats::C11RSc8:
  100. return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}};
  101. case TensorFormats::KRSCk8:
  102. return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}};
  103. case TensorFormats::KCRSc4:
  104. return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
  105. case TensorFormats::GKCRSc4:
  106. return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
  107. case TensorFormats::KCRS:
  108. return {{"K"}, {"C"}, {"R"}, {"S"}};
  109. case TensorFormats::GKCRS:
  110. return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}};
  111. case TensorFormats::C11RS:
  112. return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}};
  113. default:
  114. mgb_throw(AssertionError, "invalid tensor formats(%u)",
  115. static_cast<uint32_t>(format));
  116. }
  117. }
  118. } // namespace gopt
  119. } // namespace mgb
  120. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台