You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chanwise_convolution.cpp 53 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329
  1. /**
  2. * \file dnn/test/cuda/chanwise_convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "cuda.h"
  13. #include "megcore_cuda.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/convolution.h"
  17. #include "test/common/tensor.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/cuda/benchmark.h"
  20. #include "test/cuda/fixture.h"
  21. #include <cuda_profiler_api.h>
  22. #include <cuda_runtime_api.h>
  23. using namespace megdnn;
  24. using namespace test;
  25. namespace {
  26. #if MEGDNN_WITH_BENCHMARK
  27. bool check_need_full_bench() {
  28. if (getenv("MEGDNN_CHANWISE_CONV_FULLBENCH"))
  29. return true;
  30. printf("set MEGDNN_CHANWISE_CONV_FULLBENCH to run full benchmark\n");
  31. return false;
  32. }
  33. #endif
  34. Convolution::Param gconv_param(Convolution::Param p, bool io16xc32 = false) {
  35. p.sparse = Convolution::Param::Sparse::GROUP;
  36. if (io16xc32)
  37. p.compute_mode = Convolution::Param::ComputeMode::FLOAT32;
  38. return p;
  39. }
  40. template <int P0, int P1, int P2>
  41. class BenchmarkEnv {
  42. Handle *handle, *handle_cpu;
  43. std::unique_ptr<GaussianRNG> rng;
  44. TensorLayout lsrc, lflt0, lflt1, ldst;
  45. std::unique_ptr<Tensor<>> src0, src1, flt0, flt0_cpu, flt1, flt1_cpu, dst0, dst1;
  46. cudaEvent_t cuda_ev[3];
  47. cudaStream_t cuda_stream;
  48. size_t pad_h, pad_w;
  49. template <typename T>
  50. static std::tuple<T, T, T> shuffle(std::tuple<T, T, T> data) {
  51. return std::make_tuple(
  52. std::get<P0>(data), std::get<P1>(data), std::get<P2>(data));
  53. }
  54. public:
  55. BenchmarkEnv(Handle* handle, Handle* handle_cpu) {
  56. this->handle = handle;
  57. this->handle_cpu = handle_cpu;
  58. rng = handle->create_operator<GaussianRNG>();
  59. // make cpu handle used
  60. handle_cpu->create_operator<Sleep>()->exec();
  61. for (int i = 0; i < 3; ++i)
  62. cudaEventCreate(&cuda_ev[i]);
  63. megcoreGetCUDAStream(handle->megcore_computing_handle(), &cuda_stream);
  64. }
  65. ~BenchmarkEnv() {
  66. for (int i = 0; i < 3; ++i)
  67. cudaEventDestroy(cuda_ev[i]);
  68. }
  69. void alloc(
  70. size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  71. size_t FW, size_t PH, size_t PW) {
  72. pad_h = PH;
  73. pad_w = PW;
  74. auto mkly = [](const TensorShape& s) {
  75. return TensorLayout{s, dtype::Float32()};
  76. };
  77. lsrc = mkly({N, IC, IH, IW});
  78. lflt0 = mkly({CHL_MUL * IC, IC, FH, FW});
  79. lflt1 = mkly({IC, CHL_MUL, 1, FH, FW});
  80. ldst = mkly({N, IC * CHL_MUL, IH - FH + 1 + PH * 2, IW - FW + 1 + PW * 2});
  81. src0.reset(new Tensor<>(handle, lsrc));
  82. src1.reset(new Tensor<>(handle, lsrc));
  83. flt0.reset(new Tensor<>(handle, lflt0));
  84. flt0_cpu.reset(new Tensor<>(handle_cpu, lflt0));
  85. flt1.reset(new Tensor<>(handle, lflt1));
  86. flt1_cpu.reset(new Tensor<>(handle_cpu, lflt1));
  87. dst0.reset(new Tensor<>(handle, ldst));
  88. dst1.reset(new Tensor<>(handle, ldst));
  89. }
  90. void fill_src() {
  91. rng->exec(src0->tensornd(), {});
  92. megdnn_memcpy_D2D(handle, src1->ptr(), src0->ptr(), lsrc.span().dist_byte());
  93. }
  94. void fill_flt() {
  95. rng->exec(flt1->tensornd(), {});
  96. megdnn_memcpy_D2H(
  97. handle, flt1_cpu->ptr(), flt1->ptr(), lflt1.span().dist_byte());
  98. const size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  99. // fill flt0 from flt1
  100. float* src = flt1_cpu->ptr();
  101. float* dst = flt0_cpu->ptr();
  102. memset(dst, 0, lflt0.span().dist_byte());
  103. for (size_t i = 0; i < IC; ++i) {
  104. for (size_t j = 0; j < CHL_MUL; ++j) {
  105. memcpy(dst + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  106. src + (i * CHL_MUL + j) * FSIZE, FSIZE * sizeof(float));
  107. }
  108. }
  109. megdnn_memcpy_H2D(handle, flt0->ptr(), dst, lflt0.span().dist_byte());
  110. }
  111. void fill_dst() {
  112. rng->exec(dst0->tensornd(), {});
  113. megdnn_memcpy_D2D(handle, dst1->ptr(), dst0->ptr(), ldst.span().dist_byte());
  114. }
  115. template <class Opr>
  116. void exec(Opr* opr0, Opr* opr1) {
  117. opr0->param().pad_h = pad_h;
  118. opr0->param().pad_w = pad_w;
  119. opr1->param() = opr0->param();
  120. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  121. TensorND a0, b0, c0, a1, b1, c1;
  122. std::tie(a0, b0, c0) = shuffle(
  123. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  124. std::tie(a1, b1, c1) = shuffle(
  125. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  126. WorkspaceWrapper wk(
  127. handle,
  128. std::max(
  129. opr0->get_workspace_in_bytes(a0.layout, b0.layout, c0.layout),
  130. opr1->get_workspace_in_bytes(a1.layout, b1.layout, c1.layout)));
  131. cudaProfilerStart();
  132. cudaEventRecord(cuda_ev[0], cuda_stream);
  133. opr0->exec(a0, b0, c0, wk.workspace());
  134. cudaEventRecord(cuda_ev[1], cuda_stream);
  135. opr1->exec(a1, b1, c1, wk.workspace());
  136. cudaEventRecord(cuda_ev[2], cuda_stream);
  137. cudaProfilerStop();
  138. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  139. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  140. cudaStreamSynchronize(cuda_stream);
  141. float t0 = -1, t1 = -1;
  142. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  143. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  144. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  145. lsrc.TensorShape::to_string().c_str(),
  146. lflt1.TensorShape::to_string().c_str(),
  147. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  148. }
  149. }
  150. //! special for weight preprocess
  151. void exec_convolution(ConvolutionForward* opr0, ConvolutionForward* opr1) {
  152. opr0->param().pad_h = pad_h;
  153. opr0->param().pad_w = pad_w;
  154. opr1->param() = opr0->param();
  155. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  156. TensorND a0, b0, c0, a1, b1, c1;
  157. std::tie(a0, b0, c0) = shuffle(
  158. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  159. std::tie(a1, b1, c1) = shuffle(
  160. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  161. WorkspaceWrapper wk(
  162. handle, std::max(
  163. opr0->get_workspace_in_bytes(
  164. a0.layout, b0.layout, c0.layout, nullptr),
  165. opr1->get_workspace_in_bytes(
  166. a1.layout, b1.layout, c1.layout, nullptr)));
  167. cudaProfilerStart();
  168. cudaEventRecord(cuda_ev[0], cuda_stream);
  169. opr0->exec(a0, b0, c0, nullptr, wk.workspace());
  170. cudaEventRecord(cuda_ev[1], cuda_stream);
  171. opr1->exec(a1, b1, c1, nullptr, wk.workspace());
  172. cudaEventRecord(cuda_ev[2], cuda_stream);
  173. cudaProfilerStop();
  174. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  175. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  176. cudaStreamSynchronize(cuda_stream);
  177. float t0 = -1, t1 = -1;
  178. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  179. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  180. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  181. lsrc.TensorShape::to_string().c_str(),
  182. lflt1.TensorShape::to_string().c_str(),
  183. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  184. }
  185. }
  186. void cmp_dst() {
  187. Tensor<> dst0_cpu(handle_cpu, ldst), dst1_cpu(handle_cpu, ldst);
  188. megdnn_memcpy_D2H(handle, dst0_cpu.ptr(), dst0->ptr(), ldst.span().dist_byte());
  189. megdnn_memcpy_D2H(handle, dst1_cpu.ptr(), dst1->ptr(), ldst.span().dist_byte());
  190. dst0_cpu.check_with(dst1_cpu);
  191. }
  192. void cmp_src() {
  193. Tensor<> src0_cpu(handle_cpu, lsrc), src1_cpu(handle_cpu, lsrc);
  194. megdnn_memcpy_D2H(handle, src0_cpu.ptr(), src0->ptr(), lsrc.span().dist_byte());
  195. megdnn_memcpy_D2H(handle, src1_cpu.ptr(), src1->ptr(), lsrc.span().dist_byte());
  196. src0_cpu.check_with(src1_cpu);
  197. }
  198. void cmp_flt() {
  199. Tensor<> flt0_cpu(handle_cpu, lflt0), flt1_cpu(handle_cpu, lflt1);
  200. float* p0 = flt0_cpu.ptr();
  201. float* p1 = flt1_cpu.ptr();
  202. megdnn_memcpy_D2H(handle, p0, flt0->ptr(), lflt0.span().dist_byte());
  203. megdnn_memcpy_D2H(handle, p1, flt1->ptr(), lflt1.span().dist_byte());
  204. size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  205. double tot_err = 0, tot_err_num = 0;
  206. for (size_t i = 0; i < IC; ++i) {
  207. for (size_t j = 0; j < CHL_MUL; ++j) {
  208. auto t0 = p0 + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  209. t1 = p1 + (i * CHL_MUL + j) * FSIZE;
  210. for (size_t k = 0; k < FSIZE; ++k) {
  211. auto err = std::abs(diff(t0[k], t1[k]));
  212. tot_err += err;
  213. tot_err_num += 1;
  214. ASSERT_LT(err, 1e-2) << "failed at " << i << " " << j << " " << k
  215. << " vals=" << t0[k] << "," << t1[k];
  216. }
  217. }
  218. }
  219. auto avg_err = tot_err / tot_err_num;
  220. ASSERT_LT(avg_err, 1e-4);
  221. }
  222. };
  223. } // anonymous namespace
  224. constexpr auto M = Convolution::Mode::CROSS_CORRELATION;
  225. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD) {
  226. Checker<Convolution> checker(handle_cuda());
  227. bool require_algo = false;
  228. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  229. ExecutionPolicyAlgoName{
  230. "DEFAULT",
  231. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  232. "CHANNEL_WISE", {})
  233. .c_str(),
  234. {}}}},
  235. &require_algo));
  236. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  237. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  238. if (dtype.enumv() == DTypeEnum::Float16)
  239. checker.set_epsilon(2e-2);
  240. // simple case
  241. // clang-format off
  242. for (uint32_t s : {1, 2})
  243. for (uint32_t p : {0, 1, 2, 3})
  244. for (size_t f : {2, 3, 5, 7})
  245. for (size_t ocpg : {1, 3}) {
  246. checker.set_param(gconv_param({M, p, p, s, s}))
  247. .execs({{2, 3, 16, 16}, {3, ocpg, 1, f, f}, {}});
  248. }
  249. // clang-format on
  250. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  251. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  252. // padding larger than kern
  253. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  254. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  255. }
  256. }
  257. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_SMALL) {
  258. Checker<Convolution> checker(handle_cuda());
  259. bool require_algo = false;
  260. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  261. ExecutionPolicyAlgoName{
  262. "DEFAULT",
  263. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  264. "CHANNEL_WISE_SMALL", {})
  265. .c_str(),
  266. {}}}},
  267. &require_algo));
  268. for (auto dtype : std::vector<DType> {
  269. dtype::Float32(),
  270. #if CUDA_VERSION >= 9000
  271. dtype::Float16()
  272. #endif
  273. }) {
  274. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  275. if (dtype.enumv() == DTypeEnum::Float16)
  276. checker.set_epsilon(2e-2);
  277. // clang-format off
  278. for (uint32_t s : {1})
  279. for (uint32_t f : {1, 3, 5, 7}) {
  280. checker.set_param(gconv_param({M, f / 2, f / 2, s, s}))
  281. .execs({{2, 3, 16, 16}, {3, 1, 1, f, f}, {}});
  282. }
  283. // clang-format on
  284. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  285. .execs({{2, 3, 3, 16}, {3, 1, 1, 3, 3}, {}})
  286. .execs({{2, 3, 8, 3}, {3, 1, 1, 3, 3}, {}});
  287. }
  288. }
  289. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA) {
  290. Checker<ConvolutionBackwardData> checker(handle_cuda());
  291. bool require_algo = false;
  292. checker.set_before_exec_callback(
  293. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE", &require_algo));
  294. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  295. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  296. if (dtype.enumv() == DTypeEnum::Float16)
  297. checker.set_epsilon(1e-1);
  298. // simple case
  299. // clang-format off
  300. for (uint32_t s : {1, 2})
  301. for (uint32_t p : {0, 1, 2, 3})
  302. for (size_t f : {1, 2, 3, 5, 7})
  303. for (size_t ocpg : {1, 3}) {
  304. size_t ii = infer_conv_shape(16, f, s, p, true);
  305. checker.set_param(gconv_param({M, p, p, s, s}))
  306. .execs({{3, ocpg, 1, f, f},
  307. {2, 3 * ocpg, ii, ii},
  308. {2, 3, 16, 16}});
  309. }
  310. // clang-format on
  311. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  312. .execs({{12, 3, 1, 4, 5}, {32, 36, 20, 10}, {32, 12, 39, 8}});
  313. checker.set_param(gconv_param({M, 30, 20, 5, 4}))
  314. .execs({{6, 2, 1, 5, 4}, {32, 12, 12, 10}, {32, 6, 3, 2}});
  315. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  316. .execs({{6, 2, 1, 4, 5}, {32, 12, 10, 12}, {32, 6, 2, 3}});
  317. }
  318. }
  319. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_SMALL) {
  320. Checker<ConvolutionBackwardData> checker(handle_cuda());
  321. bool require_algo = false;
  322. checker.set_before_exec_callback(
  323. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL", &require_algo));
  324. for (auto dtype : std::vector<DType> {
  325. dtype::Float32(),
  326. #if CUDA_VERSION >= 9000
  327. dtype::Float16()
  328. #endif
  329. }) {
  330. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  331. if (dtype.enumv() == DTypeEnum::Float16)
  332. checker.set_epsilon(2e-2);
  333. for (uint32_t f : {1, 3, 5, 7}) {
  334. checker.set_param(gconv_param({M, f / 2, f / 2, 1, 1}))
  335. .execs({{3, 1, 1, f, f}, {2, 3, 16, 16}, {2, 3, 16, 16}});
  336. }
  337. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  338. .execs({{3, 1, 1, 3, 3}, {2, 3, 3, 16}, {2, 3, 3, 16}})
  339. .execs({{3, 1, 1, 3, 3}, {2, 3, 8, 3}, {2, 3, 8, 3}});
  340. }
  341. }
  342. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER) {
  343. Checker<ConvolutionBackwardFilter> checker(handle_cuda());
  344. bool require_algo = false;
  345. checker.set_before_exec_callback(
  346. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE", &require_algo));
  347. UniformFloatRNG rng(-0.1, 0.1);
  348. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  349. checker.set_dtype(0, dtype)
  350. .set_dtype(1, dtype)
  351. .set_dtype(2, dtype)
  352. .set_rng(0, &rng)
  353. .set_rng(1, &rng);
  354. if (dtype.enumv() == DTypeEnum::Float16)
  355. checker.set_epsilon(2e-1);
  356. // simple case
  357. // clang-format off
  358. for (uint32_t s : {1, 2})
  359. for (uint32_t p : {0, 1, 2, 3})
  360. for (uint32_t f : {1, 2, 3, 5, 7})
  361. for (uint32_t ocpg : {1, 3})
  362. for (uint32_t i : {8, 16, 32, 64}){
  363. size_t ii = infer_conv_shape(i, f, s, p, true);
  364. checker.set_param(gconv_param({M, p, p, s, s}))
  365. .execs({{2, 3, i, i},
  366. {2, 3 * ocpg, ii, ii},
  367. {3, ocpg, 1, f, f}});
  368. }
  369. // clang-format on
  370. // padding larger than kern
  371. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  372. .execs({{32, 6, 2, 3}, {32, 12, 10, 12}, {6, 2, 1, 4, 5}});
  373. // unused filter items
  374. checker.set_param(gconv_param({M, 2, 3, 2, 3}))
  375. .execs({{32, 6, 1, 1}, {32, 12, 1, 1}, {6, 2, 1, 5, 7}});
  376. }
  377. }
  378. namespace {
  379. template <typename Op>
  380. struct AlgoCheckerMaker {
  381. static auto make(const char* name, bool* require_algo) {
  382. return AlgoChecker<Op>(name, require_algo);
  383. }
  384. };
  385. template <>
  386. struct AlgoCheckerMaker<ConvolutionForward> {
  387. static auto make(const char* name, bool* require_algo) {
  388. return AlgoChecker<ConvolutionForward>(
  389. ExecutionPolicyAlgoName{
  390. "DEFAULT",
  391. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  392. name, {})
  393. .c_str(),
  394. {}}}},
  395. require_algo);
  396. }
  397. };
  398. template <typename Op>
  399. void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char* name) {
  400. Checker<Op> checker(handle);
  401. bool require_algo = false;
  402. checker.set_before_exec_callback(AlgoCheckerMaker<Op>::make(name, &require_algo));
  403. checker.set_dtype(0, io_type).set_dtype(1, io_type).set_dtype(2, io_type);
  404. bool io16xc32 = false;
  405. if (io_type == dtype::Float16()) {
  406. if (comp_type == dtype::Float16()) {
  407. checker.set_epsilon(1e-1);
  408. } else {
  409. io16xc32 = true;
  410. }
  411. }
  412. // dispatch testcase by operation
  413. if (std::is_same<Op, ConvolutionForward>::value) {
  414. // align 8
  415. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  416. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  417. // align 1
  418. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  419. .execs({{8, 2, 15, 15}, {2, 1, 1, 15, 15}, {}});
  420. // align 2
  421. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  422. .execs({{8, 2, 14, 14}, {2, 1, 1, 15, 15}, {}});
  423. // custom padding
  424. checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
  425. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  426. // custom stride
  427. checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
  428. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  429. } else if (std::is_same<Op, ConvolutionBackwardData>::value) {
  430. // align 8
  431. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  432. .execs({{2, 1, 1, 15, 15}, {8, 2, 16, 16}, {8, 2, 16, 16}});
  433. // align 1
  434. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  435. .execs({{2, 1, 1, 15, 15}, {8, 2, 15, 15}, {8, 2, 15, 15}});
  436. // align 2
  437. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  438. .execs({{2, 1, 1, 15, 15}, {8, 2, 14, 14}, {8, 2, 14, 14}});
  439. // custom padding
  440. checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
  441. .execs({{2, 1, 1, 15, 15}, {8, 2, 8, 8}, {8, 2, 16, 16}});
  442. // custom stride
  443. checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
  444. .execs({{2, 1, 1, 15, 15}, {8, 2, 7, 7}, {8, 2, 14, 14}});
  445. } else if (std::is_same<Op, ConvolutionBackwardFilter>::value) {
  446. }
  447. }
  448. } // namespace
  449. #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb) \
  450. cb(1, 128, 128, 8, 32, 64, 8); \
  451. cb(2, 128, 64, 8, 64, 32, 8); \
  452. cb(3, 128, 32, 8, 64, 32, 8); \
  453. cb(4, 64, 128, 8, 64, 32, 8); \
  454. cb(5, 32, 128, 8, 32, 64, 8); \
  455. cb(6, 64, 64, 8, 64, 32, 8); \
  456. cb(7, 32, 64, 8, 32, 64, 8); \
  457. cb(8, 32, 32, 8, 32, 32, 8); \
  458. cb(9, 64, 32, 8, 64, 32, 8);
  459. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  460. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_FMA_##tag) { \
  461. check_chanwise<ConvolutionForward>( \
  462. dtype::Float32(), dtype::Float32(), handle_cuda(), \
  463. "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  464. "_" #wm "X" #wn "X" #wk "_2stage"); \
  465. }
  466. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL(cb)
  467. #undef cb
  468. #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_FMA_KERNEL
  469. #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb) \
  470. cb(1, 128, 128, 32, 32, 32, 32); \
  471. cb(2, 128, 256, 32, 64, 64, 32); \
  472. cb(3, 128, 64, 32, 32, 32, 32); \
  473. cb(4, 64, 128, 32, 32, 32, 32); \
  474. cb(5, 64, 64, 32, 32, 32, 32);
  475. // check both ioc16 and io16xc32
  476. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  477. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_HMMA_##tag) { \
  478. check_chanwise<ConvolutionForward>( \
  479. dtype::Float16(), dtype::Float16(), handle_cuda(), \
  480. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  481. "_" #wm "X" #wn "X" #wk "_2stage"); \
  482. check_chanwise<ConvolutionForward>( \
  483. dtype::Float16(), dtype::Float32(), handle_cuda(), \
  484. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  485. "_" #wm "X" #wn "X" #wk "_2stage"); \
  486. }
  487. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL(cb)
  488. #undef cb
  489. #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL
  490. #if MEGDNN_WITH_BENCHMARK
  491. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_BENCH_CHECK) {
  492. auto handle = handle_cuda();
  493. auto handle_cpu = handle_naive();
  494. auto conv0 = handle->create_operator<ConvolutionForward>();
  495. auto conv1 = handle->create_operator<ConvolutionForward>();
  496. BenchmarkEnv<0, 1, 2> benv(handle, handle_cpu);
  497. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  498. size_t FW, size_t PH, size_t PW) {
  499. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  500. benv.fill_src();
  501. benv.fill_flt();
  502. benv.exec_convolution(conv0.get(), conv1.get());
  503. benv.cmp_dst();
  504. };
  505. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  506. if (check_need_full_bench()) {
  507. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  508. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  509. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  510. }
  511. }
  512. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_DATA_BENCH_CHECK) {
  513. auto handle = handle_cuda();
  514. auto handle_cpu = handle_naive();
  515. auto conv0 = handle->create_operator<ConvolutionBackwardData>();
  516. auto conv1 = handle->create_operator<ConvolutionBackwardData>();
  517. BenchmarkEnv<1, 2, 0> benv(handle, handle_cpu);
  518. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  519. size_t FW, size_t PH, size_t PW) {
  520. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  521. benv.fill_dst();
  522. benv.fill_flt();
  523. benv.exec(conv0.get(), conv1.get());
  524. benv.cmp_src();
  525. };
  526. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  527. if (check_need_full_bench()) {
  528. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  529. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  530. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  531. }
  532. }
  533. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_FILTER_BENCH_CHECK) {
  534. auto handle = handle_cuda();
  535. auto handle_cpu = handle_naive();
  536. auto conv0 = handle->create_operator<ConvolutionBackwardFilter>();
  537. auto conv1 = handle->create_operator<ConvolutionBackwardFilter>();
  538. BenchmarkEnv<0, 2, 1> benv(handle, handle_cpu);
  539. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  540. size_t FW, size_t PH, size_t PW) {
  541. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  542. benv.fill_src();
  543. benv.fill_dst();
  544. benv.exec(conv0.get(), conv1.get());
  545. benv.cmp_flt();
  546. };
  547. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  548. if (check_need_full_bench()) {
  549. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  550. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  551. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  552. }
  553. }
  554. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_FWD) {
  555. // enable profiling
  556. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  557. new OprProxy<ConvolutionForward>{true}};
  558. proxy->warmup_times = 1;
  559. proxy->exec_times = 10;
  560. Benchmarker<ConvolutionForward> checker(handle_cuda());
  561. checker.set_times(1);
  562. ConvolutionForward::Param param;
  563. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  564. checker.set_param(param);
  565. checker.set_proxy(proxy);
  566. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  567. checker.proxy()->target_execution_policy = {};
  568. checker.execs({{N, C, IH, IW}, {C, 1, 1, FH, FW}, {}});
  569. };
  570. run(128, 64, 90, 80, 3, 3);
  571. run(128, 90, 100, 100, 3, 5);
  572. run(128, 32, 62, 62, 5, 5);
  573. }
  574. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_DATA) {
  575. // enable profiling
  576. std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
  577. new OprProxy<ConvolutionBackwardData>{true}};
  578. proxy->warmup_times = 1;
  579. proxy->exec_times = 10;
  580. Benchmarker<ConvolutionBackwardData> checker(handle_cuda());
  581. checker.set_times(1);
  582. ConvolutionBackwardData::Param param;
  583. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  584. checker.set_param(param);
  585. checker.set_proxy(proxy);
  586. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  587. checker.proxy()->target_execution_policy.algo.reset();
  588. checker.execs(
  589. {{C, 1, 1, FH, FW}, {N, C, IH - FH + 1, IW - FW + 1}, {N, C, IH, IW}});
  590. };
  591. run(128, 64, 90, 80, 3, 3);
  592. run(128, 90, 100, 100, 3, 5);
  593. run(128, 32, 62, 62, 5, 5);
  594. }
  595. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_FILTER) {
  596. // enable profiling
  597. std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
  598. new OprProxy<ConvolutionBackwardFilter>{true}};
  599. proxy->warmup_times = 1;
  600. proxy->exec_times = 10;
  601. Benchmarker<ConvolutionBackwardFilter> checker(handle_cuda());
  602. checker.set_times(1);
  603. ConvolutionBackwardFilter::Param param;
  604. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  605. checker.set_param(param);
  606. checker.set_proxy(proxy);
  607. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  608. checker.proxy()->target_execution_policy.algo.reset();
  609. checker.execs(
  610. {{N, C, IH, IW}, {N, C, IH - FH + 1, IW - FW + 1}, {C, 1, 1, FH, FW}});
  611. };
  612. run(128, 64, 90, 80, 3, 3);
  613. run(128, 90, 100, 100, 3, 5);
  614. run(128, 32, 62, 62, 5, 5);
  615. }
  616. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_ALL_ALGO_FORWARD) {
  617. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  618. size_t RUNS = 10;
  619. bencher.set_display(false).set_times(RUNS);
  620. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  621. new OprProxy<ConvolutionForward>{true}};
  622. bencher.set_proxy(proxy);
  623. Convolution::Param param;
  624. param.format = ConvBias::Param::Format::NCHW;
  625. param.sparse = Convolution::Param::Sparse::GROUP;
  626. NormalRNG rng;
  627. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  628. param.pad_h = f / 2;
  629. param.pad_w = f / 2;
  630. param.stride_h = s;
  631. param.stride_w = s;
  632. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  633. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  634. TensorLayout dst_layout;
  635. auto opr = handle_cuda()->create_operator<Convolution>();
  636. opr->param() = param;
  637. opr->deduce_layout(
  638. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  639. float bandwith = static_cast<float>(
  640. src.total_nr_elems() + filter.total_nr_elems() +
  641. dst_layout.total_nr_elems()) /
  642. (1024 * 1024 * 1024) * 1e3;
  643. bencher.set_param(param)
  644. .set_dtype(0, dtype::Float32())
  645. .set_dtype(1, dtype::Float32())
  646. .set_dtype(2, dtype::Float32())
  647. .set_rng(0, &rng)
  648. .set_rng(1, &rng);
  649. bencher.proxy()->target_execution_policy = {};
  650. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  651. bencher.set_param(param)
  652. .set_dtype(0, dtype::Float16())
  653. .set_dtype(1, dtype::Float16())
  654. .set_dtype(2, dtype::Float16())
  655. .set_rng(0, &rng)
  656. .set_rng(1, &rng);
  657. bencher.proxy()->target_execution_policy = {};
  658. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  659. bencher.proxy()->target_execution_policy.algo.reset();
  660. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  661. bencher.set_param(param);
  662. auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  663. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  664. "float16: %.2fms %.2fGB/s "
  665. "pseudo float16: %.2fms %.2fGB/s "
  666. "speedup: "
  667. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  668. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  669. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  670. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  671. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  672. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  673. };
  674. // clang-format off
  675. for (size_t s : {1, 2})
  676. for (size_t f : {3, 5, 7})
  677. for (size_t batch : {64})
  678. for (size_t c : {16, 32, 64, 128})
  679. for (size_t ih: {128, 256})
  680. for (size_t iw : {128, 256})
  681. run(batch, c, ih, iw, f, s);
  682. // clang-format on
  683. run(128, 192, 28, 28, 3, 1);
  684. run(128, 192, 28, 28, 3, 2);
  685. run(128, 576, 14, 14, 3, 1);
  686. run(128, 384, 14, 14, 3, 1);
  687. run(128, 32, 112, 112, 3, 1);
  688. run(128, 960, 7, 7, 3, 1);
  689. run(128, 384, 14, 14, 3, 1);
  690. run(128, 144, 56, 56, 3, 2);
  691. run(128, 384, 14, 14, 3, 1);
  692. run(128, 144, 56, 56, 3, 1);
  693. run(128, 96, 112, 112, 3, 2);
  694. run(128, 384, 14, 14, 3, 1);
  695. run(128, 192, 28, 28, 3, 1);
  696. run(128, 576, 14, 14, 3, 1);
  697. run(128, 576, 14, 14, 3, 2);
  698. }
  699. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT) {
  700. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  701. size_t RUNS = 1;
  702. bencher.set_display(false).set_times(RUNS);
  703. bencher.set_before_exec_callback(
  704. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  705. "DEFAULT",
  706. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  707. "CHANNEL_WISE", {})
  708. .c_str(),
  709. {}}}}));
  710. Convolution::Param param;
  711. param.format = ConvBias::Param::Format::NCHW;
  712. param.sparse = Convolution::Param::Sparse::GROUP;
  713. NormalRNG rng;
  714. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  715. param.pad_h = f / 2;
  716. param.pad_w = f / 2;
  717. param.stride_h = s;
  718. param.stride_w = s;
  719. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  720. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  721. TensorLayout dst_layout;
  722. auto opr = handle_cuda()->create_operator<Convolution>();
  723. opr->param() = param;
  724. opr->deduce_layout(
  725. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  726. float bandwith = static_cast<float>(
  727. src.total_nr_elems() + filter.total_nr_elems() +
  728. dst_layout.total_nr_elems()) /
  729. (1024 * 1024 * 1024) * 1e3;
  730. bencher.set_param(param)
  731. .set_dtype(0, dtype::Float32())
  732. .set_dtype(1, dtype::Float32())
  733. .set_dtype(2, dtype::Float32())
  734. .set_rng(0, &rng)
  735. .set_rng(1, &rng);
  736. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  737. bencher.set_param(param)
  738. .set_dtype(0, dtype::Float16())
  739. .set_dtype(1, dtype::Float16())
  740. .set_dtype(2, dtype::Float16())
  741. .set_rng(0, &rng)
  742. .set_rng(1, &rng);
  743. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  744. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  745. "float16: %.2fms %.2fGB/s "
  746. "speedup: "
  747. "%0.2f (fp16/fp32)\n",
  748. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  749. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  750. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  751. };
  752. // clang-format off
  753. for (size_t s : {1})
  754. for (size_t f : {3, 5, 7})
  755. for (size_t batch : {64})
  756. for (size_t c : {16, 32, 64, 128})
  757. for (size_t ih: {8, 16, 32, 128, 256})
  758. for (size_t iw : {8, 16, 32, 128, 256})
  759. run(batch, c, ih, iw, f, s);
  760. // clang-format on
  761. }
  762. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT_SMALL) {
  763. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  764. size_t RUNS = 1;
  765. bencher.set_display(false).set_times(RUNS);
  766. Convolution::Param param;
  767. param.format = ConvBias::Param::Format::NCHW;
  768. param.sparse = Convolution::Param::Sparse::GROUP;
  769. NormalRNG rng;
  770. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  771. param.pad_h = f / 2;
  772. param.pad_w = f / 2;
  773. param.stride_h = s;
  774. param.stride_w = s;
  775. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  776. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  777. TensorLayout dst_layout;
  778. auto opr = handle_cuda()->create_operator<Convolution>();
  779. opr->param() = param;
  780. opr->deduce_layout(
  781. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  782. float bandwith = static_cast<float>(
  783. src.total_nr_elems() + filter.total_nr_elems() +
  784. dst_layout.total_nr_elems()) /
  785. (1024 * 1024 * 1024) * 1e3;
  786. bencher.set_param(param)
  787. .set_dtype(0, dtype::Float32())
  788. .set_dtype(1, dtype::Float32())
  789. .set_dtype(2, dtype::Float32())
  790. .set_rng(0, &rng)
  791. .set_rng(1, &rng)
  792. .set_before_exec_callback(AlgoChecker<
  793. ConvolutionForward>(ExecutionPolicyAlgoName{
  794. "DEFAULT",
  795. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  796. "CHANNEL_WISE", {})
  797. .c_str(),
  798. {}}}}));
  799. auto time_in_ms_fp32_normal = bencher.execs({src, filter, {}}) / RUNS;
  800. bencher.set_before_exec_callback(
  801. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  802. "DEFAULT",
  803. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  804. "CHANNEL_WISE", {})
  805. .c_str(),
  806. {}}}}));
  807. auto time_in_ms_fp32_small = bencher.execs({src, filter, {}}) / RUNS;
  808. bencher.set_param(param)
  809. .set_dtype(0, dtype::Float16())
  810. .set_dtype(1, dtype::Float16())
  811. .set_dtype(2, dtype::Float16())
  812. .set_rng(0, &rng)
  813. .set_rng(1, &rng);
  814. auto time_in_ms_fp16_small = bencher.execs({src, filter, {}}) / RUNS;
  815. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  816. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  817. "speedup: "
  818. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  819. s, src.to_string().c_str(), filter.to_string().c_str(),
  820. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  821. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  822. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  823. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  824. time_in_ms_fp32_small / time_in_ms_fp16_small);
  825. };
  826. // clang-format off
  827. for (size_t s : {1})
  828. for (size_t f : {3, 5})
  829. for (size_t batch : {64})
  830. for (size_t c : {16, 32, 64, 128})
  831. for (size_t ih: {8, 16, 32})
  832. for (size_t iw : {8, 16, 32})
  833. run(batch, c, ih, iw, f, s);
  834. // clang-format on
  835. run(128, 192, 28, 28, 3, 1);
  836. run(128, 576, 14, 14, 3, 1);
  837. run(128, 384, 14, 14, 3, 1);
  838. run(128, 960, 7, 7, 3, 1);
  839. run(128, 384, 14, 14, 3, 1);
  840. run(128, 384, 14, 14, 3, 1);
  841. run(128, 384, 14, 14, 3, 1);
  842. run(128, 192, 28, 28, 3, 1);
  843. run(128, 576, 14, 14, 3, 1);
  844. }
  845. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_CUDNN_DNN) {
  846. CUBenchmarker<ConvBiasForward> bencher(handle_cuda());
  847. size_t RUNS = 1;
  848. bencher.set_display(false).set_times(RUNS);
  849. ConvBias::Param param;
  850. param.format = ConvBias::Param::Format::NCHW;
  851. param.sparse = ConvBias::Param::Sparse::GROUP;
  852. NormalRNG rng;
  853. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  854. param.pad_h = f / 2;
  855. param.pad_w = f / 2;
  856. param.stride_h = s;
  857. param.stride_w = s;
  858. param.compute_mode = param::ConvBias::ComputeMode::DEFAULT;
  859. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f},
  860. bias = {1, c, 1, 1};
  861. TensorLayout dst_layout;
  862. auto opr = handle_cuda()->create_operator<ConvBias>();
  863. opr->param() = param;
  864. opr->deduce_layout(
  865. {src, dtype::Float32()}, {filter, dtype::Float32()},
  866. {bias, dtype::Float32()}, {}, dst_layout);
  867. float computation_mops =
  868. static_cast<float>(dst_layout.total_nr_elems() * f * f * 2) * 1e-6;
  869. bencher.set_param(param)
  870. .set_dtype(0, dtype::Float32())
  871. .set_dtype(1, dtype::Float32())
  872. .set_dtype(2, dtype::Float32())
  873. .set_rng(0, &rng)
  874. .set_rng(1, &rng);
  875. bencher.set_before_exec_callback(
  876. AlgoChecker<ConvBiasForward>(".+CHANNEL_WISE.+"));
  877. auto time_in_ms_dnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  878. bencher.set_param(param)
  879. .set_dtype(0, dtype::Float32())
  880. .set_dtype(1, dtype::Float32())
  881. .set_dtype(2, dtype::Float32())
  882. .set_rng(0, &rng)
  883. .set_rng(1, &rng);
  884. bencher.set_before_exec_callback(AlgoChecker<ConvBiasForward>(
  885. ".+CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM.+"));
  886. auto time_in_ms_cudnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  887. printf("stride=%zu src=%s, filter=%s, dst=%s, dnn: %.2fms %.2fGB/s "
  888. "cudnn: %.2fms %.2fGB/s "
  889. "speedup: "
  890. "%0.2f (dnn/cudnn)\n",
  891. s, src.to_string().c_str(), filter.to_string().c_str(),
  892. dst_layout.to_string().c_str(), time_in_ms_dnn,
  893. computation_mops / time_in_ms_dnn, time_in_ms_cudnn,
  894. computation_mops / time_in_ms_cudnn, time_in_ms_cudnn / time_in_ms_dnn);
  895. };
  896. // clang-format off
  897. for(size_t batch:{1, 16, 32, 64, 128}){
  898. run(batch, 32, 112, 112, 3, 1);
  899. run(batch, 96, 112, 112, 3, 2);
  900. run(batch, 96, 112, 112, 3, 1);
  901. run(batch, 144, 56, 56, 3, 2);
  902. run(batch, 144, 56, 56, 3, 1);
  903. run(batch, 192, 28, 28, 3, 1);
  904. run(batch, 384, 14, 14, 3, 1);
  905. run(batch, 576, 14, 14, 3, 1);
  906. run(batch, 960, 7, 7, 3, 1);
  907. //! calibrate heu algo policy hw_size param
  908. run(batch, 144, 24, 24, 3, 1);
  909. run(batch, 144, 22, 22, 3, 1);
  910. run(batch, 144, 20, 20, 3, 1);
  911. run(batch, 144, 18, 18, 3, 1);
  912. }
  913. // clang-format on
  914. }
  915. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_FLOAT_SMALL) {
  916. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  917. size_t RUNS = 1;
  918. bencher.set_display(false).set_times(RUNS);
  919. ConvolutionBackwardData::Param param;
  920. param.format = Convolution::Param::Format::NCHW;
  921. param.sparse = Convolution::Param::Sparse::GROUP;
  922. NormalRNG rng;
  923. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  924. param.pad_h = f / 2;
  925. param.pad_w = f / 2;
  926. param.stride_h = s;
  927. param.stride_w = s;
  928. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  929. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  930. float bandwith = static_cast<float>(
  931. src.total_nr_elems() + filter.total_nr_elems() +
  932. src.total_nr_elems()) /
  933. (1024 * 1024 * 1024) * 1e3;
  934. bencher.set_param(param)
  935. .set_dtype(0, dtype::Float32())
  936. .set_dtype(1, dtype::Float32())
  937. .set_dtype(2, dtype::Float32())
  938. .set_rng(0, &rng)
  939. .set_rng(1, &rng)
  940. .set_before_exec_callback(
  941. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  942. auto time_in_ms_fp32_normal = bencher.execs({filter, src, src}) / RUNS;
  943. bencher.set_before_exec_callback(
  944. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL"));
  945. auto time_in_ms_fp32_small = bencher.execs({filter, src, src}) / RUNS;
  946. bencher.set_param(param)
  947. .set_dtype(0, dtype::Float16())
  948. .set_dtype(1, dtype::Float16())
  949. .set_dtype(2, dtype::Float16())
  950. .set_rng(0, &rng)
  951. .set_rng(1, &rng);
  952. auto time_in_ms_fp16_small = bencher.execs({filter, src, src}) / RUNS;
  953. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  954. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  955. "speedup: "
  956. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  957. s, src.to_string().c_str(), filter.to_string().c_str(),
  958. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  959. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  960. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  961. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  962. time_in_ms_fp32_small / time_in_ms_fp16_small);
  963. };
  964. // clang-format off
  965. for (size_t s : {1})
  966. for (size_t f : {3, 5})
  967. for (size_t batch : {64})
  968. for (size_t c : {16, 32, 64, 128})
  969. for (size_t ih: {8, 16, 32})
  970. for (size_t iw : {8, 16, 32})
  971. run(batch, c, ih, iw, f, s);
  972. // clang-format on
  973. run(128, 192, 28, 28, 3, 1);
  974. run(128, 576, 14, 14, 3, 1);
  975. run(128, 384, 14, 14, 3, 1);
  976. run(128, 960, 7, 7, 3, 1);
  977. run(128, 384, 14, 14, 3, 1);
  978. run(128, 384, 14, 14, 3, 1);
  979. run(128, 384, 14, 14, 3, 1);
  980. run(128, 192, 28, 28, 3, 1);
  981. run(128, 576, 14, 14, 3, 1);
  982. }
  983. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_DATA) {
  984. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  985. size_t RUNS = 1;
  986. bencher.set_display(false).set_times(RUNS);
  987. bencher.set_before_exec_callback(
  988. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  989. Convolution::Param param;
  990. param.format = ConvBias::Param::Format::NCHW;
  991. param.sparse = Convolution::Param::Sparse::GROUP;
  992. NormalRNG rng;
  993. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t ih, size_t iw,
  994. size_t f, size_t p, size_t s) {
  995. param.pad_h = p;
  996. param.pad_w = p;
  997. param.stride_h = s;
  998. param.stride_w = s;
  999. size_t oh, ow;
  1000. infer_conv_shape2d(ih, iw, f, f, s, s, p, p, oh, ow, true);
  1001. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1002. TensorShape src_grad = {batch, group, ih, iw},
  1003. dst_grad = {batch, group * ocpg, oh, ow},
  1004. flt = {group, ocpg, 1, f, f};
  1005. auto opr = handle_cuda()->create_operator<Convolution>();
  1006. opr->param() = param;
  1007. float bandwith = static_cast<float>(
  1008. flt.total_nr_elems() + dst_grad.total_nr_elems() +
  1009. src_grad.total_nr_elems()) /
  1010. (1024 * 1024 * 1024) * 1e3;
  1011. bencher.set_param(param)
  1012. .set_dtype(0, dtype::Float32())
  1013. .set_dtype(1, dtype::Float32())
  1014. .set_dtype(2, dtype::Float32())
  1015. .set_rng(0, &rng)
  1016. .set_rng(1, &rng);
  1017. auto time_in_ms_fp32 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  1018. bencher.set_param(param)
  1019. .set_dtype(0, dtype::Float16())
  1020. .set_dtype(1, dtype::Float16())
  1021. .set_dtype(2, dtype::Float16())
  1022. .set_rng(0, &rng)
  1023. .set_rng(1, &rng);
  1024. auto time_in_ms_fp16 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  1025. printf("stride=%zu, src_grad=%s, flt=%s, "
  1026. "float32: %.2fms %.2fGB/s "
  1027. "float16: %.2fms %.2fGB/s "
  1028. "speedup: "
  1029. "%0.2f (fp16/fp32)\n",
  1030. s, src_grad.to_string().c_str(), flt.to_string().c_str(),
  1031. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1032. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  1033. };
  1034. // clang-format off
  1035. for (size_t s : {1, 2})
  1036. for (size_t f : {3, 5, 7})
  1037. for (size_t p : {f / 2})
  1038. for (size_t batch : {64})
  1039. for (size_t ocpg : {1})
  1040. for (size_t group : {16, 32, 64, 128})
  1041. for (size_t ih : {8, 16, 32, 128, 256})
  1042. for (size_t iw : {8, 16, 32, 128, 256})
  1043. run(batch, ocpg, group, ih, iw, f, p, s);
  1044. // clang-format on
  1045. }
  1046. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_FILTER) {
  1047. CUBenchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
  1048. size_t RUNS = 1;
  1049. bencher.set_display(false).set_times(RUNS);
  1050. bencher.set_before_exec_callback(
  1051. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE"));
  1052. Convolution::Param param;
  1053. param.format = ConvBias::Param::Format::NCHW;
  1054. param.sparse = Convolution::Param::Sparse::GROUP;
  1055. NormalRNG rng;
  1056. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t i, size_t f,
  1057. size_t p, size_t s) {
  1058. param.pad_h = p;
  1059. param.pad_w = p;
  1060. param.stride_h = s;
  1061. param.stride_w = s;
  1062. size_t d = infer_conv_shape(i, f, s, p, true);
  1063. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1064. TensorShape src = {batch, group, i, i}, dst_grad = {batch, group * ocpg, d, d},
  1065. flt_grad = {group, ocpg, 1, f, f};
  1066. auto opr = handle_cuda()->create_operator<Convolution>();
  1067. opr->param() = param;
  1068. float bandwith = static_cast<float>(
  1069. flt_grad.total_nr_elems() + dst_grad.total_nr_elems() +
  1070. src.total_nr_elems()) /
  1071. (1024 * 1024 * 1024) * 1e3;
  1072. bencher.set_param(param)
  1073. .set_dtype(0, dtype::Float32())
  1074. .set_dtype(1, dtype::Float32())
  1075. .set_dtype(2, dtype::Float32())
  1076. .set_rng(0, &rng)
  1077. .set_rng(1, &rng);
  1078. auto time_in_ms_fp32 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  1079. bencher.set_param(param)
  1080. .set_dtype(0, dtype::Float16())
  1081. .set_dtype(1, dtype::Float16())
  1082. .set_dtype(2, dtype::Float16())
  1083. .set_rng(0, &rng)
  1084. .set_rng(1, &rng);
  1085. auto time_in_ms_fp16 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  1086. printf("stride=%zu, src=%s, flt_grad=%s, "
  1087. "float32: %.2fms %.2fGB/s "
  1088. "float16: %.2fms %.2fGB/s "
  1089. "speedup: "
  1090. "%.2f (fp16/fp32)\n",
  1091. s, src.to_string().c_str(), flt_grad.to_string().c_str(),
  1092. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1093. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  1094. };
  1095. // clang-format off
  1096. for (size_t s : {1, 2})
  1097. for (size_t f : {3, 5, 7})
  1098. for (size_t p : {f / 2})
  1099. for (size_t batch : {64})
  1100. for (size_t ocpg : {1})
  1101. for (size_t group : {16, 32, 64, 128})
  1102. for (size_t i : {8, 16, 32, 64, 128})
  1103. run(batch, ocpg, group, i, f, p, s);
  1104. // clang-format on
  1105. }
  1106. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_LARGE_KERNEL) {
  1107. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  1108. size_t RUNS = 100;
  1109. bencher.set_display(false).set_times(RUNS);
  1110. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  1111. new OprProxy<ConvolutionForward>{true}};
  1112. bencher.set_proxy(proxy);
  1113. Convolution::Param param;
  1114. param.format = ConvBias::Param::Format::NCHW;
  1115. param.sparse = Convolution::Param::Sparse::GROUP;
  1116. NormalRNG rng;
  1117. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  1118. param.pad_h = f / 2;
  1119. param.pad_w = f / 2;
  1120. param.stride_h = s;
  1121. param.stride_w = s;
  1122. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1123. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  1124. TensorLayout dst_layout;
  1125. auto opr = handle_cuda()->create_operator<Convolution>();
  1126. opr->param() = param;
  1127. opr->deduce_layout(
  1128. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  1129. float bandwith = static_cast<float>(
  1130. src.total_nr_elems() + filter.total_nr_elems() +
  1131. dst_layout.total_nr_elems()) /
  1132. (1024 * 1024 * 1024) * 1e3;
  1133. bencher.set_param(param)
  1134. .set_dtype(0, dtype::Float32())
  1135. .set_dtype(1, dtype::Float32())
  1136. .set_dtype(2, dtype::Float32())
  1137. .set_rng(0, &rng)
  1138. .set_rng(1, &rng);
  1139. bencher.proxy()->target_execution_policy = {};
  1140. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  1141. bencher.set_param(param)
  1142. .set_dtype(0, dtype::Float16())
  1143. .set_dtype(1, dtype::Float16())
  1144. .set_dtype(2, dtype::Float16())
  1145. .set_rng(0, &rng)
  1146. .set_rng(1, &rng);
  1147. bencher.proxy()->target_execution_policy = {};
  1148. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  1149. bencher.proxy()->target_execution_policy.algo.reset();
  1150. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  1151. bencher.set_param(param);
  1152. auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  1153. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  1154. "float16: %.2fms %.2fGB/s "
  1155. "pseudo float16: %.2fms %.2fGB/s "
  1156. "speedup: "
  1157. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  1158. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  1159. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1160. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  1161. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  1162. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  1163. };
  1164. // clang-format off
  1165. for (size_t b : {32, 64})
  1166. for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
  1167. run(b, 384, 32, 32, f, 1);
  1168. run(b, 384, 64, 64, f, 1);
  1169. }
  1170. // clang-format on
  1171. }
  1172. #endif
  1173. // vim: syntax=cpp.doxygen