You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operation_table.h 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. /***************************************************************************************************
  2. * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Redistribution and use in source and binary forms, with or without
  5. *modification, are permitted provided that the following conditions are met:
  6. * * Redistributions of source code must retain the above copyright notice,
  7. *this list of conditions and the following disclaimer.
  8. * * Redistributions in binary form must reproduce the above copyright
  9. *notice, this list of conditions and the following disclaimer in the
  10. *documentation and/or other materials provided with the distribution.
  11. * * Neither the name of the NVIDIA CORPORATION nor the names of its
  12. *contributors may be used to endorse or promote products derived from this
  13. *software without specific prior written permission.
  14. *
  15. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  16. *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
  19. *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  20. * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  21. *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  22. *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
  23. *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  24. *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. *
  26. **************************************************************************************************/
  27. /**
  28. * \file dnn/src/cuda/cutlass/operation_table.h
  29. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  30. *
  31. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  32. *
  33. * Unless required by applicable law or agreed to in writing,
  34. * software distributed under the License is distributed on an
  35. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  36. * implied.
  37. */
  38. #pragma once
  39. #include <unordered_map>
  40. #include "src/common/hash_ct.h"
  41. #include "src/cuda/cutlass/manifest.h"
  42. #include "src/cuda/cutlass/util.h"
  43. /////////////////////////////////////////////////////////////////////////////////////////////////
  44. namespace cutlass {
  45. namespace library {
  46. /////////////////////////////////////////////////////////////////////////////////////////////////
  47. class Hash {
  48. public:
  49. Hash() : m_val(0) {}
  50. Hash& update(const void* ptr, size_t len) {
  51. m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456);
  52. return *this;
  53. }
  54. uint64_t digest() const { return m_val; }
  55. private:
  56. uint64_t m_val;
  57. };
  58. /////////////////////////////////////////////////////////////////////////////////////////////////
  59. // Data Structures for GemmOperationMap
  60. /////////////////////////////////////////////////////////////////////////////////////////////////
  61. struct GemmKey {
  62. NumericTypeID element_A;
  63. LayoutTypeID layout_A;
  64. NumericTypeID element_B;
  65. LayoutTypeID layout_B;
  66. NumericTypeID element_C;
  67. LayoutTypeID layout_C;
  68. NumericTypeID element_accumulator;
  69. int threadblock_shape_m;
  70. int threadblock_shape_n;
  71. int threadblock_shape_k;
  72. int warp_shape_m;
  73. int warp_shape_n;
  74. int warp_shape_k;
  75. int instruction_shape_m;
  76. int instruction_shape_n;
  77. int instruction_shape_k;
  78. int stages;
  79. int alignment_A;
  80. int alignment_B;
  81. SplitKMode split_k_mode;
  82. inline bool operator==(GemmKey const& rhs) const {
  83. return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) &&
  84. (element_B == rhs.element_B) && (layout_B == rhs.layout_B) &&
  85. (element_C == rhs.element_C) && (layout_C == rhs.layout_C) &&
  86. (element_accumulator == rhs.element_accumulator) &&
  87. (threadblock_shape_m == rhs.threadblock_shape_m) &&
  88. (threadblock_shape_n == rhs.threadblock_shape_n) &&
  89. (threadblock_shape_k == rhs.threadblock_shape_k) &&
  90. (warp_shape_m == rhs.warp_shape_m) &&
  91. (warp_shape_n == rhs.warp_shape_n) &&
  92. (warp_shape_k == rhs.warp_shape_k) &&
  93. (instruction_shape_m == rhs.instruction_shape_m) &&
  94. (instruction_shape_n == rhs.instruction_shape_n) &&
  95. (instruction_shape_k == rhs.instruction_shape_k) &&
  96. (stages == rhs.stages) && (alignment_A == rhs.alignment_A) &&
  97. (alignment_B == rhs.alignment_B) && (split_k_mode == rhs.split_k_mode);
  98. }
  99. inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); }
  100. inline std::string str() const {
  101. auto tuple_to_str = [](int m, int n, int k) -> std::string {
  102. return std::to_string(m) + " x " + std::to_string(n) + " x " +
  103. std::to_string(k);
  104. };
  105. std::string threadblock_shape_str = tuple_to_str(
  106. threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
  107. std::string warp_shape_str =
  108. tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
  109. std::string instruction_shape_str = tuple_to_str(
  110. instruction_shape_m, instruction_shape_n, instruction_shape_k);
  111. return std::string("{") + "\n element_A: " + to_string(element_A) +
  112. "\n layout_A: " + to_string(layout_A) +
  113. "\n element_B: " + to_string(element_B) +
  114. "\n layout_B: " + to_string(layout_B) +
  115. "\n element_C: " + to_string(element_C) +
  116. "\n layout_C: " + to_string(layout_C) +
  117. "\n element_accumulator: " + to_string(element_accumulator) +
  118. "\n threadblock_shape: " + threadblock_shape_str +
  119. "\n warp_shape: " + warp_shape_str +
  120. "\n instruction_shape: " + instruction_shape_str +
  121. "\n stages: " + std::to_string(stages) +
  122. "\n alignment_A: " + std::to_string(alignment_A) +
  123. "\n alignment_B: " + std::to_string(alignment_B) +
  124. "\n split_k_mode: " + to_string(split_k_mode) + "\n}";
  125. }
  126. };
  127. struct GemmKeyHasher {
  128. inline size_t operator()(GemmKey const& key) const {
  129. return Hash()
  130. .update(&key.element_A, sizeof(key.element_A))
  131. .update(&key.layout_A, sizeof(key.layout_A))
  132. .update(&key.element_B, sizeof(key.element_B))
  133. .update(&key.layout_B, sizeof(key.layout_B))
  134. .update(&key.element_C, sizeof(key.element_C))
  135. .update(&key.layout_C, sizeof(key.layout_C))
  136. .update(&key.element_accumulator, sizeof(key.element_accumulator))
  137. .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
  138. .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
  139. .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
  140. .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
  141. .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
  142. .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
  143. .update(&key.stages, sizeof(key.stages))
  144. .update(&key.alignment_A, sizeof(key.alignment_A))
  145. .update(&key.alignment_B, sizeof(key.alignment_B))
  146. .update(&key.split_k_mode, sizeof(key.split_k_mode))
  147. .digest();
  148. }
  149. };
  150. using GemmOperationMap =
  151. std::unordered_map<GemmKey, std::vector<Operation const*>, GemmKeyHasher>;
  152. /////////////////////////////////////////////////////////////////////////////////////////////////
  153. // Data Structures for ConvolutionOperationMap
  154. /////////////////////////////////////////////////////////////////////////////////////////////////
  155. struct ConvolutionKey {
  156. conv::Operator conv_op;
  157. library::NumericTypeID element_src;
  158. library::LayoutTypeID layout_src;
  159. library::NumericTypeID element_filter;
  160. library::LayoutTypeID layout_filter;
  161. library::NumericTypeID element_dst;
  162. library::LayoutTypeID layout_dst;
  163. library::NumericTypeID element_bias;
  164. library::LayoutTypeID layout_bias;
  165. NumericTypeID element_accumulator;
  166. conv::ConvType convolution_type;
  167. int threadblock_shape_m;
  168. int threadblock_shape_n;
  169. int threadblock_shape_k;
  170. int warp_shape_m;
  171. int warp_shape_n;
  172. int warp_shape_k;
  173. int instruction_shape_m;
  174. int instruction_shape_n;
  175. int instruction_shape_k;
  176. epilogue::EpilogueType epilogue_type;
  177. int stages;
  178. conv::SpecialOptimizeDesc special_optimization;
  179. int alignment_src;
  180. int alignment_filter;
  181. bool without_shared_load;
  182. inline bool operator==(ConvolutionKey const& rhs) const {
  183. return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) &&
  184. (layout_src == rhs.layout_src) &&
  185. (element_filter == rhs.element_filter) &&
  186. (layout_filter == rhs.layout_filter) &&
  187. (element_dst == rhs.element_dst) && (layout_dst == rhs.layout_dst) &&
  188. (element_bias == rhs.element_bias) && (layout_bias == rhs.layout_bias) &&
  189. (element_accumulator == rhs.element_accumulator) &&
  190. (convolution_type == rhs.convolution_type) &&
  191. (threadblock_shape_m == rhs.threadblock_shape_m) &&
  192. (threadblock_shape_n == rhs.threadblock_shape_n) &&
  193. (threadblock_shape_k == rhs.threadblock_shape_k) &&
  194. (warp_shape_m == rhs.warp_shape_m) &&
  195. (warp_shape_n == rhs.warp_shape_n) &&
  196. (warp_shape_k == rhs.warp_shape_k) &&
  197. (instruction_shape_m == rhs.instruction_shape_m) &&
  198. (instruction_shape_n == rhs.instruction_shape_n) &&
  199. (instruction_shape_k == rhs.instruction_shape_k) &&
  200. (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) &&
  201. (special_optimization == rhs.special_optimization) &&
  202. (alignment_src == rhs.alignment_src) &&
  203. (alignment_filter == rhs.alignment_filter) &&
  204. (without_shared_load == rhs.without_shared_load);
  205. }
  206. inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); }
  207. inline std::string str() const {
  208. auto tuple_to_str = [](int m, int n, int k) -> std::string {
  209. return std::to_string(m) + " x " + std::to_string(n) + " x " +
  210. std::to_string(k);
  211. };
  212. std::string threadblock_shape_str = tuple_to_str(
  213. threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
  214. std::string warp_shape_str =
  215. tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
  216. std::string instruction_shape_str = tuple_to_str(
  217. instruction_shape_m, instruction_shape_n, instruction_shape_k);
  218. return std::string("{") + "\n conv_op: " + to_string(conv_op) +
  219. "\n element_src: " + to_string(element_src) +
  220. "\n layout_src: " + to_string(layout_src) +
  221. "\n element_filter: " + to_string(element_filter) +
  222. "\n layout_filter: " + to_string(layout_filter) +
  223. "\n element_dst: " + to_string(element_dst) +
  224. "\n layout_dst: " + to_string(layout_dst) +
  225. "\n element_bias: " + to_string(element_bias) +
  226. "\n layout_bias: " + to_string(layout_bias) +
  227. "\n element_accumulator: " + to_string(element_accumulator) +
  228. "\n convolution_type: " + to_string(convolution_type) +
  229. "\n threadblock_shape: " + threadblock_shape_str +
  230. "\n warp_shape: " + warp_shape_str +
  231. "\n instruction_shape: " + instruction_shape_str +
  232. "\n epilogue_type: " + to_string(epilogue_type) +
  233. "\n stages: " + std::to_string(stages) +
  234. "\n special_optimization: " + to_string(special_optimization) +
  235. "\n alignment_src: " + std::to_string(alignment_src) +
  236. "\n alignment_filter: " + std::to_string(alignment_filter) +
  237. "\n without_shared_load: " + to_string(without_shared_load) + "\n}";
  238. }
  239. };
  240. struct ConvolutionKeyHasher {
  241. inline size_t operator()(ConvolutionKey const& key) const {
  242. return Hash()
  243. .update(&key.conv_op, sizeof(key.conv_op))
  244. .update(&key.element_src, sizeof(key.element_src))
  245. .update(&key.layout_src, sizeof(key.layout_src))
  246. .update(&key.element_filter, sizeof(key.element_filter))
  247. .update(&key.layout_filter, sizeof(key.layout_filter))
  248. .update(&key.element_dst, sizeof(key.element_dst))
  249. .update(&key.layout_dst, sizeof(key.layout_dst))
  250. .update(&key.element_bias, sizeof(key.element_bias))
  251. .update(&key.layout_bias, sizeof(key.layout_bias))
  252. .update(&key.element_accumulator, sizeof(key.element_accumulator))
  253. .update(&key.convolution_type, sizeof(key.convolution_type))
  254. .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
  255. .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
  256. .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
  257. .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
  258. .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
  259. .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
  260. .update(&key.instruction_shape_m, sizeof(key.instruction_shape_m))
  261. .update(&key.instruction_shape_n, sizeof(key.instruction_shape_n))
  262. .update(&key.instruction_shape_k, sizeof(key.instruction_shape_k))
  263. .update(&key.epilogue_type, sizeof(key.epilogue_type))
  264. .update(&key.stages, sizeof(key.stages))
  265. .update(&key.special_optimization, sizeof(key.special_optimization))
  266. .update(&key.alignment_src, sizeof(key.alignment_src))
  267. .update(&key.alignment_filter, sizeof(key.alignment_filter))
  268. .update(&key.without_shared_load, sizeof(key.without_shared_load))
  269. .digest();
  270. }
  271. };
  272. using ConvolutionOperationMap = std::unordered_map<
  273. ConvolutionKey, std::vector<Operation const*>, ConvolutionKeyHasher>;
  274. /////////////////////////////////////////////////////////////////////////////////////////////////
  275. /// Table of cutlass::library::Operation instances
  276. class OperationTable {
  277. public:
  278. /// Map of all operations of type kGemm
  279. GemmOperationMap gemm_operations;
  280. /// Map of all operations of type kConvolution
  281. ConvolutionOperationMap convolution_operations;
  282. public:
  283. void append(Manifest const& manifest);
  284. Operation const* find_op(GemmKey const& key) const;
  285. Operation const* find_op(ConvolutionKey const& key) const;
  286. };
  287. /////////////////////////////////////////////////////////////////////////////////////////////////
  288. } // namespace library
  289. } // namespace cutlass
  290. /////////////////////////////////////////////////////////////////////////////////////////////////