You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_int8.cpp 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. /**
  2. * \file dnn/test/cuda/conv_bias_int8.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "src/common/utils.h"
  13. #include "src/cuda/cudnn_with_check.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/cuda/benchmark.h"
  17. #include "test/cuda/fixture.h"
  18. #include "test/cuda/utils.h"
  19. namespace megdnn {
  20. namespace test {
  21. #if MEGDNN_WITH_BENCHMARK
  22. namespace {
  23. struct BenchArgs {
  24. size_t n, ci, hi, wi, co, f, s;
  25. };
  26. std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) {
  27. std::vector<BenchArgs> args;
  28. args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1});
  29. args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1});
  30. args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1});
  31. args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1});
  32. args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1});
  33. args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2});
  34. args.emplace_back(BenchArgs{batch, 256, 56, 56, 128, 1, 2});
  35. args.emplace_back(BenchArgs{batch, 512, 28, 28, 128, 1, 1});
  36. args.emplace_back(BenchArgs{batch, 128, 28, 28, 128, 3, 1});
  37. args.emplace_back(BenchArgs{batch, 128, 28, 28, 512, 1, 1});
  38. args.emplace_back(BenchArgs{batch, 512, 28, 28, 1024, 1, 2});
  39. args.emplace_back(BenchArgs{batch, 512, 28, 28, 256, 1, 2});
  40. args.emplace_back(BenchArgs{batch, 1024, 14, 14, 256, 1, 1});
  41. args.emplace_back(BenchArgs{batch, 256, 14, 14, 256, 3, 1});
  42. args.emplace_back(BenchArgs{batch, 256, 14, 14, 1024, 1, 1});
  43. args.emplace_back(BenchArgs{batch, 1024, 14, 14, 2048, 1, 2});
  44. args.emplace_back(BenchArgs{batch, 1024, 14, 14, 512, 1, 2});
  45. args.emplace_back(BenchArgs{batch, 2048, 7, 7, 512, 1, 1});
  46. args.emplace_back(BenchArgs{batch, 512, 7, 7, 512, 3, 1});
  47. args.emplace_back(BenchArgs{batch, 512, 7, 7, 2048, 1, 1});
  48. return args;
  49. }
  50. std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16) {
  51. std::vector<BenchArgs> args;
  52. args.emplace_back(BenchArgs{batch, 4, 736, 1280, 8, 3, 2});
  53. args.emplace_back(BenchArgs{batch, 32, 184, 320, 16, 3, 1});
  54. args.emplace_back(BenchArgs{batch, 16, 184, 320, 32, 3, 1});
  55. args.emplace_back(BenchArgs{batch, 8, 184, 320, 16, 3, 1});
  56. args.emplace_back(BenchArgs{batch, 8, 184, 320, 32, 3, 1});
  57. args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 1});
  58. args.emplace_back(BenchArgs{batch, 32, 184, 320, 64, 3, 2});
  59. args.emplace_back(BenchArgs{batch, 32, 184, 320, 32, 3, 2});
  60. args.emplace_back(BenchArgs{batch, 32, 92, 160, 64, 3, 1});
  61. args.emplace_back(BenchArgs{batch, 64, 92, 160, 8, 3, 1});
  62. args.emplace_back(BenchArgs{batch, 64, 92, 160, 128, 3, 2});
  63. args.emplace_back(BenchArgs{batch, 128, 46, 80, 32, 3, 1});
  64. args.emplace_back(BenchArgs{batch, 128, 46, 80, 256, 3, 2});
  65. args.emplace_back(BenchArgs{batch, 128, 46, 80, 8, 3, 1});
  66. args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 2});
  67. args.emplace_back(BenchArgs{batch, 32, 46, 80, 128, 3, 1});
  68. args.emplace_back(BenchArgs{batch, 8, 46, 80, 32, 3, 1});
  69. args.emplace_back(BenchArgs{batch, 64, 23, 40, 256, 3, 1});
  70. args.emplace_back(BenchArgs{batch, 256, 23, 40, 64, 3, 1});
  71. args.emplace_back(BenchArgs{batch, 128, 46, 80, 64, 3, 2});
  72. args.emplace_back(BenchArgs{batch, 256, 23, 40, 8, 3, 1});
  73. args.emplace_back(BenchArgs{batch, 8, 23, 40, 32, 3, 2});
  74. args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 1});
  75. args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 2});
  76. args.emplace_back(BenchArgs{batch, 8, 6, 10, 8, 3, 1});
  77. return args;
  78. }
  79. void benchmark_target_algo(
  80. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  81. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  82. const char* algo = nullptr,
  83. param::ConvBias::Format format = param::ConvBias::Format::NCHW4) {
  84. megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
  85. CUBenchmarker<ConvBiasForward> benchmarker(handle);
  86. CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle);
  87. size_t RUNS = 1000;
  88. benchmarker.set_display(false).set_times(RUNS);
  89. benchmarker_cudnn.set_display(false).set_times(RUNS);
  90. if (algo) {
  91. benchmarker.set_before_exec_callback(
  92. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo));
  93. }
  94. #define V1(x) #x
  95. #define V(x) V1(x)
  96. #define CUDNN_VERSION_STRING \
  97. "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
  98. benchmarker_cudnn.set_before_exec_callback(
  99. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  100. "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_"
  101. "GEMM" CUDNN_VERSION_STRING));
  102. benchmarker.set_dtype(0, src_dtype)
  103. .set_dtype(1, filter_dtype)
  104. .set_dtype(2, bias_dtype)
  105. .set_dtype(3, dst_dtype)
  106. .set_dtype(4, dst_dtype);
  107. benchmarker_cudnn.set_dtype(0, src_dtype)
  108. .set_dtype(1, filter_dtype)
  109. .set_dtype(2, bias_dtype)
  110. .set_dtype(3, dst_dtype)
  111. .set_dtype(4, dst_dtype);
  112. using Param = ConvBias::Param;
  113. using Format = Param::Format;
  114. if (format == Format::NCHW4) {
  115. for (auto&& arg : args) {
  116. Param param;
  117. param.pad_h = param.pad_w = arg.f / 2;
  118. param.stride_h = param.stride_w = arg.s;
  119. param.format = Format::NCHW4;
  120. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  121. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  122. benchmarker.set_param(param);
  123. auto time_in_ms =
  124. benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  125. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  126. {1, arg.co / 4, 1, 1, 4},
  127. {},
  128. {}}) /
  129. RUNS;
  130. benchmarker_cudnn.set_param(param);
  131. auto time_in_ms_cudnn =
  132. benchmarker_cudnn.execs(
  133. {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  134. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  135. {1, arg.co / 4, 1, 1, 4},
  136. {},
  137. {}}) /
  138. RUNS;
  139. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  140. arg.f / (1e12);
  141. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  142. filter{arg.co, arg.ci, arg.f, arg.f};
  143. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  144. "time(cudnn)=%.2f %.2fTops, "
  145. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  146. src.to_string().c_str(), filter.to_string().c_str(), algo,
  147. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  148. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  149. time_in_ms_cudnn / time_in_ms);
  150. }
  151. } else if (format == Format::CHWN4) {
  152. for (auto&& arg : args) {
  153. Param param;
  154. param.pad_h = param.pad_w = arg.f / 2;
  155. param.stride_h = param.stride_w = arg.s;
  156. param.format = Format::CHWN4;
  157. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  158. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  159. benchmarker.set_param(param);
  160. auto time_in_ms =
  161. benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4},
  162. {arg.ci / 4, arg.f, arg.f, arg.co, 4},
  163. {arg.co / 4, 1, 1, 1, 4},
  164. {},
  165. {}}) /
  166. RUNS;
  167. param.format = Format::NCHW4;
  168. benchmarker_cudnn.set_param(param);
  169. auto time_in_ms_cudnn =
  170. benchmarker_cudnn.execs(
  171. {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  172. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  173. {1, arg.co / 4, 1, 1, 4},
  174. {},
  175. {}}) /
  176. RUNS;
  177. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  178. arg.f / (1e12);
  179. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  180. filter{arg.co, arg.ci, arg.f, arg.f};
  181. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  182. "time(cudnn)=%.2f %.2fTops, "
  183. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  184. src.to_string().c_str(), filter.to_string().c_str(), algo,
  185. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  186. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  187. time_in_ms_cudnn / time_in_ms);
  188. }
  189. printf("bench with z tensor\n");
  190. for (auto&& arg : args) {
  191. Param param;
  192. param.pad_h = param.pad_w = arg.f / 2;
  193. param.stride_h = param.stride_w = arg.s;
  194. param.format = Format::CHWN4;
  195. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  196. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  197. benchmarker.set_param(param);
  198. auto time_in_ms =
  199. benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4},
  200. {arg.ci / 4, arg.f, arg.f, arg.co, 4},
  201. {arg.co / 4, 1, 1, 1, 4},
  202. {arg.co / 4, ho, wo, arg.n, 4},
  203. {}}) /
  204. RUNS;
  205. param.format = Format::NCHW4;
  206. benchmarker_cudnn.set_param(param);
  207. auto time_in_ms_cudnn =
  208. benchmarker_cudnn.execs(
  209. {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  210. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  211. {1, arg.co / 4, 1, 1, 4},
  212. {arg.n, arg.co / 4, ho, wo, 4},
  213. {}}) /
  214. RUNS;
  215. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  216. arg.f / (1e12);
  217. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  218. filter{arg.co, arg.ci, arg.f, arg.f};
  219. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  220. "time(cudnn)=%.2f %.2fTops, "
  221. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  222. src.to_string().c_str(), filter.to_string().c_str(), algo,
  223. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  224. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  225. time_in_ms_cudnn / time_in_ms);
  226. }
  227. }
  228. }
  229. void benchmark_target_algo_with_cudnn_tsc(
  230. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  231. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  232. const char* algo = nullptr,
  233. param::ConvBias::Format format = param::ConvBias::Format::NCHW4) {
  234. megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
  235. CUBenchmarker<ConvBiasForward> benchmarker(handle);
  236. CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle);
  237. size_t RUNS = 1000;
  238. benchmarker.set_display(false).set_times(RUNS);
  239. benchmarker_cudnn.set_display(false).set_times(RUNS);
  240. std::unique_ptr<OprProxy<ConvBiasForward>> proxy{
  241. new OprProxy<ConvBiasForward>{true}};
  242. if (algo) {
  243. benchmarker.set_before_exec_callback(
  244. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo));
  245. } else {
  246. benchmarker.set_proxy(proxy);
  247. }
  248. benchmarker_cudnn.set_before_exec_callback(
  249. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  250. "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_"
  251. "GEMM" CUDNN_VERSION_STRING));
  252. #undef V1
  253. #undef V
  254. #undef CUDNN_VERSION_STRING
  255. benchmarker.set_dtype(0, src_dtype)
  256. .set_dtype(1, filter_dtype)
  257. .set_dtype(2, bias_dtype)
  258. .set_dtype(3, dst_dtype)
  259. .set_dtype(4, dst_dtype);
  260. benchmarker_cudnn.set_dtype(0, src_dtype)
  261. .set_dtype(1, filter_dtype)
  262. .set_dtype(2, bias_dtype)
  263. .set_dtype(3, dst_dtype)
  264. .set_dtype(4, dst_dtype);
  265. using Param = ConvBias::Param;
  266. using Format = Param::Format;
  267. if (format == Format::NCHW4) {
  268. for (auto&& arg : args) {
  269. Param param;
  270. param.pad_h = param.pad_w = arg.f / 2;
  271. param.stride_h = param.stride_w = arg.s;
  272. param.format = Format::NCHW4;
  273. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  274. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  275. benchmarker.set_param(param);
  276. if (!algo) {
  277. benchmarker.proxy()->target_algo = nullptr;
  278. }
  279. auto time_in_ms =
  280. benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  281. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  282. {1, arg.co / 4, 1, 1, 4},
  283. {},
  284. {}}) /
  285. RUNS;
  286. param.format = Format::NCHW32;
  287. benchmarker_cudnn.set_param(param);
  288. auto time_in_ms_cudnn =
  289. benchmarker_cudnn.execs(
  290. {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32},
  291. {arg.co, arg.ci / 32, arg.f, arg.f, 32},
  292. {1, arg.co / 32, 1, 1, 32},
  293. {},
  294. {}}) /
  295. RUNS;
  296. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  297. arg.f / (1e12);
  298. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  299. filter{arg.co, arg.ci, arg.f, arg.f};
  300. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  301. "time(cudnn)=%.2f %.2fTops, "
  302. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  303. src.to_string().c_str(), filter.to_string().c_str(), algo,
  304. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  305. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  306. time_in_ms_cudnn / time_in_ms);
  307. }
  308. } else if (format == Format::CHWN4) {
  309. for (auto&& arg : args) {
  310. Param param;
  311. param.pad_h = param.pad_w = arg.f / 2;
  312. param.stride_h = param.stride_w = arg.s;
  313. param.format = Format::CHWN4;
  314. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  315. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  316. benchmarker.set_param(param);
  317. if (!algo) {
  318. benchmarker.proxy()->target_algo = nullptr;
  319. }
  320. auto time_in_ms =
  321. benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4},
  322. {arg.ci / 4, arg.f, arg.f, arg.co, 4},
  323. {arg.co / 4, 1, 1, 1, 4},
  324. {},
  325. {}}) /
  326. RUNS;
  327. float time_in_ms_cudnn = 0.f;
  328. if (arg.ci % 32 == 0 && arg.co % 32 == 0) {
  329. param.format = Format::NCHW32;
  330. benchmarker_cudnn.set_param(param);
  331. time_in_ms_cudnn =
  332. benchmarker_cudnn.execs(
  333. {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32},
  334. {arg.co, arg.ci / 32, arg.f, arg.f, 32},
  335. {1, arg.co / 32, 1, 1, 32},
  336. {},
  337. {}}) /
  338. RUNS;
  339. } else {
  340. param.format = Format::NCHW4;
  341. benchmarker_cudnn.set_param(param);
  342. time_in_ms_cudnn =
  343. benchmarker_cudnn.execs(
  344. {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  345. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  346. {1, arg.co / 4, 1, 1, 4},
  347. {},
  348. {}}) /
  349. RUNS;
  350. }
  351. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  352. arg.f / (1e12);
  353. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  354. filter{arg.co, arg.ci, arg.f, arg.f};
  355. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  356. "time(cudnn)=%.2f %.2fTops, "
  357. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  358. src.to_string().c_str(), filter.to_string().c_str(), algo,
  359. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  360. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  361. time_in_ms_cudnn / time_in_ms);
  362. }
  363. printf("bench with z tensor\n");
  364. for (auto&& arg : args) {
  365. Param param;
  366. param.pad_h = param.pad_w = arg.f / 2;
  367. param.stride_h = param.stride_w = arg.s;
  368. param.format = Format::CHWN4;
  369. size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2);
  370. size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2);
  371. benchmarker.set_param(param);
  372. if (!algo) {
  373. benchmarker.proxy()->target_algo = nullptr;
  374. }
  375. auto time_in_ms =
  376. benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4},
  377. {arg.ci / 4, arg.f, arg.f, arg.co, 4},
  378. {arg.co / 4, 1, 1, 1, 4},
  379. {arg.co / 4, ho, wo, arg.n, 4},
  380. {}}) /
  381. RUNS;
  382. float time_in_ms_cudnn = 0.f;
  383. if (arg.ci % 32 == 0 && arg.co % 32 == 0) {
  384. param.format = Format::NCHW32;
  385. benchmarker_cudnn.set_param(param);
  386. time_in_ms_cudnn =
  387. benchmarker_cudnn.execs(
  388. {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32},
  389. {arg.co, arg.ci / 32, arg.f, arg.f, 32},
  390. {1, arg.co / 32, 1, 1, 32},
  391. {arg.n, arg.co / 32, ho, wo, 32},
  392. {}}) /
  393. RUNS;
  394. } else {
  395. param.format = Format::NCHW4;
  396. benchmarker_cudnn.set_param(param);
  397. time_in_ms_cudnn =
  398. benchmarker_cudnn.execs(
  399. {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4},
  400. {arg.co, arg.ci / 4, arg.f, arg.f, 4},
  401. {1, arg.co / 4, 1, 1, 4},
  402. {arg.n, arg.co / 4, ho, wo, 4},
  403. {}}) /
  404. RUNS;
  405. }
  406. float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f *
  407. arg.f / (1e12);
  408. TensorShape src{arg.n, arg.ci, arg.hi, arg.wi},
  409. filter{arg.co, arg.ci, arg.f, arg.f};
  410. printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, "
  411. "time(cudnn)=%.2f %.2fTops, "
  412. "perf(algo=%s)/perf(cudnn)=%.2f\n",
  413. src.to_string().c_str(), filter.to_string().c_str(), algo,
  414. time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn,
  415. (flo / (time_in_ms_cudnn * 1e-3)), algo,
  416. time_in_ms_cudnn / time_in_ms);
  417. }
  418. }
  419. }
  420. } // namespace
  421. #endif
  422. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) {
  423. require_compute_capability(6, 1);
  424. conv_bias::check_conv_bias(
  425. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  426. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  427. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  428. param::ConvBias::Format::NCHW4, conv_bias::get_int8_nchw4_args(1));
  429. }
  430. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_3x3) {
  431. require_compute_capability(6, 1);
  432. conv_bias::check_conv_bias(
  433. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  434. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  435. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  436. param::ConvBias::Format::NCHW4);
  437. }
  438. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_5x5) {
  439. require_compute_capability(6, 1);
  440. conv_bias::check_conv_bias(
  441. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  442. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  443. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  444. param::ConvBias::Format::NCHW4, conv_bias::get_int8_nchw4_args(5));
  445. }
  446. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_7x7) {
  447. require_compute_capability(6, 1);
  448. conv_bias::check_conv_bias(
  449. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  450. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  451. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  452. param::ConvBias::Format::NCHW4, conv_bias::get_int8_nchw4_args(7));
  453. }
  454. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_WITH_Z) {
  455. require_compute_capability(6, 1);
  456. Checker<ConvBiasForward> checker(handle_cuda());
  457. checker.set_before_exec_callback(
  458. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  459. "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"));
  460. UniformIntRNG rng{-3, 3};
  461. UniformIntRNG bias_rng{-50, 50};
  462. checker.set_rng(0, &rng)
  463. .set_rng(1, &rng)
  464. .set_rng(2, &bias_rng)
  465. .set_rng(3, &rng)
  466. .set_dtype(0, dtype::QuantizedS8{1.2f})
  467. .set_dtype(1, dtype::QuantizedS8{1.3f})
  468. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  469. .set_dtype(3, dtype::QuantizedS8{1.1f})
  470. .set_dtype(4, dtype::QuantizedS8{1.0f})
  471. .set_epsilon(1 + 1e-3)
  472. .set_max_avg_error(1e-1)
  473. .set_max_avg_biased_error(1e-1);
  474. param::ConvBias param;
  475. param.pad_h = param.pad_w = 1;
  476. param.stride_h = param.stride_w = 1;
  477. param.format = param::ConvBias::Format::NCHW4;
  478. checker.set_param(param).execs({{32, 4, 12, 12, 4},
  479. {16, 4, 3, 3, 4},
  480. {1, 4, 1, 1, 4},
  481. {32, 4, 12, 12, 4},
  482. {}});
  483. }
  484. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_STRIDE2_WITH_Z) {
  485. require_compute_capability(6, 1);
  486. Checker<ConvBiasForward> checker(handle_cuda());
  487. checker.set_before_exec_callback(
  488. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  489. "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"));
  490. UniformIntRNG rng{-3, 3};
  491. UniformIntRNG bias_rng{-50, 50};
  492. checker.set_rng(0, &rng)
  493. .set_rng(1, &rng)
  494. .set_rng(2, &bias_rng)
  495. .set_rng(3, &rng)
  496. .set_dtype(0, dtype::QuantizedS8{1.2f})
  497. .set_dtype(1, dtype::QuantizedS8{1.3f})
  498. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  499. .set_dtype(3, dtype::QuantizedS8{1.1f})
  500. .set_dtype(4, dtype::QuantizedS8{1.0f})
  501. .set_epsilon(1 + 1e-3)
  502. .set_max_avg_error(1e-1)
  503. .set_max_avg_biased_error(1e-1);
  504. param::ConvBias param;
  505. param.pad_h = param.pad_w = 1;
  506. param.stride_h = param.stride_w = 2;
  507. param.format = param::ConvBias::Format::NCHW4;
  508. checker.set_param(param).execs({{32, 4, 12, 12, 4},
  509. {16, 4, 3, 3, 4},
  510. {1, 4, 1, 1, 4},
  511. {32, 4, 6, 6, 4},
  512. {}});
  513. }
  514. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_CHECK_BOUNDS_1x1) {
  515. require_compute_capability(6, 1);
  516. conv_bias::check_conv_bias(
  517. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  518. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  519. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  520. param::ConvBias::Format::NCHW4,
  521. conv_bias::get_int8_nchw4_args_check_bounds(1));
  522. }
  523. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_CHECK_BOUNDS_3x3) {
  524. require_compute_capability(6, 1);
  525. conv_bias::check_conv_bias(
  526. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  527. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  528. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  529. param::ConvBias::Format::NCHW4,
  530. conv_bias::get_int8_nchw4_args_check_bounds(3));
  531. }
  532. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_CHECK_BOUNDS_5x5) {
  533. require_compute_capability(6, 1);
  534. conv_bias::check_conv_bias(
  535. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  536. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  537. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  538. param::ConvBias::Format::NCHW4,
  539. conv_bias::get_int8_nchw4_args_check_bounds(5));
  540. }
  541. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_CHECK_BOUNDS_7x7) {
  542. require_compute_capability(6, 1);
  543. conv_bias::check_conv_bias(
  544. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  545. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  546. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  547. param::ConvBias::Format::NCHW4,
  548. conv_bias::get_int8_nchw4_args_check_bounds(7));
  549. }
  550. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4) {
  551. require_compute_capability(6, 1);
  552. conv_bias::check_conv_bias(
  553. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  554. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  555. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  556. param::ConvBias::Format::CHWN4);
  557. }
  558. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_WITH_Z) {
  559. require_compute_capability(6, 1);
  560. Checker<ConvBiasForward> checker(handle_cuda());
  561. checker.set_before_exec_callback(
  562. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  563. "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"));
  564. UniformIntRNG rng{-3, 3};
  565. UniformIntRNG bias_rng{-50, 50};
  566. checker.set_rng(0, &rng)
  567. .set_rng(1, &rng)
  568. .set_rng(2, &bias_rng)
  569. .set_rng(3, &rng)
  570. .set_dtype(0, dtype::QuantizedS8{1.2f})
  571. .set_dtype(1, dtype::QuantizedS8{1.3f})
  572. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  573. .set_dtype(3, dtype::QuantizedS8{1.1f})
  574. .set_dtype(4, dtype::QuantizedS8{1.1f})
  575. .set_epsilon(1 + 1e-3)
  576. .set_max_avg_error(1e-1)
  577. .set_max_avg_biased_error(1e-1);
  578. param::ConvBias param;
  579. param.pad_h = param.pad_w = 1;
  580. param.stride_h = param.stride_w = 1;
  581. param.format = param::ConvBias::Format::CHWN4;
  582. checker.set_param(param).execs({{4, 12, 12, 32, 4},
  583. {4, 3, 3, 16, 4},
  584. {4, 1, 1, 1, 4},
  585. {4, 12, 12, 32, 4},
  586. {}});
  587. }
  588. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_HSWISH) {
  589. require_compute_capability(6, 1);
  590. Checker<ConvBiasForward> checker(handle_cuda());
  591. checker.set_before_exec_callback(
  592. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  593. "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"));
  594. UniformIntRNG rng{-3, 3};
  595. UniformIntRNG bias_rng{-50, 50};
  596. checker.set_rng(0, &rng)
  597. .set_rng(1, &rng)
  598. .set_rng(2, &bias_rng)
  599. .set_rng(3, &rng)
  600. .set_dtype(0, dtype::QuantizedS8{1.2f})
  601. .set_dtype(1, dtype::QuantizedS8{1.3f})
  602. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  603. .set_dtype(4, dtype::QuantizedS8{0.001f})
  604. .set_epsilon(1 + 1e-3)
  605. .set_max_avg_error(1e-1)
  606. .set_max_avg_biased_error(1e-1);
  607. param::ConvBias param;
  608. param.pad_h = param.pad_w = 1;
  609. param.stride_h = param.stride_w = 1;
  610. param.format = param::ConvBias::Format::CHWN4;
  611. param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH;
  612. checker.set_param(param).execs({{4, 12, 12, 32, 4},
  613. {4, 3, 3, 16, 4},
  614. {4, 1, 1, 1, 4},
  615. {},
  616. {}});
  617. }
  618. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_CHECK_BOUNDS) {
  619. require_compute_capability(6, 1);
  620. conv_bias::check_conv_bias(
  621. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  622. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  623. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  624. param::ConvBias::Format::CHWN4,
  625. conv_bias::get_int8_chwn4_args_check_bounds(3));
  626. }
  627. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_1x1) {
  628. require_compute_capability(6, 1);
  629. conv_bias::check_conv_bias(
  630. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  631. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  632. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  633. param::ConvBias::Format::CHWN4,
  634. conv_bias::get_int8_chwn4_small_channel_args(1));
  635. }
  636. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_3x3) {
  637. require_compute_capability(6, 1);
  638. conv_bias::check_conv_bias(
  639. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  640. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  641. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  642. param::ConvBias::Format::CHWN4,
  643. conv_bias::get_int8_chwn4_small_channel_args(3));
  644. }
  645. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_5x5) {
  646. require_compute_capability(6, 1);
  647. conv_bias::check_conv_bias(
  648. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  649. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  650. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  651. param::ConvBias::Format::CHWN4,
  652. conv_bias::get_int8_chwn4_small_channel_args(5));
  653. }
  654. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_7x7) {
  655. require_compute_capability(6, 1);
  656. conv_bias::check_conv_bias(
  657. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  658. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  659. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  660. param::ConvBias::Format::CHWN4,
  661. conv_bias::get_int8_chwn4_small_channel_args(7));
  662. }
  663. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_SMALL_CHANNEL_CHECK_BOUNDS) {
  664. require_compute_capability(6, 1);
  665. conv_bias::check_conv_bias(
  666. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  667. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  668. handle_cuda(), "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  669. param::ConvBias::Format::NCHW4,
  670. conv_bias::get_int8_nchw4_small_channel_args_check_bounds(3));
  671. }
  672. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_1x1_CHECK_BOUNDS) {
  673. require_compute_capability(6, 1);
  674. conv_bias::check_conv_bias(
  675. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  676. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  677. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  678. param::ConvBias::Format::CHWN4,
  679. conv_bias::get_int8_chwn4_small_channel_args_check_bounds(1));
  680. }
  681. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_5x5_CHECK_BOUNDS) {
  682. require_compute_capability(6, 1);
  683. conv_bias::check_conv_bias(
  684. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  685. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  686. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  687. param::ConvBias::Format::CHWN4,
  688. conv_bias::get_int8_chwn4_small_channel_args_check_bounds(5));
  689. }
  690. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL_7x7_CHECK_BOUNDS) {
  691. require_compute_capability(6, 1);
  692. conv_bias::check_conv_bias(
  693. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  694. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  695. handle_cuda(), "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  696. param::ConvBias::Format::CHWN4,
  697. conv_bias::get_int8_chwn4_small_channel_args_check_bounds(7));
  698. }
  699. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_1x1) {
  700. require_compute_capability(7, 5);
  701. conv_bias::check_conv_bias(
  702. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  703. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  704. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  705. param::ConvBias::Format::NCHW4,
  706. conv_bias::get_int8_nchw4_tensorcore_args(1));
  707. }
  708. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_3x3) {
  709. require_compute_capability(7, 5);
  710. conv_bias::check_conv_bias(
  711. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  712. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  713. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  714. param::ConvBias::Format::NCHW4,
  715. conv_bias::get_int8_nchw4_tensorcore_args(3));
  716. }
  717. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_5x5) {
  718. require_compute_capability(7, 5);
  719. conv_bias::check_conv_bias(
  720. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  721. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  722. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  723. param::ConvBias::Format::NCHW4,
  724. conv_bias::get_int8_nchw4_tensorcore_args(5));
  725. }
  726. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_7x7) {
  727. require_compute_capability(7, 5);
  728. conv_bias::check_conv_bias(
  729. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  730. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  731. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  732. param::ConvBias::Format::NCHW4,
  733. conv_bias::get_int8_nchw4_tensorcore_args(7));
  734. }
  735. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_CHECK_BOUNDS_ALGO_0) {
  736. require_compute_capability(7, 5);
  737. conv_bias::check_conv_bias(
  738. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  739. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  740. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  741. param::ConvBias::Format::NCHW4,
  742. conv_bias::get_int8_nchw4_args_check_bounds(3));
  743. }
  744. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_CHECK_BOUNDS_ALGO_1) {
  745. require_compute_capability(7, 5);
  746. conv_bias::check_conv_bias(
  747. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  748. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  749. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma8x32x16",
  750. param::ConvBias::Format::NCHW4,
  751. conv_bias::get_int8_nchw4_args_check_bounds(3));
  752. }
  753. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_CHECK_BOUNDS_ALGO_2) {
  754. require_compute_capability(7, 5);
  755. conv_bias::check_conv_bias(
  756. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  757. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  758. handle_cuda(), "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma32x8x16",
  759. param::ConvBias::Format::NCHW4,
  760. conv_bias::get_int8_nchw4_args_check_bounds(3));
  761. }
  762. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_ALGO_0) {
  763. require_compute_capability(7, 5);
  764. conv_bias::check_conv_bias(
  765. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  766. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  767. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  768. param::ConvBias::Format::CHWN4,
  769. conv_bias::get_int8_chwn4_tensorcore_args(3));
  770. }
  771. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_ALGO_1) {
  772. require_compute_capability(7, 5);
  773. conv_bias::check_conv_bias(
  774. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  775. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  776. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma32x8x16",
  777. param::ConvBias::Format::CHWN4,
  778. conv_bias::get_int8_chwn4_tensorcore_args(3));
  779. }
  780. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_ALGO_2) {
  781. require_compute_capability(7, 5);
  782. conv_bias::check_conv_bias(
  783. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  784. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  785. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma8x32x16",
  786. param::ConvBias::Format::CHWN4,
  787. conv_bias::get_int8_chwn4_tensorcore_args(3));
  788. }
  789. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_CHECK_BOUNDS_1x1) {
  790. require_compute_capability(7, 5);
  791. conv_bias::check_conv_bias(
  792. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  793. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  794. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  795. param::ConvBias::Format::CHWN4,
  796. conv_bias::get_int8_chwn4_args_check_bounds(1));
  797. }
  798. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_CHECK_BOUNDS_5x5) {
  799. require_compute_capability(7, 5);
  800. conv_bias::check_conv_bias(
  801. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  802. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  803. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  804. param::ConvBias::Format::CHWN4,
  805. conv_bias::get_int8_chwn4_args_check_bounds(5));
  806. }
  807. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_CHECK_BOUNDS_7x7) {
  808. require_compute_capability(7, 5);
  809. conv_bias::check_conv_bias(
  810. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  811. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  812. handle_cuda(), "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  813. param::ConvBias::Format::CHWN4,
  814. conv_bias::get_int8_chwn4_args_check_bounds(7));
  815. }
  816. TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_TENSORCORE_WITH_Z) {
  817. require_compute_capability(7, 5);
  818. Checker<ConvBiasForward> checker(handle_cuda());
  819. checker.set_before_exec_callback(
  820. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  821. "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16"));
  822. UniformIntRNG rng{-3, 3};
  823. UniformIntRNG bias_rng{-50, 50};
  824. checker.set_rng(0, &rng)
  825. .set_rng(1, &rng)
  826. .set_rng(2, &bias_rng)
  827. .set_rng(3, &rng)
  828. .set_dtype(0, dtype::QuantizedS8{1.2f})
  829. .set_dtype(1, dtype::QuantizedS8{1.3f})
  830. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  831. .set_dtype(3, dtype::QuantizedS8{1.1f})
  832. .set_dtype(4, dtype::QuantizedS8{1.0f})
  833. .set_epsilon(1 + 1e-3)
  834. .set_max_avg_error(1e-1)
  835. .set_max_avg_biased_error(1e-1);
  836. param::ConvBias param;
  837. param.pad_h = param.pad_w = 1;
  838. param.stride_h = param.stride_w = 1;
  839. param.format = param::ConvBias::Format::NCHW4;
  840. checker.set_param(param).execs({{64, 8, 12, 12, 4},
  841. {64, 8, 3, 3, 4},
  842. {1, 16, 1, 1, 4},
  843. {64, 16, 12, 12, 4},
  844. {}});
  845. }
  846. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_TENSORCORE_WITH_Z) {
  847. require_compute_capability(7, 5);
  848. Checker<ConvBiasForward> checker(handle_cuda());
  849. checker.set_before_exec_callback(
  850. conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
  851. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16"));
  852. UniformIntRNG rng{-3, 3};
  853. UniformIntRNG bias_rng{-50, 50};
  854. checker.set_rng(0, &rng)
  855. .set_rng(1, &rng)
  856. .set_rng(2, &bias_rng)
  857. .set_rng(3, &rng)
  858. .set_dtype(0, dtype::QuantizedS8{1.2f})
  859. .set_dtype(1, dtype::QuantizedS8{1.3f})
  860. .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
  861. .set_dtype(3, dtype::QuantizedS8{1.1f})
  862. .set_dtype(4, dtype::QuantizedS8{1.0f})
  863. .set_epsilon(1 + 1e-3)
  864. .set_max_avg_error(1e-1)
  865. .set_max_avg_biased_error(1e-1);
  866. param::ConvBias param;
  867. param.pad_h = param.pad_w = 1;
  868. param.stride_h = param.stride_w = 1;
  869. param.format = param::ConvBias::Format::CHWN4;
  870. checker.set_param(param).execs({{8, 12, 12, 64, 4},
  871. {8, 3, 3, 64, 4},
  872. {16, 1, 1, 1, 4},
  873. {16, 12, 12, 64, 4},
  874. {}});
  875. }
  876. TEST_F(CUDA,
  877. CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_CHECK_BOUNDS_ALGO_0) {
  878. require_compute_capability(7, 5);
  879. conv_bias::check_conv_bias(
  880. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  881. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  882. handle_cuda(),
  883. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma16x16x16",
  884. param::ConvBias::Format::CHWN4,
  885. conv_bias::get_int8_chwn4_args_check_bounds(3));
  886. }
  887. TEST_F(CUDA,
  888. CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_CHECK_BOUNDS_ALGO_1) {
  889. require_compute_capability(7, 5);
  890. conv_bias::check_conv_bias(
  891. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  892. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  893. handle_cuda(),
  894. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma8x32x16",
  895. param::ConvBias::Format::CHWN4,
  896. conv_bias::get_int8_chwn4_args_check_bounds(3));
  897. }
  898. TEST_F(CUDA,
  899. CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_CHECK_BOUNDS_ALGO_2) {
  900. require_compute_capability(7, 5);
  901. conv_bias::check_conv_bias(
  902. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  903. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  904. handle_cuda(),
  905. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma32x8x16",
  906. param::ConvBias::Format::CHWN4,
  907. conv_bias::get_int8_chwn4_args_check_bounds(3));
  908. }
  909. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_ALGO_0) {
  910. require_compute_capability(7, 5);
  911. conv_bias::check_conv_bias(
  912. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  913. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  914. handle_cuda(),
  915. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma16x16x16",
  916. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  917. }
  918. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_ALGO_1) {
  919. require_compute_capability(7, 5);
  920. conv_bias::check_conv_bias(
  921. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  922. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  923. handle_cuda(),
  924. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma8x32x16",
  925. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  926. }
  927. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_REFORMAT_FILTER_TENSORCORE_ALGO_2) {
  928. require_compute_capability(7, 5);
  929. conv_bias::check_conv_bias(
  930. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  931. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  932. handle_cuda(),
  933. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_mma32x8x16",
  934. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  935. }
  936. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_ALGO_0) {
  937. require_compute_capability(7, 5);
  938. conv_bias::check_conv_bias(
  939. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  940. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  941. handle_cuda(),
  942. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma16x16x16",
  943. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  944. }
  945. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_ALGO_1) {
  946. require_compute_capability(7, 5);
  947. conv_bias::check_conv_bias(
  948. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  949. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  950. handle_cuda(),
  951. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma8x32x16",
  952. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  953. }
  954. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_ALGO_2) {
  955. require_compute_capability(7, 5);
  956. conv_bias::check_conv_bias(
  957. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  958. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.3f},
  959. handle_cuda(),
  960. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma32x8x16",
  961. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(3));
  962. }
  963. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1) {
  964. require_compute_capability(7, 5);
  965. conv_bias::check_conv_bias(
  966. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  967. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  968. handle_cuda(),
  969. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma16x16x16",
  970. param::ConvBias::Format::CHWN4, conv_bias::get_int8_chwn4_args(1));
  971. }
  972. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_5x5) {
  973. require_compute_capability(7, 5);
  974. conv_bias::check_conv_bias(
  975. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  976. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  977. handle_cuda(),
  978. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma16x16x16",
  979. param::ConvBias::Format::CHWN4,
  980. conv_bias::get_int8_chwn4_args_small_batch(5));
  981. }
  982. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_7x7) {
  983. require_compute_capability(7, 5);
  984. conv_bias::check_conv_bias(
  985. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  986. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  987. handle_cuda(),
  988. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma16x16x16",
  989. param::ConvBias::Format::CHWN4,
  990. conv_bias::get_int8_chwn4_args_small_batch(7));
  991. }
  992. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_5x5_ALGO_1) {
  993. require_compute_capability(7, 5);
  994. conv_bias::check_conv_bias(
  995. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  996. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  997. handle_cuda(),
  998. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma32x8x16",
  999. param::ConvBias::Format::CHWN4,
  1000. conv_bias::get_int8_chwn4_args_small_batch(5));
  1001. }
  1002. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_5x5_ALGO_2) {
  1003. require_compute_capability(7, 5);
  1004. conv_bias::check_conv_bias(
  1005. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1006. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  1007. handle_cuda(),
  1008. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma8x32x16",
  1009. param::ConvBias::Format::CHWN4,
  1010. conv_bias::get_int8_chwn4_args_small_batch(5));
  1011. }
  1012. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_1) {
  1013. require_compute_capability(7, 5);
  1014. conv_bias::check_conv_bias(
  1015. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1016. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  1017. handle_cuda(),
  1018. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma32x8x16",
  1019. param::ConvBias::Format::CHWN4,
  1020. conv_bias::get_int8_chwn4_args_small_batch(1));
  1021. }
  1022. TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) {
  1023. require_compute_capability(7, 5);
  1024. conv_bias::check_conv_bias(
  1025. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1026. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.1f},
  1027. handle_cuda(),
  1028. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_mma8x32x16",
  1029. param::ConvBias::Format::CHWN4,
  1030. conv_bias::get_int8_chwn4_args_small_batch(1));
  1031. }
  1032. #if MEGDNN_WITH_BENCHMARK
  1033. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) {
  1034. require_compute_capability(6, 1);
  1035. benchmark_target_algo(
  1036. handle_cuda(), get_resnet50_bench_args(), dtype::QuantizedS8{1.2f},
  1037. dtype::QuantizedS8{1.3f}, dtype::QuantizedS32{1.2f * 1.3f},
  1038. dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  1039. param::ConvBias::Format::CHWN4);
  1040. }
  1041. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_NCHW4) {
  1042. require_compute_capability(6, 1);
  1043. benchmark_target_algo(
  1044. handle_cuda(), get_resnet50_bench_args(), dtype::QuantizedS8{1.2f},
  1045. dtype::QuantizedS8{1.3f}, dtype::QuantizedS32{1.2f * 1.3f},
  1046. dtype::QuantizedS8{1.0f}, "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM",
  1047. param::ConvBias::Format::NCHW4);
  1048. }
  1049. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_TENSORCORE) {
  1050. require_compute_capability(7, 5);
  1051. benchmark_target_algo_with_cudnn_tsc(
  1052. handle_cuda(), get_resnet50_bench_args(256),
  1053. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1054. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
  1055. "INT8_CHWN4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  1056. param::ConvBias::Format::CHWN4);
  1057. }
  1058. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_TENSORCORE_ALL_ALGO) {
  1059. require_compute_capability(7, 5);
  1060. benchmark_target_algo_with_cudnn_tsc(
  1061. handle_cuda(), get_resnet50_bench_args(256),
  1062. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1063. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, nullptr,
  1064. param::ConvBias::Format::CHWN4);
  1065. }
  1066. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_DET_ALL_ALGO) {
  1067. require_compute_capability(7, 5);
  1068. benchmark_target_algo_with_cudnn_tsc(
  1069. handle_cuda(), get_detection_bench_args(), dtype::QuantizedS8{1.2f},
  1070. dtype::QuantizedS8{1.3f}, dtype::QuantizedS32{1.2f * 1.3f},
  1071. dtype::QuantizedS8{1.0f}, nullptr, param::ConvBias::Format::CHWN4);
  1072. }
  1073. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_NCHW4_TENSORCORE) {
  1074. require_compute_capability(7, 5);
  1075. benchmark_target_algo_with_cudnn_tsc(
  1076. handle_cuda(), get_resnet50_bench_args(256),
  1077. dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f},
  1078. dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f},
  1079. "INT8_NCHW4_IMMA_IMPLICIT_GEMM_mma16x16x16",
  1080. param::ConvBias::Format::NCHW4);
  1081. }
  1082. TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) {
  1083. require_compute_capability(6, 1);
  1084. std::vector<BenchArgs> args;
  1085. args.push_back(BenchArgs{64, 4, 224, 224, 64, 7, 2});
  1086. benchmark_target_algo(
  1087. handle_cuda(), args, dtype::QuantizedS8{1.2f},
  1088. dtype::QuantizedS8{1.3f}, dtype::QuantizedS32{1.2f * 1.3f},
  1089. dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM",
  1090. param::ConvBias::Format::CHWN4);
  1091. }
  1092. #endif
  1093. } // namespace test
  1094. } // namespace megdnn
  1095. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台