You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dct.cpp 42 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. /**
  2. * \file dnn/test/naive/dct.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/dct_ref.h"
  15. #include "test/common/rng.h"
  16. #include "test/common/tensor.h"
  17. #include "test/naive/fixture.h"
  18. namespace megdnn {
  19. namespace test {
  20. TEST_F(NAIVE, DCT) {
  21. Checker<DctChannelSelectForward> checker(
  22. handle(),
  23. /* check_dispatch */ false);
  24. DctChannelSelectForward::Param param;
  25. checker.set_param(param).exect(
  26. Testcase{
  27. TensorValue(
  28. {1, 1, 16, 16}, dtype::Uint8(),
  29. {87, 155, 59, 161, 24, 200, 58, 3, 40, 43, 156, 7,
  30. 176, 232, 226, 78, 73, 236, 185, 109, 196, 169, 62, 32,
  31. 167, 180, 96, 157, 101, 53, 150, 47, 26, 238, 218, 210,
  32. 204, 236, 249, 111, 16, 35, 169, 204, 117, 16, 3, 147,
  33. 12, 233, 135, 162, 58, 118, 184, 237, 90, 105, 156, 195,
  34. 196, 104, 138, 19, 82, 62, 126, 140, 220, 171, 206, 232,
  35. 105, 123, 2, 135, 137, 41, 26, 219, 167, 245, 104, 103,
  36. 24, 144, 141, 210, 208, 114, 169, 170, 22, 11, 69, 106,
  37. 236, 150, 57, 184, 75, 241, 28, 175, 178, 186, 190, 124,
  38. 187, 116, 112, 162, 214, 154, 207, 31, 43, 40, 15, 188,
  39. 81, 197, 20, 199, 246, 132, 159, 111, 79, 95, 148, 184,
  40. 171, 173, 203, 146, 150, 33, 178, 9, 141, 49, 237, 222,
  41. 72, 5, 23, 38, 248, 82, 93, 229, 70, 180, 149, 232,
  42. 245, 72, 196, 138, 4, 31, 160, 30, 8, 109, 153, 252,
  43. 204, 126, 15, 182, 145, 130, 179, 234, 21, 240, 144, 105,
  44. 77, 116, 155, 232, 168, 99, 159, 92, 251, 223, 119, 173,
  45. 166, 39, 228, 91, 34, 5, 62, 172, 131, 164, 143, 10,
  46. 161, 165, 221, 214, 178, 110, 185, 254, 152, 149, 46, 144,
  47. 173, 237, 76, 210, 221, 45, 200, 113, 58, 20, 47, 135,
  48. 228, 80, 91, 51, 238, 194, 222, 231, 174, 244, 139, 96,
  49. 71, 25, 25, 62, 172, 181, 71, 27, 86, 0, 121, 38,
  50. 199, 236, 93, 158}),
  51. {},
  52. {},
  53. {}},
  54. Testcase{
  55. {},
  56. {},
  57. {},
  58. TensorValue(
  59. {1, 64, 2, 2}, dtype::Float32(),
  60. {1.10687500e+03, 9.59500000e+02, 8.98125000e+02,
  61. 1.21912500e+03, 1.38846378e+01, 3.91629181e+01,
  62. -1.50343018e+02, -1.02085358e+02, 2.34341068e+01,
  63. -8.40960388e+01, -4.23510742e+01, 1.72630596e+01,
  64. -4.66624413e+01, -4.87857285e+01, -7.06332016e+01,
  65. 6.31493912e+01, -9.96249924e+01, 7.72499924e+01,
  66. 7.46250153e+01, 5.81250114e+01, -9.07061768e+01,
  67. -7.68266630e+00, -3.15778809e+01, -3.35406876e+01,
  68. 8.55864143e+00, -7.36760712e+01, 6.20557327e+01,
  69. -2.92043419e+01, -1.39985870e+02, 2.56675129e+01,
  70. 5.21866226e+01, 1.07624054e+02, -6.16851950e+00,
  71. -8.56008530e+01, 7.35654449e+01, -2.56767311e+01,
  72. -2.09981880e+01, -6.22950821e+01, -1.31617493e+02,
  73. -6.30962448e+01, -2.21552780e+02, -4.79528542e+01,
  74. 1.04179153e+02, 7.45253448e+01, 3.19730816e+01,
  75. 1.24306192e+01, -9.93905945e+01, -8.95680237e+01,
  76. -1.44870041e+02, -9.44738235e+01, -4.09417763e+01,
  77. 4.50356903e+01, -3.65339231e+00, 5.79474449e+01,
  78. -2.46253452e+01, 3.29394951e+01, -1.09065903e+02,
  79. 5.23808861e+01, -1.00386992e+01, -7.92311325e+01,
  80. -1.44292374e+01, 5.74285736e+01, 2.28798485e+01,
  81. 6.84826508e+01, -1.49241837e+02, 9.35751495e+01,
  82. -4.02763329e+01, -6.63586197e+01, 2.15622040e+02,
  83. -7.83887939e+01, -8.06824951e+01, -2.51097183e+01,
  84. 1.58941059e+01, -5.66967869e+00, -1.53566467e+02,
  85. -4.33494377e+01, 8.12108078e+01, 1.21169144e+02,
  86. 2.14673615e+02, -3.72018318e+01, 2.45811577e+01,
  87. -1.27189613e+02, 4.98553581e+01, -5.83694696e+00,
  88. -4.80477619e+00, -2.24601650e+01, -5.02191353e+00,
  89. 5.16259460e+01, 1.07266571e+02, -3.41748886e+01,
  90. -5.44621315e+01, 6.25573196e+01, -4.24649086e+01,
  91. 4.42625465e+01, 2.71147366e+01, 4.83264275e+01,
  92. -6.99711227e+01, -1.00299120e+01, 1.33173111e+02,
  93. 2.48003254e+01, -1.74687519e+01, 9.44530487e-01,
  94. 1.35930038e+02, 6.72219162e+01, 4.53297043e+01,
  95. 1.37072708e+02, -7.73253784e+01, 6.12967606e+01,
  96. 9.78184891e+01, 3.63894577e+01, -1.64039135e+01,
  97. -6.67858887e+01, 5.27859840e+01, -4.99117432e+01,
  98. 8.77927475e+01, -5.86666260e+01, 3.86430244e+01,
  99. 2.17759323e+01, 8.34562683e+01, 3.06256886e+01,
  100. 1.61030369e+01, 8.11268158e+01, 1.36932516e+01,
  101. -1.06112595e+02, -9.31621475e+01, 3.13674717e+01,
  102. -4.90609503e+00, 7.96453857e+01, -1.02625000e+02,
  103. 1.40000076e+01, 3.18749981e+01, -1.08375000e+02,
  104. -5.44420319e+01, -1.50944397e+02, 5.29974670e+01,
  105. -1.44041641e+02, 4.86086197e+01, -7.13610382e+01,
  106. 3.06417294e+01, 7.20477829e+01, -6.95384140e+01,
  107. 1.25441925e+02, -1.54897385e+01, 3.78566666e+01,
  108. 4.23749886e+01, -3.37500000e+01, -9.96250000e+01,
  109. -6.73750076e+01, 3.34241295e+01, -6.24825974e+01,
  110. 1.76387348e+01, -6.45708389e+01, 1.70728874e+01,
  111. -5.73032570e+01, -1.71570969e+01, 1.84064590e+02,
  112. 4.17566071e+01, 7.08248520e+00, -2.59306641e+01,
  113. 1.37766739e+02, -2.16669798e+00, 6.03565750e+01,
  114. 6.84421844e+01, 6.19825096e+01, -1.44220114e+01,
  115. -3.12404213e+01, -2.50061111e+01, 6.73021851e+01,
  116. 2.52050266e+01, -8.35850677e+01, -4.70746574e+01,
  117. 1.73889160e+01, 1.18955564e+01, 6.16792488e+00,
  118. -3.29667168e+01, 4.55779572e+01, -4.17868996e+00,
  119. -9.40233841e+01, -9.77727051e+01, 1.74934635e+01,
  120. 5.25992851e+01, 1.23662634e+01, 5.26129305e-01,
  121. 4.69518929e+01, -1.52657738e+01, 9.96897888e+01,
  122. -9.51726151e+01, 9.99432602e+01, -1.75949844e+02,
  123. 1.00472336e+02, -5.89417953e+01, -1.72231483e+01,
  124. 1.89282093e+01, -8.17851868e+01, 7.22908936e+01,
  125. -9.06294174e+01, 2.46093607e+00, -4.03946457e+01,
  126. 2.17710762e+01, -5.62999649e+01, 4.77665749e+01,
  127. -4.04248848e+01, 4.78787374e+00, 1.05557320e+02,
  128. -4.60584450e+01, -7.33774490e+01, -4.25107193e+01,
  129. 1.71907139e+01, -8.01314316e+01, 1.69647141e+01,
  130. -8.24824219e+01, 8.29206543e+01, 3.72900200e+01,
  131. 3.77470016e+01, 6.70151443e+01, 1.79784470e+01,
  132. -4.01441078e+01, 6.29196739e+01, 7.60664597e+01,
  133. -5.59005699e+01, 8.81600475e+00, -6.89491081e+00,
  134. -8.03825378e+01, -5.33856511e-01, 7.26196136e+01,
  135. -3.76809120e+01, -1.08401566e+02, 6.35455990e+00,
  136. -8.66767120e+01, -1.02679443e+02, -9.54313660e+00,
  137. -3.55650787e+01, -1.21355652e+02, 2.32628040e+01,
  138. 3.94072838e+01, 1.24754738e+02, 9.51344986e+01,
  139. -5.84752541e+01, -4.65028038e+01, 6.00556993e+00,
  140. 4.94889374e+01, 7.64868622e+01, -1.49546280e+01,
  141. -3.70648766e+01, 5.55572205e+01, -1.17196434e+02,
  142. 9.20216217e+01, 3.29843826e+01, 3.25113411e+01,
  143. 5.62059135e+01, 6.30202141e+01, 4.99030991e+01,
  144. 2.85804024e+01, -1.44606361e+01, 7.64952774e+01,
  145. -2.95697536e+01})});
  146. }
  147. TEST_F(NAIVE, DCT_INT8) {
  148. Checker<DctChannelSelectForward> checker(
  149. handle(),
  150. /* check_dispatch */ false);
  151. DctChannelSelectForward::Param param;
  152. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  153. checker.set_param(param).exect(
  154. Testcase{
  155. TensorValue(
  156. {1, 1, 16, 16}, dtype::Uint8(),
  157. {113, 223, 229, 159, 249, 252, 89, 84, 45, 16, 41, 72,
  158. 184, 236, 70, 184, 86, 172, 218, 211, 47, 177, 18, 85,
  159. 174, 226, 37, 109, 38, 135, 228, 195, 133, 238, 47, 246,
  160. 244, 118, 175, 143, 34, 10, 28, 4, 82, 103, 89, 55,
  161. 235, 78, 151, 178, 249, 62, 183, 84, 105, 0, 121, 98,
  162. 249, 90, 161, 114, 121, 241, 21, 199, 196, 119, 231, 209,
  163. 250, 180, 192, 213, 116, 105, 114, 169, 1, 142, 3, 30,
  164. 140, 245, 201, 109, 19, 26, 224, 68, 123, 228, 64, 150,
  165. 184, 212, 136, 172, 241, 152, 222, 233, 15, 72, 130, 144,
  166. 107, 130, 242, 79, 195, 46, 226, 57, 183, 36, 88, 161,
  167. 121, 170, 2, 215, 109, 212, 35, 18, 76, 197, 117, 81,
  168. 208, 8, 237, 75, 15, 20, 16, 192, 61, 113, 96, 126,
  169. 211, 57, 49, 62, 185, 211, 155, 87, 233, 163, 164, 84,
  170. 61, 28, 1, 11, 190, 253, 145, 30, 38, 98, 153, 56,
  171. 231, 152, 12, 204, 96, 8, 47, 87, 25, 237, 21, 150,
  172. 173, 19, 41, 175, 164, 231, 39, 145, 39, 187, 210, 123,
  173. 165, 98, 87, 242, 38, 136, 182, 145, 41, 47, 147, 171,
  174. 172, 35, 170, 148, 26, 89, 107, 151, 130, 232, 65, 217,
  175. 27, 206, 68, 219, 60, 106, 3, 209, 175, 189, 191, 32,
  176. 119, 141, 56, 48, 105, 58, 94, 163, 185, 60, 83, 249,
  177. 112, 245, 137, 60, 178, 51, 177, 106, 199, 209, 4, 247,
  178. 3, 127, 88, 46}),
  179. {},
  180. {},
  181. {}},
  182. Testcase{
  183. {},
  184. {},
  185. {},
  186. TensorValue(
  187. {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f),
  188. {122, -1, -8, 4, 92, -13, -5, 7, 99, 4, 5, 3,
  189. 89, 7, 2, -6, 3, -8, -10, 2, -1, 0, 4, -3,
  190. -5, -8, -11, 1, 14, 4, -10, -18, 3, 12, -14, -2,
  191. -4, -9, 12, 4, -2, -2, 2, 6, -9, 6, 1, 5,
  192. -5, -1, 2, -12, 4, -5, -0, 4, 1, 5, -8, 5,
  193. -3, 4, 2, 6, -0, 9, -4, -7, -4, -5, -2, 8,
  194. 2, 4, 0, 7, -8, 4, -2, 3, -6, -5, 19, 5,
  195. -4, -4, -5, -16, -8, -3, -5, 19, 4, 3, 4, -6,
  196. 1, -12, -1, 7, 11, -5, -1, -8, 2, -12, -9, -2,
  197. -4, -20, -11, -15, -15, -9, -2, -9, -2, -3, 13, 2,
  198. 5, 6, 7, -4, 1, -7, 6, 4, 2, 6, 0, -0,
  199. 8, 8, -6, 5, 1, -2, -2, -12, 2, -12, -2, 6,
  200. 7, 3, 4, 14, 14, -3, 1, -3, 6, 0, -20, 2,
  201. -10, 10, -5, -5, 13, 0, -3, 7, -12, -17, -13, 1,
  202. -6, 10, -1, -9, 4, -16, 3, 2, 5, 1, -4, 9,
  203. -0, 1, 3, 15, -4, -13, -6, 4, 3, -2, -1, -4,
  204. -7, -7, -2, 8, -16, -4, -10, 5, 1, -3, 2, -9,
  205. -4, 1, -1, -1, -4, -6, -4, 1, 0, -9, 15, -1,
  206. -7, -3, -5, -0, 3, -0, -6, -17, 16, -3, 3, -2,
  207. -3, 5, 3, -2, 3, 13, 8, 1, -3, -8, -7, -4,
  208. 6, -6, -15, -7, 0, 4, -3, -3, -10, 14, 1, 3,
  209. 14, 4, -1, 14})});
  210. }
  211. TEST_F(NAIVE, DCT_INT8_MASK) {
  212. Checker<DctChannelSelectForward> checker(
  213. handle(),
  214. /* check_dispatch */ false);
  215. DctChannelSelectForward::Param param;
  216. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  217. auto src_tensor = TensorValue(
  218. {1, 3, 8, 16}, dtype::Uint8(),
  219. {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, 107, 36,
  220. 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, 253, 85, 132, 101,
  221. 105, 177, 46, 183, 102, 99, 19, 175, 108, 252, 42, 238, 48, 251, 108,
  222. 90, 176, 2, 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249,
  223. 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1, 93, 179, 246,
  224. 134, 18, 178, 173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30,
  225. 206, 97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, 223,
  226. 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212,
  227. 42, 26, 219, 225, 120, 94, 213, 238,
  228. 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, 105, 138,
  229. 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, 239, 79, 56, 6,
  230. 53, 201, 97, 233, 178, 74, 193, 46, 249, 65, 5, 208, 130, 67, 191,
  231. 168, 152, 129, 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22,
  232. 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104, 41, 217, 215,
  233. 184, 136, 249, 14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198,
  234. 202, 88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, 239,
  235. 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48,
  236. 24, 225, 15, 3, 26, 71, 82, 104,
  237. 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, 195, 123,
  238. 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, 185, 71, 57, 215,
  239. 18, 5, 151, 13, 59, 206, 154, 95, 149, 40, 229, 16, 116, 144, 249,
  240. 67, 97, 223, 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250,
  241. 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173, 109, 39, 57,
  242. 13, 129, 79, 236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226,
  243. 44, 113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, 181,
  244. 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81,
  245. 49, 217, 86, 229, 139, 203, 57, 194});
  246. checker.set_param(param).exect(
  247. Testcase{
  248. src_tensor,
  249. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  250. TensorValue(
  251. {32}, dtype::Int32(),
  252. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  253. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  254. {}},
  255. Testcase{
  256. {},
  257. {},
  258. {},
  259. TensorValue(
  260. {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f),
  261. {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8, 12,
  262. -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7, -0, -2,
  263. -7, 11, 6, 3, -1, 7, 94, -5, 6, -5, 98, 0, -3,
  264. -16, 5, 7, 13, -8, 1, 5, -5, -8, 108, -3, -8, -7,
  265. 110, 1, -2, 5, -0, 7, 8, -9, 14, -0, 1, -4})});
  266. checker.set_param(param).exect(
  267. Testcase{
  268. TensorValue(
  269. {1, 3, 8, 16}, dtype::Uint8(),
  270. {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37,
  271. 250, 107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76,
  272. 140, 33, 253, 85, 132, 101, 105, 177, 46, 183, 102, 99,
  273. 19, 175, 108, 252, 42, 238, 48, 251, 108, 90, 176, 2,
  274. 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249,
  275. 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1,
  276. 93, 179, 246, 134, 18, 178, 173, 36, 122, 89, 115, 46,
  277. 43, 205, 232, 55, 149, 30, 206, 97, 186, 125, 35, 209,
  278. 51, 48, 222, 222, 130, 173, 63, 0, 223, 19, 5, 162,
  279. 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212,
  280. 42, 26, 219, 225, 120, 94, 213, 238,
  281. 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76,
  282. 151, 105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37,
  283. 54, 203, 239, 79, 56, 6, 53, 201, 97, 233, 178, 74,
  284. 193, 46, 249, 65, 5, 208, 130, 67, 191, 168, 152, 129,
  285. 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22,
  286. 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104,
  287. 41, 217, 215, 184, 136, 249, 14, 244, 4, 220, 33, 53,
  288. 142, 219, 43, 28, 68, 198, 202, 88, 235, 7, 233, 47,
  289. 84, 127, 28, 17, 189, 135, 183, 192, 239, 116, 31, 118,
  290. 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48,
  291. 24, 225, 15, 3, 26, 71, 82, 104,
  292. 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173,
  293. 43, 195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82,
  294. 193, 179, 185, 71, 57, 215, 18, 5, 151, 13, 59, 206,
  295. 154, 95, 149, 40, 229, 16, 116, 144, 249, 67, 97, 223,
  296. 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250,
  297. 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173,
  298. 109, 39, 57, 13, 129, 79, 236, 117, 134, 123, 149, 113,
  299. 198, 160, 249, 242, 220, 226, 44, 113, 164, 217, 46, 249,
  300. 182, 22, 98, 228, 49, 78, 101, 236, 181, 5, 245, 72,
  301. 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81,
  302. 49, 217, 86, 229, 139, 203, 57, 194}),
  303. TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}),
  304. TensorValue(
  305. {28}, dtype::Int32(),
  306. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 0, 1,
  307. 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  308. {}},
  309. Testcase{
  310. {},
  311. {},
  312. {},
  313. TensorValue(
  314. {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f),
  315. {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8,
  316. 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7,
  317. 94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8,
  318. 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, 5,
  319. -0, 7, 8, -9, 14, -0, 1, -4})});
  320. }
  321. TEST_F(NAIVE, DCT_4x4) {
  322. Checker<DctChannelSelectForward> checker(
  323. handle(),
  324. /* check_dispatch */ false);
  325. DctChannelSelectForward::Param param;
  326. param.dct_block_size = 4;
  327. checker.set_param(param).exect(
  328. Testcase{
  329. TensorValue(
  330. {1, 1, 8, 8}, dtype::Uint8(),
  331. {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175,
  332. 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139,
  333. 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115,
  334. 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156,
  335. 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229,
  336. 159, 249, 252, 89, 84, 45, 16, 41, 72}),
  337. {},
  338. {},
  339. {}},
  340. Testcase{
  341. {},
  342. {},
  343. {},
  344. TensorValue(
  345. {1, 16, 2, 2}, dtype::Float32(),
  346. {5.42000000e+02, 5.91750000e+02, 6.78000000e+02,
  347. 4.27750000e+02, 3.49953423e+01, -1.17686939e+01,
  348. -1.66842098e+01, -3.85316620e+01, -3.80000000e+01,
  349. -1.22500000e+01, 2.00000000e+01, -9.77500000e+01,
  350. -1.61191311e+01, -9.46695328e+00, 3.28882408e+01,
  351. -4.92537880e+01, 1.66958221e+02, -4.26609573e+01,
  352. 2.56999969e-01, 5.39384537e+01, 1.71819706e+01,
  353. 9.00009003e+01, -1.23818558e+02, 1.18912420e+01,
  354. 6.61014938e+01, -2.49261990e+01, 4.95798302e+00,
  355. -1.02324417e+02, 7.85859919e+00, 3.73140755e+01,
  356. 1.03783745e+02, -4.61430321e+01, -1.43000000e+02,
  357. -7.57500000e+01, -5.00000000e-01, -8.27500000e+01,
  358. 1.34834738e+01, -1.93409515e+02, 6.84791718e+01,
  359. -4.01652241e+00, 1.22000000e+02, -8.57500000e+01,
  360. -4.05000000e+01, -5.62500000e+01, -2.88564739e+01,
  361. 5.76532059e+01, -2.67414131e+01, 1.70877876e+01,
  362. 3.85416756e+01, 3.09300461e+01, 5.84670639e+00,
  363. 1.85747864e+02, -2.05141403e+02, -9.91859360e+01,
  364. -1.66716263e+02, -1.71430378e+01, 6.71520996e+00,
  365. 8.41980438e+01, -3.50666313e+01, -1.48387482e+02,
  366. 1.08180256e+01, 5.49991112e+01, -1.06814528e+01,
  367. 1.86087704e+01})});
  368. checker.set_param(param).exect(
  369. Testcase{
  370. TensorValue(
  371. {1, 1, 8, 8}, dtype::Uint8(),
  372. {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175,
  373. 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139,
  374. 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115,
  375. 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156,
  376. 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229,
  377. 159, 249, 252, 89, 84, 45, 16, 41, 72}),
  378. TensorValue({2}, dtype::Int32(), {0, 6}),
  379. TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}),
  380. {}},
  381. Testcase{
  382. {},
  383. {},
  384. {},
  385. TensorValue(
  386. {1, 6, 2, 2}, dtype::Float32(),
  387. {5.4200000e+02, 5.9175000e+02, 6.7800000e+02,
  388. 4.2775000e+02, 3.4995342e+01, -1.1768694e+01,
  389. -1.6684210e+01, -3.8531662e+01, -1.4300000e+02,
  390. -7.5750000e+01, -5.0000000e-01, -8.2750000e+01,
  391. 1.6695822e+02, -4.2660957e+01, 2.5699997e-01,
  392. 5.3938454e+01, -3.8000000e+01, -1.2250000e+01,
  393. 2.0000000e+01, -9.7750000e+01, -1.6119131e+01,
  394. -9.4669533e+00, 3.2888241e+01, -4.9253788e+01})});
  395. }
  396. TEST_F(NAIVE, DCT_WITH_MASK) {
  397. Checker<DctChannelSelectForward> checker(
  398. handle(),
  399. /* check_dispatch */ false);
  400. DctChannelSelectForward::Param param;
  401. checker.set_param(param).exect(
  402. Testcase{
  403. TensorValue(
  404. {1, 3, 8, 16}, dtype::Uint8(),
  405. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  406. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  407. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  408. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  409. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  410. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  411. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  412. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  413. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  414. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  415. 238, 181, 232, 191, 161, 57, 23, 204,
  416. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  417. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  418. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  419. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  420. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  421. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  422. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  423. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  424. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  425. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  426. 238, 181, 232, 191, 161, 57, 23, 204,
  427. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  428. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  429. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  430. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  431. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  432. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  433. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  434. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  435. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  436. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  437. 238, 181, 232, 191, 161, 57, 23, 204}),
  438. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  439. TensorValue(
  440. {32}, dtype::Int32(),
  441. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  442. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  443. {}},
  444. Testcase{
  445. {},
  446. {},
  447. {},
  448. TensorValue(
  449. {1, 32, 1, 2}, dtype::Float32(),
  450. {890.12494, 941.25, -7.0498576, 99.47632,
  451. -22.850792, -97.862236, -101.043236, -4.727012,
  452. 28.275675, -157.96654, 42.1377, 45.06531,
  453. -149.77373, 24.487143, -8.054966, -13.990831,
  454. -6.9395194, -3.9211385, 64.79172, -12.363858,
  455. -47.875, 59., 56.271786, -62.725567,
  456. 120.522675, 16.559765, 85.74334, 112.904495,
  457. 99.375, 29.499973, 2.0220923, -19.681704,
  458. 890.12494, 941.25, -7.0498576, 99.47632,
  459. -22.850792, -97.862236, -101.043236, -4.727012,
  460. 28.275675, -157.96654, 42.1377, 45.06531,
  461. -149.77373, 24.487143, -8.054966, -13.990831,
  462. 890.12494, 941.25, -7.0498576, 99.47632,
  463. -22.850792, -97.862236, -101.043236, -4.727012,
  464. 28.275675, -157.96654, 42.1377, 45.06531,
  465. -149.77373, 24.487143, -8.054966, -13.990831})});
  466. checker.set_param(param).exect(
  467. Testcase{
  468. TensorValue(
  469. {1, 3, 8, 16}, dtype::Uint8(),
  470. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  471. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  472. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  473. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  474. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  475. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  476. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  477. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  478. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  479. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  480. 238, 181, 232, 191, 161, 57, 23, 204,
  481. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  482. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  483. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  484. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  485. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  486. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  487. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  488. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  489. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  490. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  491. 238, 181, 232, 191, 161, 57, 23, 204,
  492. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  493. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  494. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  495. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  496. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  497. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  498. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  499. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  500. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  501. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  502. 238, 181, 232, 191, 161, 57, 23, 204}),
  503. TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}),
  504. TensorValue({24}, dtype::Int32(), {17, 24, 32, 25, 18, 11, 4, 5,
  505. 0, 1, 8, 16, 9, 2, 3, 10,
  506. 0, 1, 8, 16, 9, 2, 3, 10}),
  507. {}},
  508. Testcase{
  509. {},
  510. {},
  511. {},
  512. TensorValue(
  513. {1, 24, 1, 2}, dtype::Float32(),
  514. {-6.9395194, -3.9211385, 64.79172, -12.363858,
  515. -47.875, 59., 56.271786, -62.725567,
  516. 120.522675, 16.559765, 85.74334, 112.904495,
  517. 99.375, 29.499973, 2.0220923, -19.681704,
  518. 890.12494, 941.25, -7.0498576, 99.47632,
  519. -22.850792, -97.862236, -101.043236, -4.727012,
  520. 28.275675, -157.96654, 42.1377, 45.06531,
  521. -149.77373, 24.487143, -8.054966, -13.990831,
  522. 890.12494, 941.25, -7.0498576, 99.47632,
  523. -22.850792, -97.862236, -101.043236, -4.727012,
  524. 28.275675, -157.96654, 42.1377, 45.06531,
  525. -149.77373, 24.487143, -8.054966, -13.990831})});
  526. }
  527. TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) {
  528. Checker<DctChannelSelectForward> checker(
  529. handle(),
  530. /* check_dispatch */ false);
  531. using Param = DctChannelSelectForward::Param;
  532. Param param;
  533. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  534. checker.set_param(param).exect(
  535. Testcase{
  536. TensorValue(
  537. {1, 3, 8, 16}, dtype::Uint8(),
  538. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  539. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  540. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  541. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  542. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  543. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  544. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  545. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  546. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  547. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  548. 238, 181, 232, 191, 161, 57, 23, 204,
  549. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  550. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  551. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  552. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  553. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  554. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  555. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  556. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  557. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  558. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  559. 238, 181, 232, 191, 161, 57, 23, 204,
  560. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  561. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  562. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  563. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  564. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  565. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  566. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  567. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  568. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  569. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  570. 238, 181, 232, 191, 161, 57, 23, 204}),
  571. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  572. TensorValue(
  573. {32}, dtype::Int32(),
  574. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  575. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  576. {}},
  577. Testcase{
  578. {},
  579. {},
  580. {},
  581. TensorValue(
  582. {1, 32, 1, 2}, dtype::Float32(),
  583. {890.12494, 941.25, -7.0498576, 99.47632,
  584. -22.850792, -97.862236, -101.043236, -4.727012,
  585. 28.275675, -157.96654, 42.1377, 45.06531,
  586. -149.77373, 24.487143, -8.054966, -13.990831,
  587. -6.9395194, -3.9211385, 64.79172, -12.363858,
  588. -47.875, 59., 56.271786, -62.725567,
  589. 120.522675, 16.559765, 85.74334, 112.904495,
  590. 99.375, 29.499973, 2.0220923, -19.681704,
  591. 890.12494, 941.25, -7.0498576, 99.47632,
  592. -22.850792, -97.862236, -101.043236, -4.727012,
  593. 28.275675, -157.96654, 42.1377, 45.06531,
  594. -149.77373, 24.487143, -8.054966, -13.990831,
  595. 890.12494, 941.25, -7.0498576, 99.47632,
  596. -22.850792, -97.862236, -101.043236, -4.727012,
  597. 28.275675, -157.96654, 42.1377, 45.06531,
  598. -149.77373, 24.487143, -8.054966, -13.990831})});
  599. }
  600. TEST_F(NAIVE, DCT_WITH_MASK2) {
  601. Checker<DctChannelSelectForward> checker(handle(), false);
  602. DctChannelSelectForward::Param param;
  603. UniformIntRNG rng_oc(0, 3 * 64);
  604. for (size_t n : {1, 3}) {
  605. for (size_t ic : {1, 3}) {
  606. for (size_t ih : {8, 16, 32, 512, 1024}) {
  607. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  608. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  609. int max_oc = ic * 64;
  610. int mask_oc = (random_oc % max_oc) + 1;
  611. auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param);
  612. checker.set_param(param).exect(
  613. test_case->testcase_in, test_case->testcase_out);
  614. }
  615. }
  616. }
  617. }
  618. }
  619. } // namespace test
  620. } // namespace megdnn
  621. // vim: syntax=cpp.doxygen