You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chanwise_convolution.cpp 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. /**
  2. * \file dnn/test/cuda/chanwise_convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "cuda.h"
  13. #include "megcore_cuda.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/convolution.h"
  17. #include "test/common/tensor.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/cuda/benchmark.h"
  20. #include "test/cuda/fixture.h"
  21. #include <cuda_profiler_api.h>
  22. #include <cuda_runtime_api.h>
  23. using namespace megdnn;
  24. using namespace test;
  25. namespace {
  26. #if MEGDNN_WITH_BENCHMARK
  27. bool check_need_full_bench() {
  28. if (getenv("MEGDNN_CHANWISE_CONV_FULLBENCH"))
  29. return true;
  30. printf("set MEGDNN_CHANWISE_CONV_FULLBENCH to run full benchmark\n");
  31. return false;
  32. }
  33. #endif
  34. Convolution::Param gconv_param(Convolution::Param p) {
  35. p.sparse = Convolution::Param::Sparse::GROUP;
  36. return p;
  37. }
  38. template <int P0, int P1, int P2>
  39. class BenchmarkEnv {
  40. Handle *handle, *handle_cpu;
  41. std::unique_ptr<GaussianRNG> rng;
  42. TensorLayout lsrc, lflt0, lflt1, ldst;
  43. std::unique_ptr<Tensor<>> src0, src1, flt0, flt0_cpu, flt1, flt1_cpu, dst0, dst1;
  44. cudaEvent_t cuda_ev[3];
  45. cudaStream_t cuda_stream;
  46. size_t pad_h, pad_w;
  47. template <typename T>
  48. static std::tuple<T, T, T> shuffle(std::tuple<T, T, T> data) {
  49. return std::make_tuple(
  50. std::get<P0>(data), std::get<P1>(data), std::get<P2>(data));
  51. }
  52. public:
  53. BenchmarkEnv(Handle* handle, Handle* handle_cpu) {
  54. this->handle = handle;
  55. this->handle_cpu = handle_cpu;
  56. rng = handle->create_operator<GaussianRNG>();
  57. // make cpu handle used
  58. handle_cpu->create_operator<Sleep>()->exec();
  59. for (int i = 0; i < 3; ++i)
  60. cudaEventCreate(&cuda_ev[i]);
  61. megcoreGetCUDAStream(handle->megcore_computing_handle(), &cuda_stream);
  62. }
  63. ~BenchmarkEnv() {
  64. for (int i = 0; i < 3; ++i)
  65. cudaEventDestroy(cuda_ev[i]);
  66. }
  67. void alloc(
  68. size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  69. size_t FW, size_t PH, size_t PW) {
  70. pad_h = PH;
  71. pad_w = PW;
  72. auto mkly = [](const TensorShape& s) {
  73. return TensorLayout{s, dtype::Float32()};
  74. };
  75. lsrc = mkly({N, IC, IH, IW});
  76. lflt0 = mkly({CHL_MUL * IC, IC, FH, FW});
  77. lflt1 = mkly({IC, CHL_MUL, 1, FH, FW});
  78. ldst = mkly({N, IC * CHL_MUL, IH - FH + 1 + PH * 2, IW - FW + 1 + PW * 2});
  79. src0.reset(new Tensor<>(handle, lsrc));
  80. src1.reset(new Tensor<>(handle, lsrc));
  81. flt0.reset(new Tensor<>(handle, lflt0));
  82. flt0_cpu.reset(new Tensor<>(handle_cpu, lflt0));
  83. flt1.reset(new Tensor<>(handle, lflt1));
  84. flt1_cpu.reset(new Tensor<>(handle_cpu, lflt1));
  85. dst0.reset(new Tensor<>(handle, ldst));
  86. dst1.reset(new Tensor<>(handle, ldst));
  87. }
  88. void fill_src() {
  89. rng->exec(src0->tensornd(), {});
  90. megdnn_memcpy_D2D(handle, src1->ptr(), src0->ptr(), lsrc.span().dist_byte());
  91. }
  92. void fill_flt() {
  93. rng->exec(flt1->tensornd(), {});
  94. megdnn_memcpy_D2H(
  95. handle, flt1_cpu->ptr(), flt1->ptr(), lflt1.span().dist_byte());
  96. const size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  97. // fill flt0 from flt1
  98. float* src = flt1_cpu->ptr();
  99. float* dst = flt0_cpu->ptr();
  100. memset(dst, 0, lflt0.span().dist_byte());
  101. for (size_t i = 0; i < IC; ++i) {
  102. for (size_t j = 0; j < CHL_MUL; ++j) {
  103. memcpy(dst + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  104. src + (i * CHL_MUL + j) * FSIZE, FSIZE * sizeof(float));
  105. }
  106. }
  107. megdnn_memcpy_H2D(handle, flt0->ptr(), dst, lflt0.span().dist_byte());
  108. }
  109. void fill_dst() {
  110. rng->exec(dst0->tensornd(), {});
  111. megdnn_memcpy_D2D(handle, dst1->ptr(), dst0->ptr(), ldst.span().dist_byte());
  112. }
  113. template <class Opr>
  114. void exec(Opr* opr0, Opr* opr1) {
  115. opr0->param().pad_h = pad_h;
  116. opr0->param().pad_w = pad_w;
  117. opr1->param() = opr0->param();
  118. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  119. TensorND a0, b0, c0, a1, b1, c1;
  120. std::tie(a0, b0, c0) = shuffle(
  121. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  122. std::tie(a1, b1, c1) = shuffle(
  123. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  124. WorkspaceWrapper wk(
  125. handle,
  126. std::max(
  127. opr0->get_workspace_in_bytes(a0.layout, b0.layout, c0.layout),
  128. opr1->get_workspace_in_bytes(a1.layout, b1.layout, c1.layout)));
  129. cudaProfilerStart();
  130. cudaEventRecord(cuda_ev[0], cuda_stream);
  131. opr0->exec(a0, b0, c0, wk.workspace());
  132. cudaEventRecord(cuda_ev[1], cuda_stream);
  133. opr1->exec(a1, b1, c1, wk.workspace());
  134. cudaEventRecord(cuda_ev[2], cuda_stream);
  135. cudaProfilerStop();
  136. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  137. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  138. cudaStreamSynchronize(cuda_stream);
  139. float t0 = -1, t1 = -1;
  140. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  141. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  142. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  143. lsrc.TensorShape::to_string().c_str(),
  144. lflt1.TensorShape::to_string().c_str(),
  145. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  146. }
  147. }
  148. //! special for weight preprocess
  149. void exec_convolution(ConvolutionForward* opr0, ConvolutionForward* opr1) {
  150. opr0->param().pad_h = pad_h;
  151. opr0->param().pad_w = pad_w;
  152. opr1->param() = opr0->param();
  153. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  154. TensorND a0, b0, c0, a1, b1, c1;
  155. std::tie(a0, b0, c0) = shuffle(
  156. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  157. std::tie(a1, b1, c1) = shuffle(
  158. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  159. WorkspaceWrapper wk(
  160. handle, std::max(
  161. opr0->get_workspace_in_bytes(
  162. a0.layout, b0.layout, c0.layout, nullptr),
  163. opr1->get_workspace_in_bytes(
  164. a1.layout, b1.layout, c1.layout, nullptr)));
  165. cudaProfilerStart();
  166. cudaEventRecord(cuda_ev[0], cuda_stream);
  167. opr0->exec(a0, b0, c0, nullptr, wk.workspace());
  168. cudaEventRecord(cuda_ev[1], cuda_stream);
  169. opr1->exec(a1, b1, c1, nullptr, wk.workspace());
  170. cudaEventRecord(cuda_ev[2], cuda_stream);
  171. cudaProfilerStop();
  172. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  173. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  174. cudaStreamSynchronize(cuda_stream);
  175. float t0 = -1, t1 = -1;
  176. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  177. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  178. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  179. lsrc.TensorShape::to_string().c_str(),
  180. lflt1.TensorShape::to_string().c_str(),
  181. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  182. }
  183. }
  184. void cmp_dst() {
  185. Tensor<> dst0_cpu(handle_cpu, ldst), dst1_cpu(handle_cpu, ldst);
  186. megdnn_memcpy_D2H(handle, dst0_cpu.ptr(), dst0->ptr(), ldst.span().dist_byte());
  187. megdnn_memcpy_D2H(handle, dst1_cpu.ptr(), dst1->ptr(), ldst.span().dist_byte());
  188. dst0_cpu.check_with(dst1_cpu);
  189. }
  190. void cmp_src() {
  191. Tensor<> src0_cpu(handle_cpu, lsrc), src1_cpu(handle_cpu, lsrc);
  192. megdnn_memcpy_D2H(handle, src0_cpu.ptr(), src0->ptr(), lsrc.span().dist_byte());
  193. megdnn_memcpy_D2H(handle, src1_cpu.ptr(), src1->ptr(), lsrc.span().dist_byte());
  194. src0_cpu.check_with(src1_cpu);
  195. }
  196. void cmp_flt() {
  197. Tensor<> flt0_cpu(handle_cpu, lflt0), flt1_cpu(handle_cpu, lflt1);
  198. float* p0 = flt0_cpu.ptr();
  199. float* p1 = flt1_cpu.ptr();
  200. megdnn_memcpy_D2H(handle, p0, flt0->ptr(), lflt0.span().dist_byte());
  201. megdnn_memcpy_D2H(handle, p1, flt1->ptr(), lflt1.span().dist_byte());
  202. size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  203. double tot_err = 0, tot_err_num = 0;
  204. for (size_t i = 0; i < IC; ++i) {
  205. for (size_t j = 0; j < CHL_MUL; ++j) {
  206. auto t0 = p0 + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  207. t1 = p1 + (i * CHL_MUL + j) * FSIZE;
  208. for (size_t k = 0; k < FSIZE; ++k) {
  209. auto err = std::abs(diff(t0[k], t1[k]));
  210. tot_err += err;
  211. tot_err_num += 1;
  212. ASSERT_LT(err, 1e-2) << "failed at " << i << " " << j << " " << k
  213. << " vals=" << t0[k] << "," << t1[k];
  214. }
  215. }
  216. }
  217. auto avg_err = tot_err / tot_err_num;
  218. ASSERT_LT(avg_err, 1e-4);
  219. }
  220. };
  221. } // anonymous namespace
  222. constexpr auto M = Convolution::Mode::CROSS_CORRELATION;
  223. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD) {
  224. Checker<Convolution> checker(handle_cuda());
  225. bool require_algo = false;
  226. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  227. ExecutionPolicyAlgoName{
  228. "DEFAULT",
  229. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  230. "CHANNEL_WISE", {})
  231. .c_str(),
  232. {}}}},
  233. &require_algo));
  234. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  235. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  236. if (dtype.enumv() == DTypeEnum::Float16)
  237. checker.set_epsilon(2e-2);
  238. // simple case
  239. // clang-format off
  240. for (uint32_t s : {1, 2})
  241. for (uint32_t p : {0, 1, 2, 3})
  242. for (size_t f : {2, 3, 5, 7})
  243. for (size_t ocpg : {1, 3}) {
  244. checker.set_param(gconv_param({M, p, p, s, s}))
  245. .execs({{2, 3, 16, 16}, {3, ocpg, 1, f, f}, {}});
  246. }
  247. // clang-format on
  248. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  249. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  250. // padding larger than kern
  251. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  252. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  253. }
  254. }
  255. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_SMALL) {
  256. Checker<Convolution> checker(handle_cuda());
  257. bool require_algo = false;
  258. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  259. ExecutionPolicyAlgoName{
  260. "DEFAULT",
  261. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  262. "CHANNEL_WISE_SMALL", {})
  263. .c_str(),
  264. {}}}},
  265. &require_algo));
  266. for (auto dtype : std::vector<DType> {
  267. dtype::Float32(),
  268. #if CUDA_VERSION >= 9000
  269. dtype::Float16()
  270. #endif
  271. }) {
  272. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  273. if (dtype.enumv() == DTypeEnum::Float16)
  274. checker.set_epsilon(2e-2);
  275. // clang-format off
  276. for (uint32_t s : {1})
  277. for (uint32_t f : {1, 3, 5, 7}) {
  278. checker.set_param(gconv_param({M, f / 2, f / 2, s, s}))
  279. .execs({{2, 3, 16, 16}, {3, 1, 1, f, f}, {}});
  280. }
  281. // clang-format on
  282. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  283. .execs({{2, 3, 3, 16}, {3, 1, 1, 3, 3}, {}})
  284. .execs({{2, 3, 8, 3}, {3, 1, 1, 3, 3}, {}});
  285. }
  286. }
  287. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA) {
  288. Checker<ConvolutionBackwardData> checker(handle_cuda());
  289. bool require_algo = false;
  290. checker.set_before_exec_callback(
  291. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE", &require_algo));
  292. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  293. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  294. if (dtype.enumv() == DTypeEnum::Float16)
  295. checker.set_epsilon(1e-1);
  296. // simple case
  297. // clang-format off
  298. for (uint32_t s : {1, 2})
  299. for (uint32_t p : {0, 1, 2, 3})
  300. for (size_t f : {1, 2, 3, 5, 7})
  301. for (size_t ocpg : {1, 3}) {
  302. size_t ii = infer_conv_shape(16, f, s, p, true);
  303. checker.set_param(gconv_param({M, p, p, s, s}))
  304. .execs({{3, ocpg, 1, f, f},
  305. {2, 3 * ocpg, ii, ii},
  306. {2, 3, 16, 16}});
  307. }
  308. // clang-format on
  309. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  310. .execs({{12, 3, 1, 4, 5}, {32, 36, 20, 10}, {32, 12, 39, 8}});
  311. checker.set_param(gconv_param({M, 30, 20, 5, 4}))
  312. .execs({{6, 2, 1, 5, 4}, {32, 12, 12, 10}, {32, 6, 3, 2}});
  313. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  314. .execs({{6, 2, 1, 4, 5}, {32, 12, 10, 12}, {32, 6, 2, 3}});
  315. }
  316. }
  317. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_SMALL) {
  318. Checker<ConvolutionBackwardData> checker(handle_cuda());
  319. bool require_algo = false;
  320. checker.set_before_exec_callback(
  321. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL", &require_algo));
  322. for (auto dtype : std::vector<DType> {
  323. dtype::Float32(),
  324. #if CUDA_VERSION >= 9000
  325. dtype::Float16()
  326. #endif
  327. }) {
  328. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  329. if (dtype.enumv() == DTypeEnum::Float16)
  330. checker.set_epsilon(2e-2);
  331. for (uint32_t f : {1, 3, 5, 7}) {
  332. checker.set_param(gconv_param({M, f / 2, f / 2, 1, 1}))
  333. .execs({{3, 1, 1, f, f}, {2, 3, 16, 16}, {2, 3, 16, 16}});
  334. }
  335. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  336. .execs({{3, 1, 1, 3, 3}, {2, 3, 3, 16}, {2, 3, 3, 16}})
  337. .execs({{3, 1, 1, 3, 3}, {2, 3, 8, 3}, {2, 3, 8, 3}});
  338. }
  339. }
  340. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER) {
  341. Checker<ConvolutionBackwardFilter> checker(handle_cuda());
  342. bool require_algo = false;
  343. checker.set_before_exec_callback(
  344. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE", &require_algo));
  345. UniformFloatRNG rng(-0.1, 0.1);
  346. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  347. checker.set_dtype(0, dtype)
  348. .set_dtype(1, dtype)
  349. .set_dtype(2, dtype)
  350. .set_rng(0, &rng)
  351. .set_rng(1, &rng);
  352. if (dtype.enumv() == DTypeEnum::Float16)
  353. checker.set_epsilon(2e-1);
  354. // simple case
  355. // clang-format off
  356. for (uint32_t s : {1, 2})
  357. for (uint32_t p : {0, 1, 2, 3})
  358. for (uint32_t f : {1, 2, 3, 5, 7})
  359. for (uint32_t ocpg : {1, 3})
  360. for (uint32_t i : {8, 16, 32, 64}){
  361. size_t ii = infer_conv_shape(i, f, s, p, true);
  362. checker.set_param(gconv_param({M, p, p, s, s}))
  363. .execs({{2, 3, i, i},
  364. {2, 3 * ocpg, ii, ii},
  365. {3, ocpg, 1, f, f}});
  366. }
  367. // clang-format on
  368. // padding larger than kern
  369. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  370. .execs({{32, 6, 2, 3}, {32, 12, 10, 12}, {6, 2, 1, 4, 5}});
  371. // unused filter items
  372. checker.set_param(gconv_param({M, 2, 3, 2, 3}))
  373. .execs({{32, 6, 1, 1}, {32, 12, 1, 1}, {6, 2, 1, 5, 7}});
  374. }
  375. }
  376. #if MEGDNN_WITH_BENCHMARK
  377. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_BENCH_CHECK) {
  378. auto handle = handle_cuda();
  379. auto handle_cpu = handle_naive();
  380. auto conv0 = handle->create_operator<ConvolutionForward>();
  381. auto conv1 = handle->create_operator<ConvolutionForward>();
  382. BenchmarkEnv<0, 1, 2> benv(handle, handle_cpu);
  383. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  384. size_t FW, size_t PH, size_t PW) {
  385. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  386. benv.fill_src();
  387. benv.fill_flt();
  388. benv.exec_convolution(conv0.get(), conv1.get());
  389. benv.cmp_dst();
  390. };
  391. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  392. if (check_need_full_bench()) {
  393. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  394. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  395. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  396. }
  397. }
  398. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_DATA_BENCH_CHECK) {
  399. auto handle = handle_cuda();
  400. auto handle_cpu = handle_naive();
  401. auto conv0 = handle->create_operator<ConvolutionBackwardData>();
  402. auto conv1 = handle->create_operator<ConvolutionBackwardData>();
  403. BenchmarkEnv<1, 2, 0> benv(handle, handle_cpu);
  404. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  405. size_t FW, size_t PH, size_t PW) {
  406. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  407. benv.fill_dst();
  408. benv.fill_flt();
  409. benv.exec(conv0.get(), conv1.get());
  410. benv.cmp_src();
  411. };
  412. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  413. if (check_need_full_bench()) {
  414. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  415. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  416. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  417. }
  418. }
  419. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_FILTER_BENCH_CHECK) {
  420. auto handle = handle_cuda();
  421. auto handle_cpu = handle_naive();
  422. auto conv0 = handle->create_operator<ConvolutionBackwardFilter>();
  423. auto conv1 = handle->create_operator<ConvolutionBackwardFilter>();
  424. BenchmarkEnv<0, 2, 1> benv(handle, handle_cpu);
  425. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  426. size_t FW, size_t PH, size_t PW) {
  427. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  428. benv.fill_src();
  429. benv.fill_dst();
  430. benv.exec(conv0.get(), conv1.get());
  431. benv.cmp_flt();
  432. };
  433. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  434. if (check_need_full_bench()) {
  435. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  436. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  437. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  438. }
  439. }
  440. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_FWD) {
  441. // enable profiling
  442. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  443. new OprProxy<ConvolutionForward>{true}};
  444. proxy->warmup_times = 1;
  445. proxy->exec_times = 10;
  446. Benchmarker<ConvolutionForward> checker(handle_cuda());
  447. checker.set_times(1);
  448. ConvolutionForward::Param param;
  449. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  450. checker.set_param(param);
  451. checker.set_proxy(proxy);
  452. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  453. checker.proxy()->target_execution_policy = {};
  454. checker.execs({{N, C, IH, IW}, {C, 1, 1, FH, FW}, {}});
  455. };
  456. run(128, 64, 90, 80, 3, 3);
  457. run(128, 90, 100, 100, 3, 5);
  458. run(128, 32, 62, 62, 5, 5);
  459. }
  460. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_DATA) {
  461. // enable profiling
  462. std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
  463. new OprProxy<ConvolutionBackwardData>{true}};
  464. proxy->warmup_times = 1;
  465. proxy->exec_times = 10;
  466. Benchmarker<ConvolutionBackwardData> checker(handle_cuda());
  467. checker.set_times(1);
  468. ConvolutionBackwardData::Param param;
  469. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  470. checker.set_param(param);
  471. checker.set_proxy(proxy);
  472. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  473. checker.proxy()->target_execution_policy.algo.reset();
  474. checker.execs(
  475. {{C, 1, 1, FH, FW}, {N, C, IH - FH + 1, IW - FW + 1}, {N, C, IH, IW}});
  476. };
  477. run(128, 64, 90, 80, 3, 3);
  478. run(128, 90, 100, 100, 3, 5);
  479. run(128, 32, 62, 62, 5, 5);
  480. }
  481. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_FILTER) {
  482. // enable profiling
  483. std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
  484. new OprProxy<ConvolutionBackwardFilter>{true}};
  485. proxy->warmup_times = 1;
  486. proxy->exec_times = 10;
  487. Benchmarker<ConvolutionBackwardFilter> checker(handle_cuda());
  488. checker.set_times(1);
  489. ConvolutionBackwardFilter::Param param;
  490. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  491. checker.set_param(param);
  492. checker.set_proxy(proxy);
  493. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  494. checker.proxy()->target_execution_policy.algo.reset();
  495. checker.execs(
  496. {{N, C, IH, IW}, {N, C, IH - FH + 1, IW - FW + 1}, {C, 1, 1, FH, FW}});
  497. };
  498. run(128, 64, 90, 80, 3, 3);
  499. run(128, 90, 100, 100, 3, 5);
  500. run(128, 32, 62, 62, 5, 5);
  501. }
  502. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_ALL_ALGO_FORWARD) {
  503. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  504. size_t RUNS = 10;
  505. bencher.set_display(false).set_times(RUNS);
  506. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  507. new OprProxy<ConvolutionForward>{true}};
  508. bencher.set_proxy(proxy);
  509. Convolution::Param param;
  510. param.format = ConvBias::Param::Format::NCHW;
  511. param.sparse = Convolution::Param::Sparse::GROUP;
  512. NormalRNG rng;
  513. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  514. param.pad_h = f / 2;
  515. param.pad_w = f / 2;
  516. param.stride_h = s;
  517. param.stride_w = s;
  518. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  519. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  520. TensorLayout dst_layout;
  521. auto opr = handle_cuda()->create_operator<Convolution>();
  522. opr->param() = param;
  523. opr->deduce_layout(
  524. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  525. float bandwith = static_cast<float>(
  526. src.total_nr_elems() + filter.total_nr_elems() +
  527. dst_layout.total_nr_elems()) /
  528. (1024 * 1024 * 1024) * 1e3;
  529. bencher.set_param(param)
  530. .set_dtype(0, dtype::Float32())
  531. .set_dtype(1, dtype::Float32())
  532. .set_dtype(2, dtype::Float32())
  533. .set_rng(0, &rng)
  534. .set_rng(1, &rng);
  535. bencher.proxy()->target_execution_policy = {};
  536. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  537. bencher.set_param(param)
  538. .set_dtype(0, dtype::Float16())
  539. .set_dtype(1, dtype::Float16())
  540. .set_dtype(2, dtype::Float16())
  541. .set_rng(0, &rng)
  542. .set_rng(1, &rng);
  543. bencher.proxy()->target_execution_policy = {};
  544. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  545. bencher.proxy()->target_execution_policy.algo.reset();
  546. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  547. bencher.set_param(param);
  548. auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  549. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  550. "float16: %.2fms %.2fGB/s "
  551. "pseudo float16: %.2fms %.2fGB/s "
  552. "speedup: "
  553. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  554. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  555. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  556. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  557. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  558. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  559. };
  560. // clang-format off
  561. for (size_t s : {1, 2})
  562. for (size_t f : {3, 5, 7})
  563. for (size_t batch : {64})
  564. for (size_t c : {16, 32, 64, 128})
  565. for (size_t ih: {128, 256})
  566. for (size_t iw : {128, 256})
  567. run(batch, c, ih, iw, f, s);
  568. // clang-format on
  569. run(128, 192, 28, 28, 3, 1);
  570. run(128, 192, 28, 28, 3, 2);
  571. run(128, 576, 14, 14, 3, 1);
  572. run(128, 384, 14, 14, 3, 1);
  573. run(128, 32, 112, 112, 3, 1);
  574. run(128, 960, 7, 7, 3, 1);
  575. run(128, 384, 14, 14, 3, 1);
  576. run(128, 144, 56, 56, 3, 2);
  577. run(128, 384, 14, 14, 3, 1);
  578. run(128, 144, 56, 56, 3, 1);
  579. run(128, 96, 112, 112, 3, 2);
  580. run(128, 384, 14, 14, 3, 1);
  581. run(128, 192, 28, 28, 3, 1);
  582. run(128, 576, 14, 14, 3, 1);
  583. run(128, 576, 14, 14, 3, 2);
  584. }
  585. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT) {
  586. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  587. size_t RUNS = 1;
  588. bencher.set_display(false).set_times(RUNS);
  589. bencher.set_before_exec_callback(
  590. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  591. "DEFAULT",
  592. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  593. "CHANNEL_WISE", {})
  594. .c_str(),
  595. {}}}}));
  596. Convolution::Param param;
  597. param.format = ConvBias::Param::Format::NCHW;
  598. param.sparse = Convolution::Param::Sparse::GROUP;
  599. NormalRNG rng;
  600. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  601. param.pad_h = f / 2;
  602. param.pad_w = f / 2;
  603. param.stride_h = s;
  604. param.stride_w = s;
  605. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  606. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  607. TensorLayout dst_layout;
  608. auto opr = handle_cuda()->create_operator<Convolution>();
  609. opr->param() = param;
  610. opr->deduce_layout(
  611. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  612. float bandwith = static_cast<float>(
  613. src.total_nr_elems() + filter.total_nr_elems() +
  614. dst_layout.total_nr_elems()) /
  615. (1024 * 1024 * 1024) * 1e3;
  616. bencher.set_param(param)
  617. .set_dtype(0, dtype::Float32())
  618. .set_dtype(1, dtype::Float32())
  619. .set_dtype(2, dtype::Float32())
  620. .set_rng(0, &rng)
  621. .set_rng(1, &rng);
  622. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  623. bencher.set_param(param)
  624. .set_dtype(0, dtype::Float16())
  625. .set_dtype(1, dtype::Float16())
  626. .set_dtype(2, dtype::Float16())
  627. .set_rng(0, &rng)
  628. .set_rng(1, &rng);
  629. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  630. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  631. "float16: %.2fms %.2fGB/s "
  632. "speedup: "
  633. "%0.2f (fp16/fp32)\n",
  634. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  635. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  636. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  637. };
  638. // clang-format off
  639. for (size_t s : {1})
  640. for (size_t f : {3, 5, 7})
  641. for (size_t batch : {64})
  642. for (size_t c : {16, 32, 64, 128})
  643. for (size_t ih: {8, 16, 32, 128, 256})
  644. for (size_t iw : {8, 16, 32, 128, 256})
  645. run(batch, c, ih, iw, f, s);
  646. // clang-format on
  647. }
  648. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT_SMALL) {
  649. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  650. size_t RUNS = 1;
  651. bencher.set_display(false).set_times(RUNS);
  652. Convolution::Param param;
  653. param.format = ConvBias::Param::Format::NCHW;
  654. param.sparse = Convolution::Param::Sparse::GROUP;
  655. NormalRNG rng;
  656. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  657. param.pad_h = f / 2;
  658. param.pad_w = f / 2;
  659. param.stride_h = s;
  660. param.stride_w = s;
  661. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  662. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  663. TensorLayout dst_layout;
  664. auto opr = handle_cuda()->create_operator<Convolution>();
  665. opr->param() = param;
  666. opr->deduce_layout(
  667. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  668. float bandwith = static_cast<float>(
  669. src.total_nr_elems() + filter.total_nr_elems() +
  670. dst_layout.total_nr_elems()) /
  671. (1024 * 1024 * 1024) * 1e3;
  672. bencher.set_param(param)
  673. .set_dtype(0, dtype::Float32())
  674. .set_dtype(1, dtype::Float32())
  675. .set_dtype(2, dtype::Float32())
  676. .set_rng(0, &rng)
  677. .set_rng(1, &rng)
  678. .set_before_exec_callback(AlgoChecker<
  679. ConvolutionForward>(ExecutionPolicyAlgoName{
  680. "DEFAULT",
  681. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  682. "CHANNEL_WISE", {})
  683. .c_str(),
  684. {}}}}));
  685. auto time_in_ms_fp32_normal = bencher.execs({src, filter, {}}) / RUNS;
  686. bencher.set_before_exec_callback(
  687. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  688. "DEFAULT",
  689. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  690. "CHANNEL_WISE", {})
  691. .c_str(),
  692. {}}}}));
  693. auto time_in_ms_fp32_small = bencher.execs({src, filter, {}}) / RUNS;
  694. bencher.set_param(param)
  695. .set_dtype(0, dtype::Float16())
  696. .set_dtype(1, dtype::Float16())
  697. .set_dtype(2, dtype::Float16())
  698. .set_rng(0, &rng)
  699. .set_rng(1, &rng);
  700. auto time_in_ms_fp16_small = bencher.execs({src, filter, {}}) / RUNS;
  701. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  702. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  703. "speedup: "
  704. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  705. s, src.to_string().c_str(), filter.to_string().c_str(),
  706. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  707. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  708. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  709. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  710. time_in_ms_fp32_small / time_in_ms_fp16_small);
  711. };
  712. // clang-format off
  713. for (size_t s : {1})
  714. for (size_t f : {3, 5})
  715. for (size_t batch : {64})
  716. for (size_t c : {16, 32, 64, 128})
  717. for (size_t ih: {8, 16, 32})
  718. for (size_t iw : {8, 16, 32})
  719. run(batch, c, ih, iw, f, s);
  720. // clang-format on
  721. run(128, 192, 28, 28, 3, 1);
  722. run(128, 576, 14, 14, 3, 1);
  723. run(128, 384, 14, 14, 3, 1);
  724. run(128, 960, 7, 7, 3, 1);
  725. run(128, 384, 14, 14, 3, 1);
  726. run(128, 384, 14, 14, 3, 1);
  727. run(128, 384, 14, 14, 3, 1);
  728. run(128, 192, 28, 28, 3, 1);
  729. run(128, 576, 14, 14, 3, 1);
  730. }
  731. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_CUDNN_DNN) {
  732. CUBenchmarker<ConvBiasForward> bencher(handle_cuda());
  733. size_t RUNS = 1;
  734. bencher.set_display(false).set_times(RUNS);
  735. ConvBias::Param param;
  736. param.format = ConvBias::Param::Format::NCHW;
  737. param.sparse = ConvBias::Param::Sparse::GROUP;
  738. NormalRNG rng;
  739. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  740. param.pad_h = f / 2;
  741. param.pad_w = f / 2;
  742. param.stride_h = s;
  743. param.stride_w = s;
  744. param.compute_mode = param::ConvBias::ComputeMode::DEFAULT;
  745. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f},
  746. bias = {1, c, 1, 1};
  747. TensorLayout dst_layout;
  748. auto opr = handle_cuda()->create_operator<ConvBias>();
  749. opr->param() = param;
  750. opr->deduce_layout(
  751. {src, dtype::Float32()}, {filter, dtype::Float32()},
  752. {bias, dtype::Float32()}, {}, dst_layout);
  753. float computation_mops =
  754. static_cast<float>(dst_layout.total_nr_elems() * f * f * 2) * 1e-6;
  755. bencher.set_param(param)
  756. .set_dtype(0, dtype::Float32())
  757. .set_dtype(1, dtype::Float32())
  758. .set_dtype(2, dtype::Float32())
  759. .set_rng(0, &rng)
  760. .set_rng(1, &rng);
  761. bencher.set_before_exec_callback(
  762. AlgoChecker<ConvBiasForward>(".+CHANNEL_WISE.+"));
  763. auto time_in_ms_dnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  764. bencher.set_param(param)
  765. .set_dtype(0, dtype::Float32())
  766. .set_dtype(1, dtype::Float32())
  767. .set_dtype(2, dtype::Float32())
  768. .set_rng(0, &rng)
  769. .set_rng(1, &rng);
  770. bencher.set_before_exec_callback(AlgoChecker<ConvBiasForward>(
  771. ".+CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM.+"));
  772. auto time_in_ms_cudnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  773. printf("stride=%zu src=%s, filter=%s, dst=%s, dnn: %.2fms %.2fGB/s "
  774. "cudnn: %.2fms %.2fGB/s "
  775. "speedup: "
  776. "%0.2f (dnn/cudnn)\n",
  777. s, src.to_string().c_str(), filter.to_string().c_str(),
  778. dst_layout.to_string().c_str(), time_in_ms_dnn,
  779. computation_mops / time_in_ms_dnn, time_in_ms_cudnn,
  780. computation_mops / time_in_ms_cudnn, time_in_ms_cudnn / time_in_ms_dnn);
  781. };
  782. // clang-format off
  783. for(size_t batch:{1, 16, 32, 64, 128}){
  784. run(batch, 32, 112, 112, 3, 1);
  785. run(batch, 96, 112, 112, 3, 2);
  786. run(batch, 96, 112, 112, 3, 1);
  787. run(batch, 144, 56, 56, 3, 2);
  788. run(batch, 144, 56, 56, 3, 1);
  789. run(batch, 192, 28, 28, 3, 1);
  790. run(batch, 384, 14, 14, 3, 1);
  791. run(batch, 576, 14, 14, 3, 1);
  792. run(batch, 960, 7, 7, 3, 1);
  793. //! calibrate heu algo policy hw_size param
  794. run(batch, 144, 24, 24, 3, 1);
  795. run(batch, 144, 22, 22, 3, 1);
  796. run(batch, 144, 20, 20, 3, 1);
  797. run(batch, 144, 18, 18, 3, 1);
  798. }
  799. // clang-format on
  800. }
  801. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_FLOAT_SMALL) {
  802. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  803. size_t RUNS = 1;
  804. bencher.set_display(false).set_times(RUNS);
  805. ConvolutionBackwardData::Param param;
  806. param.format = Convolution::Param::Format::NCHW;
  807. param.sparse = Convolution::Param::Sparse::GROUP;
  808. NormalRNG rng;
  809. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  810. param.pad_h = f / 2;
  811. param.pad_w = f / 2;
  812. param.stride_h = s;
  813. param.stride_w = s;
  814. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  815. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  816. float bandwith = static_cast<float>(
  817. src.total_nr_elems() + filter.total_nr_elems() +
  818. src.total_nr_elems()) /
  819. (1024 * 1024 * 1024) * 1e3;
  820. bencher.set_param(param)
  821. .set_dtype(0, dtype::Float32())
  822. .set_dtype(1, dtype::Float32())
  823. .set_dtype(2, dtype::Float32())
  824. .set_rng(0, &rng)
  825. .set_rng(1, &rng)
  826. .set_before_exec_callback(
  827. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  828. auto time_in_ms_fp32_normal = bencher.execs({filter, src, src}) / RUNS;
  829. bencher.set_before_exec_callback(
  830. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL"));
  831. auto time_in_ms_fp32_small = bencher.execs({filter, src, src}) / RUNS;
  832. bencher.set_param(param)
  833. .set_dtype(0, dtype::Float16())
  834. .set_dtype(1, dtype::Float16())
  835. .set_dtype(2, dtype::Float16())
  836. .set_rng(0, &rng)
  837. .set_rng(1, &rng);
  838. auto time_in_ms_fp16_small = bencher.execs({filter, src, src}) / RUNS;
  839. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  840. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  841. "speedup: "
  842. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  843. s, src.to_string().c_str(), filter.to_string().c_str(),
  844. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  845. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  846. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  847. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  848. time_in_ms_fp32_small / time_in_ms_fp16_small);
  849. };
  850. // clang-format off
  851. for (size_t s : {1})
  852. for (size_t f : {3, 5})
  853. for (size_t batch : {64})
  854. for (size_t c : {16, 32, 64, 128})
  855. for (size_t ih: {8, 16, 32})
  856. for (size_t iw : {8, 16, 32})
  857. run(batch, c, ih, iw, f, s);
  858. // clang-format on
  859. run(128, 192, 28, 28, 3, 1);
  860. run(128, 576, 14, 14, 3, 1);
  861. run(128, 384, 14, 14, 3, 1);
  862. run(128, 960, 7, 7, 3, 1);
  863. run(128, 384, 14, 14, 3, 1);
  864. run(128, 384, 14, 14, 3, 1);
  865. run(128, 384, 14, 14, 3, 1);
  866. run(128, 192, 28, 28, 3, 1);
  867. run(128, 576, 14, 14, 3, 1);
  868. }
  869. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_DATA) {
  870. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  871. size_t RUNS = 1;
  872. bencher.set_display(false).set_times(RUNS);
  873. bencher.set_before_exec_callback(
  874. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  875. Convolution::Param param;
  876. param.format = ConvBias::Param::Format::NCHW;
  877. param.sparse = Convolution::Param::Sparse::GROUP;
  878. NormalRNG rng;
  879. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t ih, size_t iw,
  880. size_t f, size_t p, size_t s) {
  881. param.pad_h = p;
  882. param.pad_w = p;
  883. param.stride_h = s;
  884. param.stride_w = s;
  885. size_t oh, ow;
  886. infer_conv_shape2d(ih, iw, f, f, s, s, p, p, oh, ow, true);
  887. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  888. TensorShape src_grad = {batch, group, ih, iw},
  889. dst_grad = {batch, group * ocpg, oh, ow},
  890. flt = {group, ocpg, 1, f, f};
  891. auto opr = handle_cuda()->create_operator<Convolution>();
  892. opr->param() = param;
  893. float bandwith = static_cast<float>(
  894. flt.total_nr_elems() + dst_grad.total_nr_elems() +
  895. src_grad.total_nr_elems()) /
  896. (1024 * 1024 * 1024) * 1e3;
  897. bencher.set_param(param)
  898. .set_dtype(0, dtype::Float32())
  899. .set_dtype(1, dtype::Float32())
  900. .set_dtype(2, dtype::Float32())
  901. .set_rng(0, &rng)
  902. .set_rng(1, &rng);
  903. auto time_in_ms_fp32 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  904. bencher.set_param(param)
  905. .set_dtype(0, dtype::Float16())
  906. .set_dtype(1, dtype::Float16())
  907. .set_dtype(2, dtype::Float16())
  908. .set_rng(0, &rng)
  909. .set_rng(1, &rng);
  910. auto time_in_ms_fp16 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  911. printf("stride=%zu, src_grad=%s, flt=%s, "
  912. "float32: %.2fms %.2fGB/s "
  913. "float16: %.2fms %.2fGB/s "
  914. "speedup: "
  915. "%0.2f (fp16/fp32)\n",
  916. s, src_grad.to_string().c_str(), flt.to_string().c_str(),
  917. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  918. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  919. };
  920. // clang-format off
  921. for (size_t s : {1, 2})
  922. for (size_t f : {3, 5, 7})
  923. for (size_t p : {f / 2})
  924. for (size_t batch : {64})
  925. for (size_t ocpg : {1})
  926. for (size_t group : {16, 32, 64, 128})
  927. for (size_t ih : {8, 16, 32, 128, 256})
  928. for (size_t iw : {8, 16, 32, 128, 256})
  929. run(batch, ocpg, group, ih, iw, f, p, s);
  930. // clang-format on
  931. }
  932. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_FILTER) {
  933. CUBenchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
  934. size_t RUNS = 1;
  935. bencher.set_display(false).set_times(RUNS);
  936. bencher.set_before_exec_callback(
  937. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE"));
  938. Convolution::Param param;
  939. param.format = ConvBias::Param::Format::NCHW;
  940. param.sparse = Convolution::Param::Sparse::GROUP;
  941. NormalRNG rng;
  942. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t i, size_t f,
  943. size_t p, size_t s) {
  944. param.pad_h = p;
  945. param.pad_w = p;
  946. param.stride_h = s;
  947. param.stride_w = s;
  948. size_t d = infer_conv_shape(i, f, s, p, true);
  949. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  950. TensorShape src = {batch, group, i, i}, dst_grad = {batch, group * ocpg, d, d},
  951. flt_grad = {group, ocpg, 1, f, f};
  952. auto opr = handle_cuda()->create_operator<Convolution>();
  953. opr->param() = param;
  954. float bandwith = static_cast<float>(
  955. flt_grad.total_nr_elems() + dst_grad.total_nr_elems() +
  956. src.total_nr_elems()) /
  957. (1024 * 1024 * 1024) * 1e3;
  958. bencher.set_param(param)
  959. .set_dtype(0, dtype::Float32())
  960. .set_dtype(1, dtype::Float32())
  961. .set_dtype(2, dtype::Float32())
  962. .set_rng(0, &rng)
  963. .set_rng(1, &rng);
  964. auto time_in_ms_fp32 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  965. bencher.set_param(param)
  966. .set_dtype(0, dtype::Float16())
  967. .set_dtype(1, dtype::Float16())
  968. .set_dtype(2, dtype::Float16())
  969. .set_rng(0, &rng)
  970. .set_rng(1, &rng);
  971. auto time_in_ms_fp16 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  972. printf("stride=%zu, src=%s, flt_grad=%s, "
  973. "float32: %.2fms %.2fGB/s "
  974. "float16: %.2fms %.2fGB/s "
  975. "speedup: "
  976. "%.2f (fp16/fp32)\n",
  977. s, src.to_string().c_str(), flt_grad.to_string().c_str(),
  978. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  979. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  980. };
  981. // clang-format off
  982. for (size_t s : {1, 2})
  983. for (size_t f : {3, 5, 7})
  984. for (size_t p : {f / 2})
  985. for (size_t batch : {64})
  986. for (size_t ocpg : {1})
  987. for (size_t group : {16, 32, 64, 128})
  988. for (size_t i : {8, 16, 32, 64, 128})
  989. run(batch, ocpg, group, i, f, p, s);
  990. // clang-format on
  991. }
  992. #endif
  993. // vim: syntax=cpp.doxygen