You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chanwise_convolution.cpp 63 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556
  1. /**
  2. * \file dnn/test/cuda/chanwise_convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "cuda.h"
  13. #include "megcore_cuda.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/checker.h"
  16. #include "test/common/convolution.h"
  17. #include "test/common/tensor.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/cuda/benchmark.h"
  20. #include "test/cuda/fixture.h"
  21. #include "test/cuda/utils.h"
  22. #include <cuda_profiler_api.h>
  23. #include <cuda_runtime_api.h>
  24. using namespace megdnn;
  25. using namespace test;
  26. namespace {
  27. #if MEGDNN_WITH_BENCHMARK
  28. bool check_need_full_bench() {
  29. if (getenv("MEGDNN_CHANWISE_CONV_FULLBENCH"))
  30. return true;
  31. printf("set MEGDNN_CHANWISE_CONV_FULLBENCH to run full benchmark\n");
  32. return false;
  33. }
  34. #endif
  35. Convolution::Param gconv_param(Convolution::Param p, bool io16xc32 = false) {
  36. p.sparse = Convolution::Param::Sparse::GROUP;
  37. if (io16xc32)
  38. p.compute_mode = Convolution::Param::ComputeMode::FLOAT32;
  39. return p;
  40. }
  41. template <int P0, int P1, int P2>
  42. class BenchmarkEnv {
  43. Handle *handle, *handle_cpu;
  44. std::unique_ptr<GaussianRNG> rng;
  45. TensorLayout lsrc, lflt0, lflt1, ldst;
  46. std::unique_ptr<Tensor<>> src0, src1, flt0, flt0_cpu, flt1, flt1_cpu, dst0, dst1;
  47. cudaEvent_t cuda_ev[3];
  48. cudaStream_t cuda_stream;
  49. size_t pad_h, pad_w;
  50. template <typename T>
  51. static std::tuple<T, T, T> shuffle(std::tuple<T, T, T> data) {
  52. return std::make_tuple(
  53. std::get<P0>(data), std::get<P1>(data), std::get<P2>(data));
  54. }
  55. public:
  56. BenchmarkEnv(Handle* handle, Handle* handle_cpu) {
  57. this->handle = handle;
  58. this->handle_cpu = handle_cpu;
  59. rng = handle->create_operator<GaussianRNG>();
  60. // make cpu handle used
  61. handle_cpu->create_operator<Sleep>()->exec();
  62. for (int i = 0; i < 3; ++i)
  63. cudaEventCreate(&cuda_ev[i]);
  64. megcoreGetCUDAStream(handle->megcore_computing_handle(), &cuda_stream);
  65. }
  66. ~BenchmarkEnv() {
  67. for (int i = 0; i < 3; ++i)
  68. cudaEventDestroy(cuda_ev[i]);
  69. }
  70. void alloc(
  71. size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  72. size_t FW, size_t PH, size_t PW) {
  73. pad_h = PH;
  74. pad_w = PW;
  75. auto mkly = [](const TensorShape& s) {
  76. return TensorLayout{s, dtype::Float32()};
  77. };
  78. lsrc = mkly({N, IC, IH, IW});
  79. lflt0 = mkly({CHL_MUL * IC, IC, FH, FW});
  80. lflt1 = mkly({IC, CHL_MUL, 1, FH, FW});
  81. ldst = mkly({N, IC * CHL_MUL, IH - FH + 1 + PH * 2, IW - FW + 1 + PW * 2});
  82. src0.reset(new Tensor<>(handle, lsrc));
  83. src1.reset(new Tensor<>(handle, lsrc));
  84. flt0.reset(new Tensor<>(handle, lflt0));
  85. flt0_cpu.reset(new Tensor<>(handle_cpu, lflt0));
  86. flt1.reset(new Tensor<>(handle, lflt1));
  87. flt1_cpu.reset(new Tensor<>(handle_cpu, lflt1));
  88. dst0.reset(new Tensor<>(handle, ldst));
  89. dst1.reset(new Tensor<>(handle, ldst));
  90. }
  91. void fill_src() {
  92. rng->exec(src0->tensornd(), {});
  93. megdnn_memcpy_D2D(handle, src1->ptr(), src0->ptr(), lsrc.span().dist_byte());
  94. }
  95. void fill_flt() {
  96. rng->exec(flt1->tensornd(), {});
  97. megdnn_memcpy_D2H(
  98. handle, flt1_cpu->ptr(), flt1->ptr(), lflt1.span().dist_byte());
  99. const size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  100. // fill flt0 from flt1
  101. float* src = flt1_cpu->ptr();
  102. float* dst = flt0_cpu->ptr();
  103. memset(dst, 0, lflt0.span().dist_byte());
  104. for (size_t i = 0; i < IC; ++i) {
  105. for (size_t j = 0; j < CHL_MUL; ++j) {
  106. memcpy(dst + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  107. src + (i * CHL_MUL + j) * FSIZE, FSIZE * sizeof(float));
  108. }
  109. }
  110. megdnn_memcpy_H2D(handle, flt0->ptr(), dst, lflt0.span().dist_byte());
  111. }
  112. void fill_dst() {
  113. rng->exec(dst0->tensornd(), {});
  114. megdnn_memcpy_D2D(handle, dst1->ptr(), dst0->ptr(), ldst.span().dist_byte());
  115. }
  116. template <class Opr>
  117. void exec(Opr* opr0, Opr* opr1) {
  118. opr0->param().pad_h = pad_h;
  119. opr0->param().pad_w = pad_w;
  120. opr1->param() = opr0->param();
  121. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  122. TensorND a0, b0, c0, a1, b1, c1;
  123. std::tie(a0, b0, c0) = shuffle(
  124. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  125. std::tie(a1, b1, c1) = shuffle(
  126. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  127. WorkspaceWrapper wk(
  128. handle,
  129. std::max(
  130. opr0->get_workspace_in_bytes(a0.layout, b0.layout, c0.layout),
  131. opr1->get_workspace_in_bytes(a1.layout, b1.layout, c1.layout)));
  132. cudaProfilerStart();
  133. cudaEventRecord(cuda_ev[0], cuda_stream);
  134. opr0->exec(a0, b0, c0, wk.workspace());
  135. cudaEventRecord(cuda_ev[1], cuda_stream);
  136. opr1->exec(a1, b1, c1, wk.workspace());
  137. cudaEventRecord(cuda_ev[2], cuda_stream);
  138. cudaProfilerStop();
  139. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  140. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  141. cudaStreamSynchronize(cuda_stream);
  142. float t0 = -1, t1 = -1;
  143. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  144. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  145. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  146. lsrc.TensorShape::to_string().c_str(),
  147. lflt1.TensorShape::to_string().c_str(),
  148. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  149. }
  150. }
  151. //! special for weight preprocess
  152. void exec_convolution(ConvolutionForward* opr0, ConvolutionForward* opr1) {
  153. opr0->param().pad_h = pad_h;
  154. opr0->param().pad_w = pad_w;
  155. opr1->param() = opr0->param();
  156. opr1->param().sparse = param::Convolution::Sparse::GROUP;
  157. TensorND a0, b0, c0, a1, b1, c1;
  158. std::tie(a0, b0, c0) = shuffle(
  159. std::make_tuple(src0->tensornd(), flt0->tensornd(), dst0->tensornd()));
  160. std::tie(a1, b1, c1) = shuffle(
  161. std::make_tuple(src1->tensornd(), flt1->tensornd(), dst1->tensornd()));
  162. WorkspaceWrapper wk(
  163. handle, std::max(
  164. opr0->get_workspace_in_bytes(
  165. a0.layout, b0.layout, c0.layout, nullptr),
  166. opr1->get_workspace_in_bytes(
  167. a1.layout, b1.layout, c1.layout, nullptr)));
  168. cudaProfilerStart();
  169. cudaEventRecord(cuda_ev[0], cuda_stream);
  170. opr0->exec(a0, b0, c0, nullptr, wk.workspace());
  171. cudaEventRecord(cuda_ev[1], cuda_stream);
  172. opr1->exec(a1, b1, c1, nullptr, wk.workspace());
  173. cudaEventRecord(cuda_ev[2], cuda_stream);
  174. cudaProfilerStop();
  175. if (getenv("MEGDNN_CHANWISE_CONV_VERBOSE") ||
  176. getenv("MEGDNN_CHANWISE_CONV_FULLBENCH")) {
  177. cudaStreamSynchronize(cuda_stream);
  178. float t0 = -1, t1 = -1;
  179. cudaEventElapsedTime(&t0, cuda_ev[0], cuda_ev[1]);
  180. cudaEventElapsedTime(&t1, cuda_ev[1], cuda_ev[2]);
  181. printf("%s;%s;%s: cudnn/megdnn: %.3fms/%.3fms=%.3f\n",
  182. lsrc.TensorShape::to_string().c_str(),
  183. lflt1.TensorShape::to_string().c_str(),
  184. ldst.TensorShape::to_string().c_str(), t0, t1, t0 / t1);
  185. }
  186. }
  187. void cmp_dst() {
  188. Tensor<> dst0_cpu(handle_cpu, ldst), dst1_cpu(handle_cpu, ldst);
  189. megdnn_memcpy_D2H(handle, dst0_cpu.ptr(), dst0->ptr(), ldst.span().dist_byte());
  190. megdnn_memcpy_D2H(handle, dst1_cpu.ptr(), dst1->ptr(), ldst.span().dist_byte());
  191. dst0_cpu.check_with(dst1_cpu);
  192. }
  193. void cmp_src() {
  194. Tensor<> src0_cpu(handle_cpu, lsrc), src1_cpu(handle_cpu, lsrc);
  195. megdnn_memcpy_D2H(handle, src0_cpu.ptr(), src0->ptr(), lsrc.span().dist_byte());
  196. megdnn_memcpy_D2H(handle, src1_cpu.ptr(), src1->ptr(), lsrc.span().dist_byte());
  197. src0_cpu.check_with(src1_cpu);
  198. }
  199. void cmp_flt() {
  200. Tensor<> flt0_cpu(handle_cpu, lflt0), flt1_cpu(handle_cpu, lflt1);
  201. float* p0 = flt0_cpu.ptr();
  202. float* p1 = flt1_cpu.ptr();
  203. megdnn_memcpy_D2H(handle, p0, flt0->ptr(), lflt0.span().dist_byte());
  204. megdnn_memcpy_D2H(handle, p1, flt1->ptr(), lflt1.span().dist_byte());
  205. size_t IC = lflt1[0], CHL_MUL = lflt1[1], FSIZE = lflt1[3] * lflt1[4];
  206. double tot_err = 0, tot_err_num = 0;
  207. for (size_t i = 0; i < IC; ++i) {
  208. for (size_t j = 0; j < CHL_MUL; ++j) {
  209. auto t0 = p0 + ((i * CHL_MUL + j) * IC + i) * FSIZE,
  210. t1 = p1 + (i * CHL_MUL + j) * FSIZE;
  211. for (size_t k = 0; k < FSIZE; ++k) {
  212. auto err = std::abs(diff(t0[k], t1[k]));
  213. tot_err += err;
  214. tot_err_num += 1;
  215. ASSERT_LT(err, 1e-2) << "failed at " << i << " " << j << " " << k
  216. << " vals=" << t0[k] << "," << t1[k];
  217. }
  218. }
  219. }
  220. auto avg_err = tot_err / tot_err_num;
  221. ASSERT_LT(avg_err, 1e-4);
  222. }
  223. };
  224. } // anonymous namespace
  225. constexpr auto M = Convolution::Mode::CROSS_CORRELATION;
  226. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD) {
  227. Checker<Convolution> checker(handle_cuda());
  228. bool require_algo = false;
  229. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  230. ExecutionPolicyAlgoName{
  231. "DEFAULT",
  232. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  233. "CHANNEL_WISE", {})
  234. .c_str(),
  235. {}}}},
  236. &require_algo));
  237. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  238. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  239. if (dtype.enumv() == DTypeEnum::Float16)
  240. checker.set_epsilon(2e-2);
  241. // simple case
  242. // clang-format off
  243. for (uint32_t s : {1, 2})
  244. for (uint32_t p : {0, 1, 2, 3})
  245. for (size_t f : {2, 3, 5, 7})
  246. for (size_t ocpg : {1, 3}) {
  247. checker.set_param(gconv_param({M, p, p, s, s}))
  248. .execs({{2, 3, 16, 16}, {3, ocpg, 1, f, f}, {}});
  249. }
  250. // clang-format on
  251. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  252. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  253. // padding larger than kern
  254. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  255. .execs({{32, 12, 20, 10}, {12, 2, 1, 4, 5}, {}});
  256. }
  257. }
  258. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_SMALL) {
  259. Checker<Convolution> checker(handle_cuda());
  260. bool require_algo = false;
  261. checker.set_before_exec_callback(AlgoChecker<ConvolutionForward>(
  262. ExecutionPolicyAlgoName{
  263. "DEFAULT",
  264. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  265. "CHANNEL_WISE_SMALL", {})
  266. .c_str(),
  267. {}}}},
  268. &require_algo));
  269. for (auto dtype : std::vector<DType> {
  270. dtype::Float32(),
  271. #if CUDA_VERSION >= 9000
  272. dtype::Float16()
  273. #endif
  274. }) {
  275. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  276. if (dtype.enumv() == DTypeEnum::Float16)
  277. checker.set_epsilon(2e-2);
  278. // clang-format off
  279. for (uint32_t s : {1})
  280. for (uint32_t f : {1, 3, 5, 7}) {
  281. checker.set_param(gconv_param({M, f / 2, f / 2, s, s}))
  282. .execs({{2, 3, 16, 16}, {3, 1, 1, f, f}, {}});
  283. }
  284. // clang-format on
  285. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  286. .execs({{2, 3, 3, 16}, {3, 1, 1, 3, 3}, {}})
  287. .execs({{2, 3, 8, 3}, {3, 1, 1, 3, 3}, {}});
  288. }
  289. }
  290. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA) {
  291. Checker<ConvolutionBackwardData> checker(handle_cuda());
  292. bool require_algo = false;
  293. checker.set_before_exec_callback(
  294. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE", &require_algo));
  295. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  296. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  297. if (dtype.enumv() == DTypeEnum::Float16)
  298. checker.set_epsilon(1e-1);
  299. // simple case
  300. // clang-format off
  301. for (uint32_t s : {1, 2})
  302. for (uint32_t p : {0, 1, 2, 3})
  303. for (size_t f : {1, 2, 3, 5, 7})
  304. for (size_t ocpg : {1, 3}) {
  305. size_t ii = infer_conv_shape(16, f, s, p, true);
  306. checker.set_param(gconv_param({M, p, p, s, s}))
  307. .execs({{3, ocpg, 1, f, f},
  308. {2, 3 * ocpg, ii, ii},
  309. {2, 3, 16, 16}});
  310. }
  311. // clang-format on
  312. checker.set_param(gconv_param({M, 2, 3, 2, 1}))
  313. .execs({{12, 3, 1, 4, 5}, {32, 36, 20, 10}, {32, 12, 39, 8}});
  314. checker.set_param(gconv_param({M, 30, 20, 5, 4}))
  315. .execs({{6, 2, 1, 5, 4}, {32, 12, 12, 10}, {32, 6, 3, 2}});
  316. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  317. .execs({{6, 2, 1, 4, 5}, {32, 12, 10, 12}, {32, 6, 2, 3}});
  318. }
  319. }
  320. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_SMALL) {
  321. Checker<ConvolutionBackwardData> checker(handle_cuda());
  322. bool require_algo = false;
  323. checker.set_before_exec_callback(
  324. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL", &require_algo));
  325. for (auto dtype : std::vector<DType> {
  326. dtype::Float32(),
  327. #if CUDA_VERSION >= 9000
  328. dtype::Float16()
  329. #endif
  330. }) {
  331. checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
  332. if (dtype.enumv() == DTypeEnum::Float16)
  333. checker.set_epsilon(2e-2);
  334. for (uint32_t f : {1, 3, 5, 7}) {
  335. checker.set_param(gconv_param({M, f / 2, f / 2, 1, 1}))
  336. .execs({{3, 1, 1, f, f}, {2, 3, 16, 16}, {2, 3, 16, 16}});
  337. }
  338. checker.set_param(gconv_param({M, 1, 1, 1, 1}))
  339. .execs({{3, 1, 1, 3, 3}, {2, 3, 3, 16}, {2, 3, 3, 16}})
  340. .execs({{3, 1, 1, 3, 3}, {2, 3, 8, 3}, {2, 3, 8, 3}});
  341. }
  342. }
  343. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER) {
  344. Checker<ConvolutionBackwardFilter> checker(handle_cuda());
  345. bool require_algo = false;
  346. checker.set_before_exec_callback(
  347. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE", &require_algo));
  348. UniformFloatRNG rng(-0.1, 0.1);
  349. for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
  350. checker.set_dtype(0, dtype)
  351. .set_dtype(1, dtype)
  352. .set_dtype(2, dtype)
  353. .set_rng(0, &rng)
  354. .set_rng(1, &rng);
  355. if (dtype.enumv() == DTypeEnum::Float16)
  356. checker.set_epsilon(2e-1);
  357. // simple case
  358. // clang-format off
  359. for (uint32_t s : {1, 2})
  360. for (uint32_t p : {0, 1, 2, 3})
  361. for (uint32_t f : {1, 2, 3, 5, 7})
  362. for (uint32_t ocpg : {1, 3})
  363. for (uint32_t i : {8, 16, 32, 64}){
  364. size_t ii = infer_conv_shape(i, f, s, p, true);
  365. checker.set_param(gconv_param({M, p, p, s, s}))
  366. .execs({{2, 3, i, i},
  367. {2, 3 * ocpg, ii, ii},
  368. {3, ocpg, 1, f, f}});
  369. }
  370. // clang-format on
  371. // padding larger than kern
  372. checker.set_param(gconv_param({M, 20, 30, 4, 5}))
  373. .execs({{32, 6, 2, 3}, {32, 12, 10, 12}, {6, 2, 1, 4, 5}});
  374. // unused filter items
  375. checker.set_param(gconv_param({M, 2, 3, 2, 3}))
  376. .execs({{32, 6, 1, 1}, {32, 12, 1, 1}, {6, 2, 1, 5, 7}});
  377. }
  378. }
  379. namespace {
  380. template <typename Op>
  381. struct AlgoCheckerMaker {
  382. static auto make(const char* name, bool* require_algo) {
  383. return AlgoChecker<Op>(name, require_algo);
  384. }
  385. };
  386. template <>
  387. struct AlgoCheckerMaker<ConvolutionForward> {
  388. static auto make(const char* name, bool* require_algo) {
  389. return AlgoChecker<ConvolutionForward>(
  390. ExecutionPolicyAlgoName{
  391. "DEFAULT",
  392. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  393. name, {})
  394. .c_str(),
  395. {}}}},
  396. require_algo);
  397. }
  398. };
  399. template <typename Op>
  400. void check_chanwise(DType io_type, DType comp_type, Handle* handle, const char* name) {
  401. Checker<Op> checker(handle);
  402. bool require_algo = false;
  403. checker.set_before_exec_callback(AlgoCheckerMaker<Op>::make(name, &require_algo));
  404. checker.set_dtype(0, io_type).set_dtype(1, io_type).set_dtype(2, io_type);
  405. bool io16xc32 = false;
  406. if (io_type == dtype::Float16()) {
  407. if (comp_type == dtype::Float16()) {
  408. checker.set_epsilon(1e-1);
  409. } else {
  410. io16xc32 = true;
  411. }
  412. }
  413. // dispatch testcase by operation
  414. if (std::is_same<Op, ConvolutionForward>::value) {
  415. // align 8
  416. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  417. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  418. // align 1
  419. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  420. .execs({{8, 2, 15, 15}, {2, 1, 1, 15, 15}, {}});
  421. // align 2
  422. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  423. .execs({{8, 2, 14, 14}, {2, 1, 1, 15, 15}, {}});
  424. // custom padding
  425. checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
  426. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  427. // custom stride
  428. checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
  429. .execs({{8, 2, 16, 16}, {2, 1, 1, 15, 15}, {}});
  430. } else if (std::is_same<Op, ConvolutionBackwardData>::value) {
  431. // align 8
  432. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  433. .execs({{2, 1, 1, 15, 15}, {8, 2, 16, 16}, {8, 2, 16, 16}});
  434. // align 1
  435. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  436. .execs({{2, 1, 1, 15, 15}, {8, 2, 15, 15}, {8, 2, 15, 15}});
  437. // align 2
  438. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  439. .execs({{2, 1, 1, 15, 15}, {8, 2, 14, 14}, {8, 2, 14, 14}});
  440. // custom padding
  441. checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
  442. .execs({{2, 1, 1, 15, 15}, {8, 2, 8, 8}, {8, 2, 16, 16}});
  443. // custom stride
  444. checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
  445. .execs({{2, 1, 1, 15, 15}, {8, 2, 7, 7}, {8, 2, 14, 14}});
  446. } else if (std::is_same<Op, ConvolutionBackwardFilter>::value) {
  447. // align 8
  448. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  449. .execs({{8, 2, 16, 16}, {8, 2, 16, 16}, {2, 1, 1, 15, 15}});
  450. // align 1
  451. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  452. .execs({{8, 2, 15, 15}, {8, 2, 15, 15}, {2, 1, 1, 15, 15}});
  453. // align 2
  454. checker.set_param(gconv_param({M, 7, 7, 1, 1}, io16xc32))
  455. .execs({{8, 2, 14, 14}, {8, 2, 14, 14}, {2, 1, 1, 15, 15}});
  456. // custom padding
  457. checker.set_param(gconv_param({M, 3, 3, 1, 1}, io16xc32))
  458. .execs({{8, 2, 16, 16}, {8, 2, 8, 8}, {2, 1, 1, 15, 15}});
  459. // custom stride
  460. checker.set_param(gconv_param({M, 7, 7, 2, 2}, io16xc32))
  461. .execs({{8, 2, 14, 14}, {8, 2, 7, 7}, {2, 1, 1, 15, 15}});
  462. }
  463. }
  464. } // namespace
  465. #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb) \
  466. cb(1, 128, 128, 8, 32, 64, 8); \
  467. cb(2, 128, 64, 8, 64, 32, 8); \
  468. cb(3, 128, 32, 8, 64, 32, 8); \
  469. cb(4, 64, 128, 8, 32, 64, 8); \
  470. cb(5, 32, 128, 8, 32, 64, 8); \
  471. cb(6, 64, 64, 8, 32, 64, 8); \
  472. cb(7, 32, 64, 8, 32, 64, 8); \
  473. cb(8, 32, 32, 8, 32, 32, 8); \
  474. cb(9, 64, 32, 8, 64, 32, 8);
  475. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  476. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_FMA_##tag) { \
  477. require_compute_capability(6, 1); \
  478. check_chanwise<ConvolutionForward>( \
  479. dtype::Float32(), dtype::Float32(), handle_cuda(), \
  480. "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  481. "_" #wm "X" #wn "X" #wk "_2stage"); \
  482. }
  483. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
  484. #undef cb
  485. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  486. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_CUTLASS_FMA_##tag) { \
  487. require_compute_capability(6, 1); \
  488. check_chanwise<ConvolutionBackwardData>( \
  489. dtype::Float32(), dtype::Float32(), handle_cuda(), \
  490. "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  491. "_" #wm "X" #wn "X" #wk "_2stage"); \
  492. }
  493. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
  494. #undef cb
  495. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  496. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_FMA_##tag) { \
  497. require_compute_capability(6, 1); \
  498. check_chanwise<ConvolutionBackwardFilter>( \
  499. dtype::Float32(), dtype::Float32(), handle_cuda(), \
  500. "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  501. "_" #wm "X" #wn "X" #wk "_2stage"); \
  502. }
  503. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL(cb)
  504. #undef cb
  505. #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FMA_KERNEL
  506. #if CUDA_VERSION >= 10010
  507. #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb) \
  508. cb(1, 128, 128, 32, 32, 32, 32); \
  509. cb(2, 128, 256, 32, 64, 64, 32); \
  510. cb(3, 128, 64, 32, 32, 32, 32); \
  511. cb(4, 64, 128, 32, 32, 32, 32); \
  512. cb(5, 64, 64, 32, 32, 32, 32);
  513. #else
  514. // hmma instruction need cuda version >= 10.2, disable hmma testcases in this path
  515. #define MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
  516. #endif
  517. // check both ioc16 and io16xc32
  518. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  519. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_CUTLASS_HMMA_##tag) { \
  520. require_compute_capability(7, 0); \
  521. check_chanwise<ConvolutionForward>( \
  522. dtype::Float16(), dtype::Float16(), handle_cuda(), \
  523. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  524. "_" #wm "X" #wn "X" #wk "_2stage"); \
  525. check_chanwise<ConvolutionForward>( \
  526. dtype::Float16(), dtype::Float32(), handle_cuda(), \
  527. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  528. "_" #wm "X" #wn "X" #wk "_2stage"); \
  529. }
  530. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
  531. #undef cb
  532. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  533. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_DATA_CUTLASS_HMMA_##tag) { \
  534. require_compute_capability(7, 0); \
  535. check_chanwise<ConvolutionBackwardData>( \
  536. dtype::Float16(), dtype::Float16(), handle_cuda(), \
  537. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  538. "_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
  539. check_chanwise<ConvolutionBackwardData>( \
  540. dtype::Float16(), dtype::Float32(), handle_cuda(), \
  541. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  542. "_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
  543. }
  544. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
  545. #undef cb
  546. #define cb(tag, tbm, tbn, tbk, wm, wn, wk) \
  547. TEST_F(CUDA, CHANWISE_CONVOLUTION_BACKWARD_FILTER_CUTLASS_HMMA_##tag) { \
  548. require_compute_capability(7, 0); \
  549. check_chanwise<ConvolutionBackwardData>( \
  550. dtype::Float16(), dtype::Float32(), handle_cuda(), \
  551. "FLOAT16_NCHW_HMMA_IMPLICIT_BATCHED_GEMM_" #tbm "X" #tbn "X" #tbk \
  552. "_" #wm "X" #wn "X" #wk "_mma8X8X4_2stage"); \
  553. }
  554. MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_HMMA_KERNEL(cb)
  555. #undef cb
  556. #undef MEGDNN_FOREACH_CUTLASS_CHANWISE_CONV_FWD_HMMA_KERNEL
  557. #if MEGDNN_WITH_BENCHMARK
  558. TEST_F(CUDA, CHANWISE_CONVOLUTION_FORWARD_BENCH_CHECK) {
  559. auto handle = handle_cuda();
  560. auto handle_cpu = handle_naive();
  561. auto conv0 = handle->create_operator<ConvolutionForward>();
  562. auto conv1 = handle->create_operator<ConvolutionForward>();
  563. BenchmarkEnv<0, 1, 2> benv(handle, handle_cpu);
  564. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  565. size_t FW, size_t PH, size_t PW) {
  566. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  567. benv.fill_src();
  568. benv.fill_flt();
  569. benv.exec_convolution(conv0.get(), conv1.get());
  570. benv.cmp_dst();
  571. };
  572. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  573. if (check_need_full_bench()) {
  574. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  575. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  576. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  577. }
  578. }
  579. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_DATA_BENCH_CHECK) {
  580. auto handle = handle_cuda();
  581. auto handle_cpu = handle_naive();
  582. auto conv0 = handle->create_operator<ConvolutionBackwardData>();
  583. auto conv1 = handle->create_operator<ConvolutionBackwardData>();
  584. BenchmarkEnv<1, 2, 0> benv(handle, handle_cpu);
  585. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  586. size_t FW, size_t PH, size_t PW) {
  587. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  588. benv.fill_dst();
  589. benv.fill_flt();
  590. benv.exec(conv0.get(), conv1.get());
  591. benv.cmp_src();
  592. };
  593. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  594. if (check_need_full_bench()) {
  595. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  596. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  597. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  598. }
  599. }
  600. TEST_F(CUDA, CHANWISE_CONVOLUTION_BWD_FILTER_BENCH_CHECK) {
  601. auto handle = handle_cuda();
  602. auto handle_cpu = handle_naive();
  603. auto conv0 = handle->create_operator<ConvolutionBackwardFilter>();
  604. auto conv1 = handle->create_operator<ConvolutionBackwardFilter>();
  605. BenchmarkEnv<0, 2, 1> benv(handle, handle_cpu);
  606. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t CHL_MUL, size_t FH,
  607. size_t FW, size_t PH, size_t PW) {
  608. benv.alloc(N, IC, IH, IW, CHL_MUL, FH, FW, PH, PW);
  609. benv.fill_src();
  610. benv.fill_dst();
  611. benv.exec(conv0.get(), conv1.get());
  612. benv.cmp_flt();
  613. };
  614. run(64, 60, 50, 50, 1, 3, 3, 1, 1);
  615. if (check_need_full_bench()) {
  616. run(64, 728, 18, 18, 2, 5, 5, 2, 2);
  617. run(64, 64, 150, 150, 2, 3, 3, 1, 1);
  618. run(1, 2048, 4, 4, 2, 3, 3, 1, 1);
  619. }
  620. }
  621. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_FWD) {
  622. // enable profiling
  623. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  624. new OprProxy<ConvolutionForward>{true}};
  625. proxy->warmup_times = 1;
  626. proxy->exec_times = 10;
  627. Benchmarker<ConvolutionForward> checker(handle_cuda());
  628. checker.set_times(1);
  629. ConvolutionForward::Param param;
  630. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  631. checker.set_param(param);
  632. checker.set_proxy(proxy);
  633. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  634. checker.proxy()->target_execution_policy = {};
  635. checker.execs({{N, C, IH, IW}, {C, 1, 1, FH, FW}, {}});
  636. };
  637. run(128, 64, 90, 80, 3, 3);
  638. run(128, 90, 100, 100, 3, 5);
  639. run(128, 32, 62, 62, 5, 5);
  640. }
  641. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_DATA) {
  642. // enable profiling
  643. std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
  644. new OprProxy<ConvolutionBackwardData>{true}};
  645. proxy->warmup_times = 1;
  646. proxy->exec_times = 10;
  647. Benchmarker<ConvolutionBackwardData> checker(handle_cuda());
  648. checker.set_times(1);
  649. ConvolutionBackwardData::Param param;
  650. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  651. checker.set_param(param);
  652. checker.set_proxy(proxy);
  653. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  654. checker.proxy()->target_execution_policy.algo.reset();
  655. checker.execs(
  656. {{C, 1, 1, FH, FW}, {N, C, IH - FH + 1, IW - FW + 1}, {N, C, IH, IW}});
  657. };
  658. run(128, 64, 90, 80, 3, 3);
  659. run(128, 90, 100, 100, 3, 5);
  660. run(128, 32, 62, 62, 5, 5);
  661. }
  662. TEST_F(CUDA, CHANWISE_CONVOLUTION_BENCH_ALL_ALGO_BWD_FILTER) {
  663. // enable profiling
  664. std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
  665. new OprProxy<ConvolutionBackwardFilter>{true}};
  666. proxy->warmup_times = 1;
  667. proxy->exec_times = 10;
  668. Benchmarker<ConvolutionBackwardFilter> checker(handle_cuda());
  669. checker.set_times(1);
  670. ConvolutionBackwardFilter::Param param;
  671. param.sparse = ConvolutionForward::Param::Sparse::GROUP;
  672. checker.set_param(param);
  673. checker.set_proxy(proxy);
  674. auto run = [&](size_t N, size_t C, size_t IH, size_t IW, size_t FH, size_t FW) {
  675. checker.proxy()->target_execution_policy.algo.reset();
  676. checker.execs(
  677. {{N, C, IH, IW}, {N, C, IH - FH + 1, IW - FW + 1}, {C, 1, 1, FH, FW}});
  678. };
  679. run(128, 64, 90, 80, 3, 3);
  680. run(128, 90, 100, 100, 3, 5);
  681. run(128, 32, 62, 62, 5, 5);
  682. }
  683. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_ALL_ALGO_FORWARD) {
  684. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  685. size_t RUNS = 10;
  686. bencher.set_display(false).set_times(RUNS);
  687. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  688. new OprProxy<ConvolutionForward>{true}};
  689. bencher.set_proxy(proxy);
  690. Convolution::Param param;
  691. param.format = ConvBias::Param::Format::NCHW;
  692. param.sparse = Convolution::Param::Sparse::GROUP;
  693. NormalRNG rng;
  694. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  695. param.pad_h = f / 2;
  696. param.pad_w = f / 2;
  697. param.stride_h = s;
  698. param.stride_w = s;
  699. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  700. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  701. TensorLayout dst_layout;
  702. auto opr = handle_cuda()->create_operator<Convolution>();
  703. opr->param() = param;
  704. opr->deduce_layout(
  705. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  706. float bandwith = static_cast<float>(
  707. src.total_nr_elems() + filter.total_nr_elems() +
  708. dst_layout.total_nr_elems()) /
  709. (1024 * 1024 * 1024) * 1e3;
  710. bencher.set_param(param)
  711. .set_dtype(0, dtype::Float32())
  712. .set_dtype(1, dtype::Float32())
  713. .set_dtype(2, dtype::Float32())
  714. .set_rng(0, &rng)
  715. .set_rng(1, &rng);
  716. bencher.proxy()->target_execution_policy = {};
  717. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  718. bencher.set_param(param)
  719. .set_dtype(0, dtype::Float16())
  720. .set_dtype(1, dtype::Float16())
  721. .set_dtype(2, dtype::Float16())
  722. .set_rng(0, &rng)
  723. .set_rng(1, &rng);
  724. bencher.proxy()->target_execution_policy = {};
  725. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  726. bencher.proxy()->target_execution_policy.algo.reset();
  727. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  728. bencher.set_param(param);
  729. auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  730. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  731. "float16: %.2fms %.2fGB/s "
  732. "pseudo float16: %.2fms %.2fGB/s "
  733. "speedup: "
  734. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  735. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  736. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  737. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  738. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  739. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  740. };
  741. // clang-format off
  742. for (size_t s : {1, 2})
  743. for (size_t f : {3, 5, 7})
  744. for (size_t batch : {64})
  745. for (size_t c : {16, 32, 64, 128})
  746. for (size_t ih: {128, 256})
  747. for (size_t iw : {128, 256})
  748. run(batch, c, ih, iw, f, s);
  749. // clang-format on
  750. run(128, 192, 28, 28, 3, 1);
  751. run(128, 192, 28, 28, 3, 2);
  752. run(128, 576, 14, 14, 3, 1);
  753. run(128, 384, 14, 14, 3, 1);
  754. run(128, 32, 112, 112, 3, 1);
  755. run(128, 960, 7, 7, 3, 1);
  756. run(128, 384, 14, 14, 3, 1);
  757. run(128, 144, 56, 56, 3, 2);
  758. run(128, 384, 14, 14, 3, 1);
  759. run(128, 144, 56, 56, 3, 1);
  760. run(128, 96, 112, 112, 3, 2);
  761. run(128, 384, 14, 14, 3, 1);
  762. run(128, 192, 28, 28, 3, 1);
  763. run(128, 576, 14, 14, 3, 1);
  764. run(128, 576, 14, 14, 3, 2);
  765. }
  766. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT) {
  767. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  768. size_t RUNS = 1;
  769. bencher.set_display(false).set_times(RUNS);
  770. bencher.set_before_exec_callback(
  771. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  772. "DEFAULT",
  773. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  774. "CHANNEL_WISE", {})
  775. .c_str(),
  776. {}}}}));
  777. Convolution::Param param;
  778. param.format = ConvBias::Param::Format::NCHW;
  779. param.sparse = Convolution::Param::Sparse::GROUP;
  780. NormalRNG rng;
  781. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  782. param.pad_h = f / 2;
  783. param.pad_w = f / 2;
  784. param.stride_h = s;
  785. param.stride_w = s;
  786. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  787. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  788. TensorLayout dst_layout;
  789. auto opr = handle_cuda()->create_operator<Convolution>();
  790. opr->param() = param;
  791. opr->deduce_layout(
  792. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  793. float bandwith = static_cast<float>(
  794. src.total_nr_elems() + filter.total_nr_elems() +
  795. dst_layout.total_nr_elems()) /
  796. (1024 * 1024 * 1024) * 1e3;
  797. bencher.set_param(param)
  798. .set_dtype(0, dtype::Float32())
  799. .set_dtype(1, dtype::Float32())
  800. .set_dtype(2, dtype::Float32())
  801. .set_rng(0, &rng)
  802. .set_rng(1, &rng);
  803. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  804. bencher.set_param(param)
  805. .set_dtype(0, dtype::Float16())
  806. .set_dtype(1, dtype::Float16())
  807. .set_dtype(2, dtype::Float16())
  808. .set_rng(0, &rng)
  809. .set_rng(1, &rng);
  810. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  811. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  812. "float16: %.2fms %.2fGB/s "
  813. "speedup: "
  814. "%0.2f (fp16/fp32)\n",
  815. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  816. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  817. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  818. };
  819. // clang-format off
  820. for (size_t s : {1})
  821. for (size_t f : {3, 5, 7})
  822. for (size_t batch : {64})
  823. for (size_t c : {16, 32, 64, 128})
  824. for (size_t ih: {8, 16, 32, 128, 256})
  825. for (size_t iw : {8, 16, 32, 128, 256})
  826. run(batch, c, ih, iw, f, s);
  827. // clang-format on
  828. }
  829. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_FLOAT_SMALL) {
  830. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  831. size_t RUNS = 1;
  832. bencher.set_display(false).set_times(RUNS);
  833. Convolution::Param param;
  834. param.format = ConvBias::Param::Format::NCHW;
  835. param.sparse = Convolution::Param::Sparse::GROUP;
  836. NormalRNG rng;
  837. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  838. param.pad_h = f / 2;
  839. param.pad_w = f / 2;
  840. param.stride_h = s;
  841. param.stride_w = s;
  842. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  843. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  844. TensorLayout dst_layout;
  845. auto opr = handle_cuda()->create_operator<Convolution>();
  846. opr->param() = param;
  847. opr->deduce_layout(
  848. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  849. float bandwith = static_cast<float>(
  850. src.total_nr_elems() + filter.total_nr_elems() +
  851. dst_layout.total_nr_elems()) /
  852. (1024 * 1024 * 1024) * 1e3;
  853. bencher.set_param(param)
  854. .set_dtype(0, dtype::Float32())
  855. .set_dtype(1, dtype::Float32())
  856. .set_dtype(2, dtype::Float32())
  857. .set_rng(0, &rng)
  858. .set_rng(1, &rng)
  859. .set_before_exec_callback(AlgoChecker<
  860. ConvolutionForward>(ExecutionPolicyAlgoName{
  861. "DEFAULT",
  862. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  863. "CHANNEL_WISE", {})
  864. .c_str(),
  865. {}}}}));
  866. auto time_in_ms_fp32_normal = bencher.execs({src, filter, {}}) / RUNS;
  867. bencher.set_before_exec_callback(
  868. AlgoChecker<ConvolutionForward>(ExecutionPolicyAlgoName{
  869. "DEFAULT",
  870. {{ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
  871. "CHANNEL_WISE", {})
  872. .c_str(),
  873. {}}}}));
  874. auto time_in_ms_fp32_small = bencher.execs({src, filter, {}}) / RUNS;
  875. bencher.set_param(param)
  876. .set_dtype(0, dtype::Float16())
  877. .set_dtype(1, dtype::Float16())
  878. .set_dtype(2, dtype::Float16())
  879. .set_rng(0, &rng)
  880. .set_rng(1, &rng);
  881. auto time_in_ms_fp16_small = bencher.execs({src, filter, {}}) / RUNS;
  882. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  883. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  884. "speedup: "
  885. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  886. s, src.to_string().c_str(), filter.to_string().c_str(),
  887. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  888. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  889. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  890. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  891. time_in_ms_fp32_small / time_in_ms_fp16_small);
  892. };
  893. // clang-format off
  894. for (size_t s : {1})
  895. for (size_t f : {3, 5})
  896. for (size_t batch : {64})
  897. for (size_t c : {16, 32, 64, 128})
  898. for (size_t ih: {8, 16, 32})
  899. for (size_t iw : {8, 16, 32})
  900. run(batch, c, ih, iw, f, s);
  901. // clang-format on
  902. run(128, 192, 28, 28, 3, 1);
  903. run(128, 576, 14, 14, 3, 1);
  904. run(128, 384, 14, 14, 3, 1);
  905. run(128, 960, 7, 7, 3, 1);
  906. run(128, 384, 14, 14, 3, 1);
  907. run(128, 384, 14, 14, 3, 1);
  908. run(128, 384, 14, 14, 3, 1);
  909. run(128, 192, 28, 28, 3, 1);
  910. run(128, 576, 14, 14, 3, 1);
  911. }
  912. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_CUDNN_DNN) {
  913. CUBenchmarker<ConvBiasForward> bencher(handle_cuda());
  914. size_t RUNS = 1;
  915. bencher.set_display(false).set_times(RUNS);
  916. ConvBias::Param param;
  917. param.format = ConvBias::Param::Format::NCHW;
  918. param.sparse = ConvBias::Param::Sparse::GROUP;
  919. NormalRNG rng;
  920. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  921. param.pad_h = f / 2;
  922. param.pad_w = f / 2;
  923. param.stride_h = s;
  924. param.stride_w = s;
  925. param.compute_mode = param::ConvBias::ComputeMode::DEFAULT;
  926. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f},
  927. bias = {1, c, 1, 1};
  928. TensorLayout dst_layout;
  929. auto opr = handle_cuda()->create_operator<ConvBias>();
  930. opr->param() = param;
  931. opr->deduce_layout(
  932. {src, dtype::Float32()}, {filter, dtype::Float32()},
  933. {bias, dtype::Float32()}, {}, dst_layout);
  934. float computation_mops =
  935. static_cast<float>(dst_layout.total_nr_elems() * f * f * 2) * 1e-6;
  936. bencher.set_param(param)
  937. .set_dtype(0, dtype::Float32())
  938. .set_dtype(1, dtype::Float32())
  939. .set_dtype(2, dtype::Float32())
  940. .set_rng(0, &rng)
  941. .set_rng(1, &rng);
  942. bencher.set_before_exec_callback(
  943. AlgoChecker<ConvBiasForward>(".+CHANNEL_WISE.+"));
  944. auto time_in_ms_dnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  945. bencher.set_param(param)
  946. .set_dtype(0, dtype::Float32())
  947. .set_dtype(1, dtype::Float32())
  948. .set_dtype(2, dtype::Float32())
  949. .set_rng(0, &rng)
  950. .set_rng(1, &rng);
  951. bencher.set_before_exec_callback(AlgoChecker<ConvBiasForward>(
  952. ".+CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM.+"));
  953. auto time_in_ms_cudnn = bencher.execs({src, filter, bias, {}, {}}) / RUNS;
  954. printf("stride=%zu src=%s, filter=%s, dst=%s, dnn: %.2fms %.2fGB/s "
  955. "cudnn: %.2fms %.2fGB/s "
  956. "speedup: "
  957. "%0.2f (dnn/cudnn)\n",
  958. s, src.to_string().c_str(), filter.to_string().c_str(),
  959. dst_layout.to_string().c_str(), time_in_ms_dnn,
  960. computation_mops / time_in_ms_dnn, time_in_ms_cudnn,
  961. computation_mops / time_in_ms_cudnn, time_in_ms_cudnn / time_in_ms_dnn);
  962. };
  963. // clang-format off
  964. for(size_t batch:{1, 16, 32, 64, 128}){
  965. run(batch, 32, 112, 112, 3, 1);
  966. run(batch, 96, 112, 112, 3, 2);
  967. run(batch, 96, 112, 112, 3, 1);
  968. run(batch, 144, 56, 56, 3, 2);
  969. run(batch, 144, 56, 56, 3, 1);
  970. run(batch, 192, 28, 28, 3, 1);
  971. run(batch, 384, 14, 14, 3, 1);
  972. run(batch, 576, 14, 14, 3, 1);
  973. run(batch, 960, 7, 7, 3, 1);
  974. //! calibrate heu algo policy hw_size param
  975. run(batch, 144, 24, 24, 3, 1);
  976. run(batch, 144, 22, 22, 3, 1);
  977. run(batch, 144, 20, 20, 3, 1);
  978. run(batch, 144, 18, 18, 3, 1);
  979. }
  980. // clang-format on
  981. }
  982. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_FLOAT_SMALL) {
  983. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  984. size_t RUNS = 1;
  985. bencher.set_display(false).set_times(RUNS);
  986. ConvolutionBackwardData::Param param;
  987. param.format = Convolution::Param::Format::NCHW;
  988. param.sparse = Convolution::Param::Sparse::GROUP;
  989. NormalRNG rng;
  990. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  991. param.pad_h = f / 2;
  992. param.pad_w = f / 2;
  993. param.stride_h = s;
  994. param.stride_w = s;
  995. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  996. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  997. float bandwith = static_cast<float>(
  998. src.total_nr_elems() + filter.total_nr_elems() +
  999. src.total_nr_elems()) /
  1000. (1024 * 1024 * 1024) * 1e3;
  1001. bencher.set_param(param)
  1002. .set_dtype(0, dtype::Float32())
  1003. .set_dtype(1, dtype::Float32())
  1004. .set_dtype(2, dtype::Float32())
  1005. .set_rng(0, &rng)
  1006. .set_rng(1, &rng)
  1007. .set_before_exec_callback(
  1008. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  1009. auto time_in_ms_fp32_normal = bencher.execs({filter, src, src}) / RUNS;
  1010. bencher.set_before_exec_callback(
  1011. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE_SMALL"));
  1012. auto time_in_ms_fp32_small = bencher.execs({filter, src, src}) / RUNS;
  1013. bencher.set_param(param)
  1014. .set_dtype(0, dtype::Float16())
  1015. .set_dtype(1, dtype::Float16())
  1016. .set_dtype(2, dtype::Float16())
  1017. .set_rng(0, &rng)
  1018. .set_rng(1, &rng);
  1019. auto time_in_ms_fp16_small = bencher.execs({filter, src, src}) / RUNS;
  1020. printf("stride=%zu src=%s, filter=%s, fp32 normal: %.2fms %.2fGB/s "
  1021. "small: %.2fms %.2fGB/s, fp16 small: %.2fms %.2fGB/s, "
  1022. "speedup: "
  1023. "%0.2f (fp32 small/normal) %0.2f (small fp16/fp32)\n",
  1024. s, src.to_string().c_str(), filter.to_string().c_str(),
  1025. time_in_ms_fp32_normal, bandwith * 4 / time_in_ms_fp32_normal,
  1026. time_in_ms_fp32_small, bandwith * 4 / time_in_ms_fp32_small,
  1027. time_in_ms_fp16_small, bandwith * 2 / time_in_ms_fp16_small,
  1028. time_in_ms_fp32_normal / time_in_ms_fp32_small,
  1029. time_in_ms_fp32_small / time_in_ms_fp16_small);
  1030. };
  1031. // clang-format off
  1032. for (size_t s : {1})
  1033. for (size_t f : {3, 5})
  1034. for (size_t batch : {64})
  1035. for (size_t c : {16, 32, 64, 128})
  1036. for (size_t ih: {8, 16, 32})
  1037. for (size_t iw : {8, 16, 32})
  1038. run(batch, c, ih, iw, f, s);
  1039. // clang-format on
  1040. run(128, 192, 28, 28, 3, 1);
  1041. run(128, 576, 14, 14, 3, 1);
  1042. run(128, 384, 14, 14, 3, 1);
  1043. run(128, 960, 7, 7, 3, 1);
  1044. run(128, 384, 14, 14, 3, 1);
  1045. run(128, 384, 14, 14, 3, 1);
  1046. run(128, 384, 14, 14, 3, 1);
  1047. run(128, 192, 28, 28, 3, 1);
  1048. run(128, 576, 14, 14, 3, 1);
  1049. }
  1050. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_DATA) {
  1051. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  1052. size_t RUNS = 1;
  1053. bencher.set_display(false).set_times(RUNS);
  1054. bencher.set_before_exec_callback(
  1055. AlgoChecker<ConvolutionBackwardData>("CHANNEL_WISE"));
  1056. Convolution::Param param;
  1057. param.format = ConvBias::Param::Format::NCHW;
  1058. param.sparse = Convolution::Param::Sparse::GROUP;
  1059. NormalRNG rng;
  1060. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t ih, size_t iw,
  1061. size_t f, size_t p, size_t s) {
  1062. param.pad_h = p;
  1063. param.pad_w = p;
  1064. param.stride_h = s;
  1065. param.stride_w = s;
  1066. size_t oh, ow;
  1067. infer_conv_shape2d(ih, iw, f, f, s, s, p, p, oh, ow, true);
  1068. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1069. TensorShape src_grad = {batch, group, ih, iw},
  1070. dst_grad = {batch, group * ocpg, oh, ow},
  1071. flt = {group, ocpg, 1, f, f};
  1072. auto opr = handle_cuda()->create_operator<Convolution>();
  1073. opr->param() = param;
  1074. float bandwith = static_cast<float>(
  1075. flt.total_nr_elems() + dst_grad.total_nr_elems() +
  1076. src_grad.total_nr_elems()) /
  1077. (1024 * 1024 * 1024) * 1e3;
  1078. bencher.set_param(param)
  1079. .set_dtype(0, dtype::Float32())
  1080. .set_dtype(1, dtype::Float32())
  1081. .set_dtype(2, dtype::Float32())
  1082. .set_rng(0, &rng)
  1083. .set_rng(1, &rng);
  1084. auto time_in_ms_fp32 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  1085. bencher.set_param(param)
  1086. .set_dtype(0, dtype::Float16())
  1087. .set_dtype(1, dtype::Float16())
  1088. .set_dtype(2, dtype::Float16())
  1089. .set_rng(0, &rng)
  1090. .set_rng(1, &rng);
  1091. auto time_in_ms_fp16 = bencher.execs({flt, dst_grad, src_grad}) / RUNS;
  1092. printf("stride=%zu, src_grad=%s, flt=%s, "
  1093. "float32: %.2fms %.2fGB/s "
  1094. "float16: %.2fms %.2fGB/s "
  1095. "speedup: "
  1096. "%0.2f (fp16/fp32)\n",
  1097. s, src_grad.to_string().c_str(), flt.to_string().c_str(),
  1098. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1099. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  1100. };
  1101. // clang-format off
  1102. for (size_t s : {1, 2})
  1103. for (size_t f : {3, 5, 7})
  1104. for (size_t p : {f / 2})
  1105. for (size_t batch : {64})
  1106. for (size_t ocpg : {1})
  1107. for (size_t group : {16, 32, 64, 128})
  1108. for (size_t ih : {8, 16, 32, 128, 256})
  1109. for (size_t iw : {8, 16, 32, 128, 256})
  1110. run(batch, ocpg, group, ih, iw, f, p, s);
  1111. // clang-format on
  1112. }
  1113. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BWD_FILTER) {
  1114. CUBenchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
  1115. size_t RUNS = 1;
  1116. bencher.set_display(false).set_times(RUNS);
  1117. bencher.set_before_exec_callback(
  1118. AlgoChecker<ConvolutionBackwardFilter>("CHANNEL_WISE"));
  1119. Convolution::Param param;
  1120. param.format = ConvBias::Param::Format::NCHW;
  1121. param.sparse = Convolution::Param::Sparse::GROUP;
  1122. NormalRNG rng;
  1123. auto run = [&](size_t batch, size_t ocpg, size_t group, size_t i, size_t f,
  1124. size_t p, size_t s) {
  1125. param.pad_h = p;
  1126. param.pad_w = p;
  1127. param.stride_h = s;
  1128. param.stride_w = s;
  1129. size_t d = infer_conv_shape(i, f, s, p, true);
  1130. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1131. TensorShape src = {batch, group, i, i}, dst_grad = {batch, group * ocpg, d, d},
  1132. flt_grad = {group, ocpg, 1, f, f};
  1133. auto opr = handle_cuda()->create_operator<Convolution>();
  1134. opr->param() = param;
  1135. float bandwith = static_cast<float>(
  1136. flt_grad.total_nr_elems() + dst_grad.total_nr_elems() +
  1137. src.total_nr_elems()) /
  1138. (1024 * 1024 * 1024) * 1e3;
  1139. bencher.set_param(param)
  1140. .set_dtype(0, dtype::Float32())
  1141. .set_dtype(1, dtype::Float32())
  1142. .set_dtype(2, dtype::Float32())
  1143. .set_rng(0, &rng)
  1144. .set_rng(1, &rng);
  1145. auto time_in_ms_fp32 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  1146. bencher.set_param(param)
  1147. .set_dtype(0, dtype::Float16())
  1148. .set_dtype(1, dtype::Float16())
  1149. .set_dtype(2, dtype::Float16())
  1150. .set_rng(0, &rng)
  1151. .set_rng(1, &rng);
  1152. auto time_in_ms_fp16 = bencher.execs({src, dst_grad, flt_grad}) / RUNS;
  1153. printf("stride=%zu, src=%s, flt_grad=%s, "
  1154. "float32: %.2fms %.2fGB/s "
  1155. "float16: %.2fms %.2fGB/s "
  1156. "speedup: "
  1157. "%.2f (fp16/fp32)\n",
  1158. s, src.to_string().c_str(), flt_grad.to_string().c_str(),
  1159. time_in_ms_fp32, bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1160. bandwith * 2 / time_in_ms_fp16, time_in_ms_fp32 / time_in_ms_fp16);
  1161. };
  1162. // clang-format off
  1163. for (size_t s : {1, 2})
  1164. for (size_t f : {3, 5, 7})
  1165. for (size_t p : {f / 2})
  1166. for (size_t batch : {64})
  1167. for (size_t ocpg : {1})
  1168. for (size_t group : {16, 32, 64, 128})
  1169. for (size_t i : {8, 16, 32, 64, 128})
  1170. run(batch, ocpg, group, i, f, p, s);
  1171. // clang-format on
  1172. }
  1173. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_FORWARD_LARGE_KERNEL) {
  1174. CUBenchmarker<ConvolutionForward> bencher(handle_cuda());
  1175. size_t RUNS = 100;
  1176. bencher.set_display(false).set_times(RUNS);
  1177. std::unique_ptr<OprProxy<ConvolutionForward>> proxy{
  1178. new OprProxy<ConvolutionForward>{true}};
  1179. bencher.set_proxy(proxy);
  1180. Convolution::Param param;
  1181. param.format = ConvBias::Param::Format::NCHW;
  1182. param.sparse = Convolution::Param::Sparse::GROUP;
  1183. NormalRNG rng;
  1184. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  1185. param.pad_h = f / 2;
  1186. param.pad_w = f / 2;
  1187. param.stride_h = s;
  1188. param.stride_w = s;
  1189. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1190. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  1191. TensorLayout dst_layout;
  1192. auto opr = handle_cuda()->create_operator<Convolution>();
  1193. opr->param() = param;
  1194. opr->deduce_layout(
  1195. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  1196. float bandwith = static_cast<float>(
  1197. src.total_nr_elems() + filter.total_nr_elems() +
  1198. dst_layout.total_nr_elems()) /
  1199. (1024 * 1024 * 1024) * 1e3;
  1200. bencher.set_param(param)
  1201. .set_dtype(0, dtype::Float32())
  1202. .set_dtype(1, dtype::Float32())
  1203. .set_dtype(2, dtype::Float32())
  1204. .set_rng(0, &rng)
  1205. .set_rng(1, &rng);
  1206. bencher.proxy()->target_execution_policy = {};
  1207. auto time_in_ms_fp32 = bencher.execs({src, filter, {}}) / RUNS;
  1208. bencher.set_param(param)
  1209. .set_dtype(0, dtype::Float16())
  1210. .set_dtype(1, dtype::Float16())
  1211. .set_dtype(2, dtype::Float16())
  1212. .set_rng(0, &rng)
  1213. .set_rng(1, &rng);
  1214. bencher.proxy()->target_execution_policy = {};
  1215. auto time_in_ms_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  1216. bencher.proxy()->target_execution_policy.algo.reset();
  1217. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  1218. bencher.set_param(param);
  1219. auto time_in_ms_pseudo_fp16 = bencher.execs({src, filter, {}}) / RUNS;
  1220. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  1221. "float16: %.2fms %.2fGB/s "
  1222. "pseudo float16: %.2fms %.2fGB/s "
  1223. "speedup: "
  1224. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  1225. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  1226. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1227. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  1228. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  1229. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  1230. };
  1231. // clang-format off
  1232. for (size_t b : {32, 64})
  1233. for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
  1234. run(b, 384, 32, 32, f, 1);
  1235. run(b, 384, 64, 64, f, 1);
  1236. }
  1237. // clang-format on
  1238. }
  1239. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_DATA_LARGE_KERNEL) {
  1240. CUBenchmarker<ConvolutionBackwardData> bencher(handle_cuda());
  1241. size_t RUNS = 100;
  1242. bencher.set_display(false).set_times(RUNS);
  1243. std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
  1244. new OprProxy<ConvolutionBackwardData>{true}};
  1245. bencher.set_proxy(proxy);
  1246. Convolution::Param param;
  1247. param.format = ConvBias::Param::Format::NCHW;
  1248. param.sparse = Convolution::Param::Sparse::GROUP;
  1249. NormalRNG rng;
  1250. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  1251. param.pad_h = f / 2;
  1252. param.pad_w = f / 2;
  1253. param.stride_h = s;
  1254. param.stride_w = s;
  1255. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1256. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  1257. TensorLayout dst_layout;
  1258. auto opr = handle_cuda()->create_operator<Convolution>();
  1259. opr->param() = param;
  1260. opr->deduce_layout(
  1261. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  1262. float bandwith = static_cast<float>(
  1263. src.total_nr_elems() + filter.total_nr_elems() +
  1264. dst_layout.total_nr_elems()) /
  1265. (1024 * 1024 * 1024) * 1e3;
  1266. bencher.set_param(param)
  1267. .set_dtype(0, dtype::Float32())
  1268. .set_dtype(1, dtype::Float32())
  1269. .set_dtype(2, dtype::Float32())
  1270. .set_rng(0, &rng)
  1271. .set_rng(1, &rng);
  1272. bencher.proxy()->target_execution_policy = {};
  1273. auto time_in_ms_fp32 = bencher.execs({filter, src, src}) / RUNS;
  1274. bencher.set_param(param)
  1275. .set_dtype(0, dtype::Float16())
  1276. .set_dtype(1, dtype::Float16())
  1277. .set_dtype(2, dtype::Float16())
  1278. .set_rng(0, &rng)
  1279. .set_rng(1, &rng);
  1280. bencher.proxy()->target_execution_policy = {};
  1281. auto time_in_ms_fp16 = bencher.execs({filter, src, src}) / RUNS;
  1282. bencher.proxy()->target_execution_policy.algo.reset();
  1283. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  1284. bencher.set_param(param);
  1285. auto time_in_ms_pseudo_fp16 = bencher.execs({filter, src, src}) / RUNS;
  1286. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  1287. "float16: %.2fms %.2fGB/s "
  1288. "pseudo float16: %.2fms %.2fGB/s "
  1289. "speedup: "
  1290. "%0.2f (fp16/fp32) %.2f (fp16/pseudo fp16)\n",
  1291. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  1292. bandwith * 4 / time_in_ms_fp32, time_in_ms_fp16,
  1293. bandwith * 2 / time_in_ms_fp16, time_in_ms_pseudo_fp16,
  1294. bandwith * 2 / time_in_ms_pseudo_fp16, time_in_ms_fp32 / time_in_ms_fp16,
  1295. time_in_ms_pseudo_fp16 / time_in_ms_fp16);
  1296. };
  1297. // clang-format off
  1298. for (size_t b : {32, 64})
  1299. for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
  1300. run(b, 384, 32, 32, f, 1);
  1301. run(b, 384, 64, 64, f, 1);
  1302. }
  1303. // clang-format on
  1304. }
  1305. TEST_F(CUDA, BENCHMARK_CHANWISE_CONV_BACKWARD_FILTER_LARGE_KERNEL) {
  1306. CUBenchmarker<ConvolutionBackwardFilter> bencher(handle_cuda());
  1307. size_t RUNS = 100;
  1308. bencher.set_display(false).set_times(RUNS);
  1309. std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
  1310. new OprProxy<ConvolutionBackwardFilter>{true}};
  1311. bencher.set_proxy(proxy);
  1312. Convolution::Param param;
  1313. param.format = ConvBias::Param::Format::NCHW;
  1314. param.sparse = Convolution::Param::Sparse::GROUP;
  1315. NormalRNG rng;
  1316. auto run = [&](size_t batch, size_t c, size_t ih, size_t iw, size_t f, size_t s) {
  1317. param.pad_h = f / 2;
  1318. param.pad_w = f / 2;
  1319. param.stride_h = s;
  1320. param.stride_w = s;
  1321. param.compute_mode = param::Convolution::ComputeMode::DEFAULT;
  1322. TensorShape src = {batch, c, ih, iw}, filter = {c, 1, 1, f, f};
  1323. TensorLayout dst_layout;
  1324. auto opr = handle_cuda()->create_operator<Convolution>();
  1325. opr->param() = param;
  1326. opr->deduce_layout(
  1327. {src, dtype::Float32()}, {filter, dtype::Float32()}, dst_layout);
  1328. float bandwith = static_cast<float>(
  1329. src.total_nr_elems() + filter.total_nr_elems() +
  1330. dst_layout.total_nr_elems()) /
  1331. (1024 * 1024 * 1024) * 1e3;
  1332. bencher.set_param(param)
  1333. .set_dtype(0, dtype::Float32())
  1334. .set_dtype(1, dtype::Float32())
  1335. .set_dtype(2, dtype::Float32())
  1336. .set_rng(0, &rng)
  1337. .set_rng(1, &rng);
  1338. bencher.proxy()->target_execution_policy = {};
  1339. auto time_in_ms_fp32 = bencher.execs({src, src, filter}) / RUNS;
  1340. bencher.set_param(param)
  1341. .set_dtype(0, dtype::Float16())
  1342. .set_dtype(1, dtype::Float16())
  1343. .set_dtype(2, dtype::Float16())
  1344. .set_rng(0, &rng)
  1345. .set_rng(1, &rng);
  1346. bencher.proxy()->target_execution_policy = {};
  1347. param.compute_mode = param::Convolution::ComputeMode::FLOAT32;
  1348. bencher.set_param(param);
  1349. auto time_in_ms_pseudo_fp16 = bencher.execs({src, src, filter}) / RUNS;
  1350. printf("stride=%zu src=%s, filter=%s, float32: %.2fms %.2fGB/s "
  1351. "pseudo float16: %.2fms %.2fGB/s "
  1352. "speedup: "
  1353. "%0.2f (fp16/fp32) \n",
  1354. s, src.to_string().c_str(), filter.to_string().c_str(), time_in_ms_fp32,
  1355. bandwith * 4 / time_in_ms_fp32, time_in_ms_pseudo_fp16,
  1356. bandwith * 2 / time_in_ms_pseudo_fp16,
  1357. time_in_ms_fp32 / time_in_ms_pseudo_fp16);
  1358. };
  1359. // clang-format off
  1360. for (size_t b : {32, 64})
  1361. for (size_t f : {3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}) {
  1362. run(b, 384, 32, 32, f, 1);
  1363. run(b, 384, 64, 64, f, 1);
  1364. }
  1365. // clang-format on
  1366. }
  1367. #endif
  1368. // vim: syntax=cpp.doxygen