You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. #include "megdnn/oprs/nn.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/dct_ref.h"
  4. #include "test/common/rng.h"
  5. #include "test/common/tensor.h"
  6. #include "test/naive/fixture.h"
  7. namespace megdnn {
  8. namespace test {
  9. TEST_F(NAIVE, DCT) {
  10. Checker<DctChannelSelectForward> checker(
  11. handle(),
  12. /* check_dispatch */ false);
  13. DctChannelSelectForward::Param param;
  14. checker.set_param(param).exect(
  15. Testcase{
  16. TensorValue(
  17. {1, 1, 16, 16}, dtype::Uint8(),
  18. {87, 155, 59, 161, 24, 200, 58, 3, 40, 43, 156, 7,
  19. 176, 232, 226, 78, 73, 236, 185, 109, 196, 169, 62, 32,
  20. 167, 180, 96, 157, 101, 53, 150, 47, 26, 238, 218, 210,
  21. 204, 236, 249, 111, 16, 35, 169, 204, 117, 16, 3, 147,
  22. 12, 233, 135, 162, 58, 118, 184, 237, 90, 105, 156, 195,
  23. 196, 104, 138, 19, 82, 62, 126, 140, 220, 171, 206, 232,
  24. 105, 123, 2, 135, 137, 41, 26, 219, 167, 245, 104, 103,
  25. 24, 144, 141, 210, 208, 114, 169, 170, 22, 11, 69, 106,
  26. 236, 150, 57, 184, 75, 241, 28, 175, 178, 186, 190, 124,
  27. 187, 116, 112, 162, 214, 154, 207, 31, 43, 40, 15, 188,
  28. 81, 197, 20, 199, 246, 132, 159, 111, 79, 95, 148, 184,
  29. 171, 173, 203, 146, 150, 33, 178, 9, 141, 49, 237, 222,
  30. 72, 5, 23, 38, 248, 82, 93, 229, 70, 180, 149, 232,
  31. 245, 72, 196, 138, 4, 31, 160, 30, 8, 109, 153, 252,
  32. 204, 126, 15, 182, 145, 130, 179, 234, 21, 240, 144, 105,
  33. 77, 116, 155, 232, 168, 99, 159, 92, 251, 223, 119, 173,
  34. 166, 39, 228, 91, 34, 5, 62, 172, 131, 164, 143, 10,
  35. 161, 165, 221, 214, 178, 110, 185, 254, 152, 149, 46, 144,
  36. 173, 237, 76, 210, 221, 45, 200, 113, 58, 20, 47, 135,
  37. 228, 80, 91, 51, 238, 194, 222, 231, 174, 244, 139, 96,
  38. 71, 25, 25, 62, 172, 181, 71, 27, 86, 0, 121, 38,
  39. 199, 236, 93, 158}),
  40. {},
  41. {},
  42. {}},
  43. Testcase{
  44. {},
  45. {},
  46. {},
  47. TensorValue(
  48. {1, 64, 2, 2}, dtype::Float32(),
  49. {1.10687500e+03, 9.59500000e+02, 8.98125000e+02,
  50. 1.21912500e+03, 1.38846378e+01, 3.91629181e+01,
  51. -1.50343018e+02, -1.02085358e+02, 2.34341068e+01,
  52. -8.40960388e+01, -4.23510742e+01, 1.72630596e+01,
  53. -4.66624413e+01, -4.87857285e+01, -7.06332016e+01,
  54. 6.31493912e+01, -9.96249924e+01, 7.72499924e+01,
  55. 7.46250153e+01, 5.81250114e+01, -9.07061768e+01,
  56. -7.68266630e+00, -3.15778809e+01, -3.35406876e+01,
  57. 8.55864143e+00, -7.36760712e+01, 6.20557327e+01,
  58. -2.92043419e+01, -1.39985870e+02, 2.56675129e+01,
  59. 5.21866226e+01, 1.07624054e+02, -6.16851950e+00,
  60. -8.56008530e+01, 7.35654449e+01, -2.56767311e+01,
  61. -2.09981880e+01, -6.22950821e+01, -1.31617493e+02,
  62. -6.30962448e+01, -2.21552780e+02, -4.79528542e+01,
  63. 1.04179153e+02, 7.45253448e+01, 3.19730816e+01,
  64. 1.24306192e+01, -9.93905945e+01, -8.95680237e+01,
  65. -1.44870041e+02, -9.44738235e+01, -4.09417763e+01,
  66. 4.50356903e+01, -3.65339231e+00, 5.79474449e+01,
  67. -2.46253452e+01, 3.29394951e+01, -1.09065903e+02,
  68. 5.23808861e+01, -1.00386992e+01, -7.92311325e+01,
  69. -1.44292374e+01, 5.74285736e+01, 2.28798485e+01,
  70. 6.84826508e+01, -1.49241837e+02, 9.35751495e+01,
  71. -4.02763329e+01, -6.63586197e+01, 2.15622040e+02,
  72. -7.83887939e+01, -8.06824951e+01, -2.51097183e+01,
  73. 1.58941059e+01, -5.66967869e+00, -1.53566467e+02,
  74. -4.33494377e+01, 8.12108078e+01, 1.21169144e+02,
  75. 2.14673615e+02, -3.72018318e+01, 2.45811577e+01,
  76. -1.27189613e+02, 4.98553581e+01, -5.83694696e+00,
  77. -4.80477619e+00, -2.24601650e+01, -5.02191353e+00,
  78. 5.16259460e+01, 1.07266571e+02, -3.41748886e+01,
  79. -5.44621315e+01, 6.25573196e+01, -4.24649086e+01,
  80. 4.42625465e+01, 2.71147366e+01, 4.83264275e+01,
  81. -6.99711227e+01, -1.00299120e+01, 1.33173111e+02,
  82. 2.48003254e+01, -1.74687519e+01, 9.44530487e-01,
  83. 1.35930038e+02, 6.72219162e+01, 4.53297043e+01,
  84. 1.37072708e+02, -7.73253784e+01, 6.12967606e+01,
  85. 9.78184891e+01, 3.63894577e+01, -1.64039135e+01,
  86. -6.67858887e+01, 5.27859840e+01, -4.99117432e+01,
  87. 8.77927475e+01, -5.86666260e+01, 3.86430244e+01,
  88. 2.17759323e+01, 8.34562683e+01, 3.06256886e+01,
  89. 1.61030369e+01, 8.11268158e+01, 1.36932516e+01,
  90. -1.06112595e+02, -9.31621475e+01, 3.13674717e+01,
  91. -4.90609503e+00, 7.96453857e+01, -1.02625000e+02,
  92. 1.40000076e+01, 3.18749981e+01, -1.08375000e+02,
  93. -5.44420319e+01, -1.50944397e+02, 5.29974670e+01,
  94. -1.44041641e+02, 4.86086197e+01, -7.13610382e+01,
  95. 3.06417294e+01, 7.20477829e+01, -6.95384140e+01,
  96. 1.25441925e+02, -1.54897385e+01, 3.78566666e+01,
  97. 4.23749886e+01, -3.37500000e+01, -9.96250000e+01,
  98. -6.73750076e+01, 3.34241295e+01, -6.24825974e+01,
  99. 1.76387348e+01, -6.45708389e+01, 1.70728874e+01,
  100. -5.73032570e+01, -1.71570969e+01, 1.84064590e+02,
  101. 4.17566071e+01, 7.08248520e+00, -2.59306641e+01,
  102. 1.37766739e+02, -2.16669798e+00, 6.03565750e+01,
  103. 6.84421844e+01, 6.19825096e+01, -1.44220114e+01,
  104. -3.12404213e+01, -2.50061111e+01, 6.73021851e+01,
  105. 2.52050266e+01, -8.35850677e+01, -4.70746574e+01,
  106. 1.73889160e+01, 1.18955564e+01, 6.16792488e+00,
  107. -3.29667168e+01, 4.55779572e+01, -4.17868996e+00,
  108. -9.40233841e+01, -9.77727051e+01, 1.74934635e+01,
  109. 5.25992851e+01, 1.23662634e+01, 5.26129305e-01,
  110. 4.69518929e+01, -1.52657738e+01, 9.96897888e+01,
  111. -9.51726151e+01, 9.99432602e+01, -1.75949844e+02,
  112. 1.00472336e+02, -5.89417953e+01, -1.72231483e+01,
  113. 1.89282093e+01, -8.17851868e+01, 7.22908936e+01,
  114. -9.06294174e+01, 2.46093607e+00, -4.03946457e+01,
  115. 2.17710762e+01, -5.62999649e+01, 4.77665749e+01,
  116. -4.04248848e+01, 4.78787374e+00, 1.05557320e+02,
  117. -4.60584450e+01, -7.33774490e+01, -4.25107193e+01,
  118. 1.71907139e+01, -8.01314316e+01, 1.69647141e+01,
  119. -8.24824219e+01, 8.29206543e+01, 3.72900200e+01,
  120. 3.77470016e+01, 6.70151443e+01, 1.79784470e+01,
  121. -4.01441078e+01, 6.29196739e+01, 7.60664597e+01,
  122. -5.59005699e+01, 8.81600475e+00, -6.89491081e+00,
  123. -8.03825378e+01, -5.33856511e-01, 7.26196136e+01,
  124. -3.76809120e+01, -1.08401566e+02, 6.35455990e+00,
  125. -8.66767120e+01, -1.02679443e+02, -9.54313660e+00,
  126. -3.55650787e+01, -1.21355652e+02, 2.32628040e+01,
  127. 3.94072838e+01, 1.24754738e+02, 9.51344986e+01,
  128. -5.84752541e+01, -4.65028038e+01, 6.00556993e+00,
  129. 4.94889374e+01, 7.64868622e+01, -1.49546280e+01,
  130. -3.70648766e+01, 5.55572205e+01, -1.17196434e+02,
  131. 9.20216217e+01, 3.29843826e+01, 3.25113411e+01,
  132. 5.62059135e+01, 6.30202141e+01, 4.99030991e+01,
  133. 2.85804024e+01, -1.44606361e+01, 7.64952774e+01,
  134. -2.95697536e+01})});
  135. }
  136. TEST_F(NAIVE, DCT_INT8) {
  137. Checker<DctChannelSelectForward> checker(
  138. handle(),
  139. /* check_dispatch */ false);
  140. DctChannelSelectForward::Param param;
  141. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  142. checker.set_param(param).exect(
  143. Testcase{
  144. TensorValue(
  145. {1, 1, 16, 16}, dtype::Uint8(),
  146. {113, 223, 229, 159, 249, 252, 89, 84, 45, 16, 41, 72,
  147. 184, 236, 70, 184, 86, 172, 218, 211, 47, 177, 18, 85,
  148. 174, 226, 37, 109, 38, 135, 228, 195, 133, 238, 47, 246,
  149. 244, 118, 175, 143, 34, 10, 28, 4, 82, 103, 89, 55,
  150. 235, 78, 151, 178, 249, 62, 183, 84, 105, 0, 121, 98,
  151. 249, 90, 161, 114, 121, 241, 21, 199, 196, 119, 231, 209,
  152. 250, 180, 192, 213, 116, 105, 114, 169, 1, 142, 3, 30,
  153. 140, 245, 201, 109, 19, 26, 224, 68, 123, 228, 64, 150,
  154. 184, 212, 136, 172, 241, 152, 222, 233, 15, 72, 130, 144,
  155. 107, 130, 242, 79, 195, 46, 226, 57, 183, 36, 88, 161,
  156. 121, 170, 2, 215, 109, 212, 35, 18, 76, 197, 117, 81,
  157. 208, 8, 237, 75, 15, 20, 16, 192, 61, 113, 96, 126,
  158. 211, 57, 49, 62, 185, 211, 155, 87, 233, 163, 164, 84,
  159. 61, 28, 1, 11, 190, 253, 145, 30, 38, 98, 153, 56,
  160. 231, 152, 12, 204, 96, 8, 47, 87, 25, 237, 21, 150,
  161. 173, 19, 41, 175, 164, 231, 39, 145, 39, 187, 210, 123,
  162. 165, 98, 87, 242, 38, 136, 182, 145, 41, 47, 147, 171,
  163. 172, 35, 170, 148, 26, 89, 107, 151, 130, 232, 65, 217,
  164. 27, 206, 68, 219, 60, 106, 3, 209, 175, 189, 191, 32,
  165. 119, 141, 56, 48, 105, 58, 94, 163, 185, 60, 83, 249,
  166. 112, 245, 137, 60, 178, 51, 177, 106, 199, 209, 4, 247,
  167. 3, 127, 88, 46}),
  168. {},
  169. {},
  170. {}},
  171. Testcase{
  172. {},
  173. {},
  174. {},
  175. TensorValue(
  176. {1, 16, 2, 2, 4}, dtype::QuantizedS8(10.f),
  177. {122, -1, -8, 4, 92, -13, -5, 7, 99, 4, 5, 3,
  178. 89, 7, 2, -6, 3, -8, -10, 2, -1, 0, 4, -3,
  179. -5, -8, -11, 1, 14, 4, -10, -18, 3, 12, -14, -2,
  180. -4, -9, 12, 4, -2, -2, 2, 6, -9, 6, 1, 5,
  181. -5, -1, 2, -12, 4, -5, -0, 4, 1, 5, -8, 5,
  182. -3, 4, 2, 6, -0, 9, -4, -7, -4, -5, -2, 8,
  183. 2, 4, 0, 7, -8, 4, -2, 3, -6, -5, 19, 5,
  184. -4, -4, -5, -16, -8, -3, -5, 19, 4, 3, 4, -6,
  185. 1, -12, -1, 7, 11, -5, -1, -8, 2, -12, -9, -2,
  186. -4, -20, -11, -15, -15, -9, -2, -9, -2, -3, 13, 2,
  187. 5, 6, 7, -4, 1, -7, 6, 4, 2, 6, 0, -0,
  188. 8, 8, -6, 5, 1, -2, -2, -12, 2, -12, -2, 6,
  189. 7, 3, 4, 14, 14, -3, 1, -3, 6, 0, -20, 2,
  190. -10, 10, -5, -5, 13, 0, -3, 7, -12, -17, -13, 1,
  191. -6, 10, -1, -9, 4, -16, 3, 2, 5, 1, -4, 9,
  192. -0, 1, 3, 15, -4, -13, -6, 4, 3, -2, -1, -4,
  193. -7, -7, -2, 8, -16, -4, -10, 5, 1, -3, 2, -9,
  194. -4, 1, -1, -1, -4, -6, -4, 1, 0, -9, 15, -1,
  195. -7, -3, -5, -0, 3, -0, -6, -17, 16, -3, 3, -2,
  196. -3, 5, 3, -2, 3, 13, 8, 1, -3, -8, -7, -4,
  197. 6, -6, -15, -7, 0, 4, -3, -3, -10, 14, 1, 3,
  198. 14, 4, -1, 14})});
  199. }
  200. TEST_F(NAIVE, DCT_INT8_MASK) {
  201. Checker<DctChannelSelectForward> checker(
  202. handle(),
  203. /* check_dispatch */ false);
  204. DctChannelSelectForward::Param param;
  205. param.format = DctChannelSelectForward::Param::Format::NCHW4;
  206. auto src_tensor = TensorValue(
  207. {1, 3, 8, 16}, dtype::Uint8(),
  208. {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37, 250, 107, 36,
  209. 80, 5, 54, 247, 218, 191, 211, 239, 76, 140, 33, 253, 85, 132, 101,
  210. 105, 177, 46, 183, 102, 99, 19, 175, 108, 252, 42, 238, 48, 251, 108,
  211. 90, 176, 2, 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249,
  212. 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1, 93, 179, 246,
  213. 134, 18, 178, 173, 36, 122, 89, 115, 46, 43, 205, 232, 55, 149, 30,
  214. 206, 97, 186, 125, 35, 209, 51, 48, 222, 222, 130, 173, 63, 0, 223,
  215. 19, 5, 162, 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212,
  216. 42, 26, 219, 225, 120, 94, 213, 238,
  217. 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76, 151, 105, 138,
  218. 87, 125, 55, 60, 211, 15, 158, 198, 37, 54, 203, 239, 79, 56, 6,
  219. 53, 201, 97, 233, 178, 74, 193, 46, 249, 65, 5, 208, 130, 67, 191,
  220. 168, 152, 129, 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22,
  221. 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104, 41, 217, 215,
  222. 184, 136, 249, 14, 244, 4, 220, 33, 53, 142, 219, 43, 28, 68, 198,
  223. 202, 88, 235, 7, 233, 47, 84, 127, 28, 17, 189, 135, 183, 192, 239,
  224. 116, 31, 118, 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48,
  225. 24, 225, 15, 3, 26, 71, 82, 104,
  226. 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173, 43, 195, 123,
  227. 111, 106, 5, 91, 254, 34, 76, 52, 82, 193, 179, 185, 71, 57, 215,
  228. 18, 5, 151, 13, 59, 206, 154, 95, 149, 40, 229, 16, 116, 144, 249,
  229. 67, 97, 223, 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250,
  230. 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173, 109, 39, 57,
  231. 13, 129, 79, 236, 117, 134, 123, 149, 113, 198, 160, 249, 242, 220, 226,
  232. 44, 113, 164, 217, 46, 249, 182, 22, 98, 228, 49, 78, 101, 236, 181,
  233. 5, 245, 72, 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81,
  234. 49, 217, 86, 229, 139, 203, 57, 194});
  235. checker.set_param(param).exect(
  236. Testcase{
  237. src_tensor,
  238. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  239. TensorValue(
  240. {32}, dtype::Int32(),
  241. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  242. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  243. {}},
  244. Testcase{
  245. {},
  246. {},
  247. {},
  248. TensorValue(
  249. {1, 8, 1, 2, 4}, dtype::QuantizedS8(10.f),
  250. {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8, 12,
  251. -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7, -0, -2,
  252. -7, 11, 6, 3, -1, 7, 94, -5, 6, -5, 98, 0, -3,
  253. -16, 5, 7, 13, -8, 1, 5, -5, -8, 108, -3, -8, -7,
  254. 110, 1, -2, 5, -0, 7, 8, -9, 14, -0, 1, -4})});
  255. checker.set_param(param).exect(
  256. Testcase{
  257. TensorValue(
  258. {1, 3, 8, 16}, dtype::Uint8(),
  259. {195, 165, 82, 30, 154, 60, 175, 195, 179, 165, 132, 37,
  260. 250, 107, 36, 80, 5, 54, 247, 218, 191, 211, 239, 76,
  261. 140, 33, 253, 85, 132, 101, 105, 177, 46, 183, 102, 99,
  262. 19, 175, 108, 252, 42, 238, 48, 251, 108, 90, 176, 2,
  263. 35, 46, 161, 252, 38, 225, 195, 174, 58, 165, 198, 249,
  264. 162, 118, 198, 41, 154, 10, 87, 24, 201, 12, 188, 1,
  265. 93, 179, 246, 134, 18, 178, 173, 36, 122, 89, 115, 46,
  266. 43, 205, 232, 55, 149, 30, 206, 97, 186, 125, 35, 209,
  267. 51, 48, 222, 222, 130, 173, 63, 0, 223, 19, 5, 162,
  268. 154, 143, 134, 63, 123, 102, 102, 212, 145, 80, 87, 212,
  269. 42, 26, 219, 225, 120, 94, 213, 238,
  270. 25, 172, 141, 45, 182, 203, 50, 94, 44, 88, 74, 76,
  271. 151, 105, 138, 87, 125, 55, 60, 211, 15, 158, 198, 37,
  272. 54, 203, 239, 79, 56, 6, 53, 201, 97, 233, 178, 74,
  273. 193, 46, 249, 65, 5, 208, 130, 67, 191, 168, 152, 129,
  274. 253, 195, 231, 3, 109, 229, 254, 193, 229, 202, 108, 22,
  275. 89, 251, 13, 53, 47, 192, 12, 81, 19, 53, 93, 104,
  276. 41, 217, 215, 184, 136, 249, 14, 244, 4, 220, 33, 53,
  277. 142, 219, 43, 28, 68, 198, 202, 88, 235, 7, 233, 47,
  278. 84, 127, 28, 17, 189, 135, 183, 192, 239, 116, 31, 118,
  279. 186, 49, 251, 233, 220, 27, 97, 30, 43, 193, 217, 48,
  280. 24, 225, 15, 3, 26, 71, 82, 104,
  281. 175, 125, 79, 195, 50, 236, 114, 179, 180, 177, 230, 173,
  282. 43, 195, 123, 111, 106, 5, 91, 254, 34, 76, 52, 82,
  283. 193, 179, 185, 71, 57, 215, 18, 5, 151, 13, 59, 206,
  284. 154, 95, 149, 40, 229, 16, 116, 144, 249, 67, 97, 223,
  285. 208, 144, 92, 174, 246, 77, 196, 211, 20, 123, 239, 250,
  286. 235, 65, 184, 54, 239, 168, 135, 17, 79, 117, 171, 173,
  287. 109, 39, 57, 13, 129, 79, 236, 117, 134, 123, 149, 113,
  288. 198, 160, 249, 242, 220, 226, 44, 113, 164, 217, 46, 249,
  289. 182, 22, 98, 228, 49, 78, 101, 236, 181, 5, 245, 72,
  290. 62, 182, 151, 210, 254, 190, 35, 73, 190, 247, 50, 81,
  291. 49, 217, 86, 229, 139, 203, 57, 194}),
  292. TensorValue({4}, dtype::Int32(), {0, 12, 20, 28}),
  293. TensorValue(
  294. {28}, dtype::Int32(),
  295. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 0, 1,
  296. 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  297. {}},
  298. Testcase{
  299. {},
  300. {},
  301. {},
  302. TensorValue(
  303. {1, 7, 1, 2, 4}, dtype::QuantizedS8(10.f),
  304. {100, -12, 7, 7, 104, 2, -2, -2, -7, -7, -3, 8,
  305. 12, -12, -5, -1, 5, -7, -1, 7, -7, -3, 6, 7,
  306. 94, -5, 6, -5, 98, 0, -3, -16, 5, 7, 13, -8,
  307. 1, 5, -5, -8, 108, -3, -8, -7, 110, 1, -2, 5,
  308. -0, 7, 8, -9, 14, -0, 1, -4})});
  309. }
  310. TEST_F(NAIVE, DCT_4x4) {
  311. Checker<DctChannelSelectForward> checker(
  312. handle(),
  313. /* check_dispatch */ false);
  314. DctChannelSelectForward::Param param;
  315. param.dct_block_size = 4;
  316. checker.set_param(param).exect(
  317. Testcase{
  318. TensorValue(
  319. {1, 1, 8, 8}, dtype::Uint8(),
  320. {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175,
  321. 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139,
  322. 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115,
  323. 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156,
  324. 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229,
  325. 159, 249, 252, 89, 84, 45, 16, 41, 72}),
  326. {},
  327. {},
  328. {}},
  329. Testcase{
  330. {},
  331. {},
  332. {},
  333. TensorValue(
  334. {1, 16, 2, 2}, dtype::Float32(),
  335. {5.42000000e+02, 5.91750000e+02, 6.78000000e+02,
  336. 4.27750000e+02, 3.49953423e+01, -1.17686939e+01,
  337. -1.66842098e+01, -3.85316620e+01, -3.80000000e+01,
  338. -1.22500000e+01, 2.00000000e+01, -9.77500000e+01,
  339. -1.61191311e+01, -9.46695328e+00, 3.28882408e+01,
  340. -4.92537880e+01, 1.66958221e+02, -4.26609573e+01,
  341. 2.56999969e-01, 5.39384537e+01, 1.71819706e+01,
  342. 9.00009003e+01, -1.23818558e+02, 1.18912420e+01,
  343. 6.61014938e+01, -2.49261990e+01, 4.95798302e+00,
  344. -1.02324417e+02, 7.85859919e+00, 3.73140755e+01,
  345. 1.03783745e+02, -4.61430321e+01, -1.43000000e+02,
  346. -7.57500000e+01, -5.00000000e-01, -8.27500000e+01,
  347. 1.34834738e+01, -1.93409515e+02, 6.84791718e+01,
  348. -4.01652241e+00, 1.22000000e+02, -8.57500000e+01,
  349. -4.05000000e+01, -5.62500000e+01, -2.88564739e+01,
  350. 5.76532059e+01, -2.67414131e+01, 1.70877876e+01,
  351. 3.85416756e+01, 3.09300461e+01, 5.84670639e+00,
  352. 1.85747864e+02, -2.05141403e+02, -9.91859360e+01,
  353. -1.66716263e+02, -1.71430378e+01, 6.71520996e+00,
  354. 8.41980438e+01, -3.50666313e+01, -1.48387482e+02,
  355. 1.08180256e+01, 5.49991112e+01, -1.06814528e+01,
  356. 1.86087704e+01})});
  357. checker.set_param(param).exect(
  358. Testcase{
  359. TensorValue(
  360. {1, 1, 8, 8}, dtype::Uint8(),
  361. {186, 120, 112, 220, 69, 80, 201, 127, 246, 254, 175,
  362. 50, 240, 251, 76, 37, 34, 166, 250, 195, 231, 139,
  363. 128, 233, 75, 80, 3, 2, 19, 140, 193, 203, 115,
  364. 107, 250, 209, 14, 243, 199, 60, 234, 107, 174, 156,
  365. 81, 87, 13, 116, 96, 140, 197, 253, 113, 223, 229,
  366. 159, 249, 252, 89, 84, 45, 16, 41, 72}),
  367. TensorValue({2}, dtype::Int32(), {0, 6}),
  368. TensorValue({6}, dtype::Int32(), {0, 1, 8, 4, 2, 3}),
  369. {}},
  370. Testcase{
  371. {},
  372. {},
  373. {},
  374. TensorValue(
  375. {1, 6, 2, 2}, dtype::Float32(),
  376. {5.4200000e+02, 5.9175000e+02, 6.7800000e+02,
  377. 4.2775000e+02, 3.4995342e+01, -1.1768694e+01,
  378. -1.6684210e+01, -3.8531662e+01, -1.4300000e+02,
  379. -7.5750000e+01, -5.0000000e-01, -8.2750000e+01,
  380. 1.6695822e+02, -4.2660957e+01, 2.5699997e-01,
  381. 5.3938454e+01, -3.8000000e+01, -1.2250000e+01,
  382. 2.0000000e+01, -9.7750000e+01, -1.6119131e+01,
  383. -9.4669533e+00, 3.2888241e+01, -4.9253788e+01})});
  384. }
  385. TEST_F(NAIVE, DCT_WITH_MASK) {
  386. Checker<DctChannelSelectForward> checker(
  387. handle(),
  388. /* check_dispatch */ false);
  389. DctChannelSelectForward::Param param;
  390. checker.set_param(param).exect(
  391. Testcase{
  392. TensorValue(
  393. {1, 3, 8, 16}, dtype::Uint8(),
  394. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  395. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  396. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  397. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  398. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  399. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  400. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  401. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  402. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  403. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  404. 238, 181, 232, 191, 161, 57, 23, 204,
  405. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  406. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  407. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  408. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  409. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  410. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  411. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  412. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  413. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  414. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  415. 238, 181, 232, 191, 161, 57, 23, 204,
  416. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  417. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  418. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  419. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  420. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  421. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  422. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  423. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  424. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  425. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  426. 238, 181, 232, 191, 161, 57, 23, 204}),
  427. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  428. TensorValue(
  429. {32}, dtype::Int32(),
  430. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  431. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  432. {}},
  433. Testcase{
  434. {},
  435. {},
  436. {},
  437. TensorValue(
  438. {1, 32, 1, 2}, dtype::Float32(),
  439. {890.12494, 941.25, -7.0498576, 99.47632,
  440. -22.850792, -97.862236, -101.043236, -4.727012,
  441. 28.275675, -157.96654, 42.1377, 45.06531,
  442. -149.77373, 24.487143, -8.054966, -13.990831,
  443. -6.9395194, -3.9211385, 64.79172, -12.363858,
  444. -47.875, 59., 56.271786, -62.725567,
  445. 120.522675, 16.559765, 85.74334, 112.904495,
  446. 99.375, 29.499973, 2.0220923, -19.681704,
  447. 890.12494, 941.25, -7.0498576, 99.47632,
  448. -22.850792, -97.862236, -101.043236, -4.727012,
  449. 28.275675, -157.96654, 42.1377, 45.06531,
  450. -149.77373, 24.487143, -8.054966, -13.990831,
  451. 890.12494, 941.25, -7.0498576, 99.47632,
  452. -22.850792, -97.862236, -101.043236, -4.727012,
  453. 28.275675, -157.96654, 42.1377, 45.06531,
  454. -149.77373, 24.487143, -8.054966, -13.990831})});
  455. checker.set_param(param).exect(
  456. Testcase{
  457. TensorValue(
  458. {1, 3, 8, 16}, dtype::Uint8(),
  459. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  460. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  461. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  462. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  463. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  464. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  465. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  466. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  467. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  468. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  469. 238, 181, 232, 191, 161, 57, 23, 204,
  470. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  471. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  472. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  473. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  474. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  475. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  476. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  477. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  478. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  479. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  480. 238, 181, 232, 191, 161, 57, 23, 204,
  481. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  482. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  483. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  484. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  485. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  486. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  487. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  488. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  489. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  490. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  491. 238, 181, 232, 191, 161, 57, 23, 204}),
  492. TensorValue({4}, dtype::Int32(), {0, 8, 16, 24}),
  493. TensorValue({24}, dtype::Int32(), {17, 24, 32, 25, 18, 11, 4, 5,
  494. 0, 1, 8, 16, 9, 2, 3, 10,
  495. 0, 1, 8, 16, 9, 2, 3, 10}),
  496. {}},
  497. Testcase{
  498. {},
  499. {},
  500. {},
  501. TensorValue(
  502. {1, 24, 1, 2}, dtype::Float32(),
  503. {-6.9395194, -3.9211385, 64.79172, -12.363858,
  504. -47.875, 59., 56.271786, -62.725567,
  505. 120.522675, 16.559765, 85.74334, 112.904495,
  506. 99.375, 29.499973, 2.0220923, -19.681704,
  507. 890.12494, 941.25, -7.0498576, 99.47632,
  508. -22.850792, -97.862236, -101.043236, -4.727012,
  509. 28.275675, -157.96654, 42.1377, 45.06531,
  510. -149.77373, 24.487143, -8.054966, -13.990831,
  511. 890.12494, 941.25, -7.0498576, 99.47632,
  512. -22.850792, -97.862236, -101.043236, -4.727012,
  513. 28.275675, -157.96654, 42.1377, 45.06531,
  514. -149.77373, 24.487143, -8.054966, -13.990831})});
  515. }
  516. TEST_F(NAIVE, DCT_WITH_FIX_32_MASK) {
  517. Checker<DctChannelSelectForward> checker(
  518. handle(),
  519. /* check_dispatch */ false);
  520. using Param = DctChannelSelectForward::Param;
  521. Param param;
  522. param.fastImpl = Param::FastImpl::FIX_32_MASK;
  523. checker.set_param(param).exect(
  524. Testcase{
  525. TensorValue(
  526. {1, 3, 8, 16}, dtype::Uint8(),
  527. {109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  528. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  529. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  530. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  531. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  532. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  533. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  534. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  535. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  536. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  537. 238, 181, 232, 191, 161, 57, 23, 204,
  538. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  539. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  540. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  541. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  542. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  543. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  544. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  545. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  546. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  547. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  548. 238, 181, 232, 191, 161, 57, 23, 204,
  549. 109, 39, 30, 115, 71, 15, 206, 139, 221, 5, 18, 16,
  550. 93, 185, 99, 102, 205, 172, 191, 29, 185, 6, 47, 84,
  551. 0, 47, 105, 203, 251, 73, 196, 83, 3, 211, 32, 181,
  552. 49, 111, 114, 83, 148, 232, 77, 17, 35, 2, 154, 100,
  553. 41, 135, 141, 206, 56, 91, 137, 199, 104, 192, 75, 122,
  554. 78, 65, 184, 69, 91, 82, 2, 172, 194, 240, 49, 145,
  555. 87, 210, 97, 190, 179, 93, 125, 105, 181, 207, 148, 178,
  556. 133, 53, 25, 198, 238, 151, 14, 120, 213, 195, 145, 20,
  557. 122, 107, 217, 185, 65, 5, 115, 110, 82, 206, 163, 86,
  558. 2, 2, 44, 125, 50, 38, 41, 106, 30, 5, 151, 243,
  559. 238, 181, 232, 191, 161, 57, 23, 204}),
  560. TensorValue({4}, dtype::Int32(), {0, 16, 24, 32}),
  561. TensorValue(
  562. {32}, dtype::Int32(),
  563. {0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5,
  564. 0, 1, 8, 16, 9, 2, 3, 10, 0, 1, 8, 16, 9, 2, 3, 10}),
  565. {}},
  566. Testcase{
  567. {},
  568. {},
  569. {},
  570. TensorValue(
  571. {1, 32, 1, 2}, dtype::Float32(),
  572. {890.12494, 941.25, -7.0498576, 99.47632,
  573. -22.850792, -97.862236, -101.043236, -4.727012,
  574. 28.275675, -157.96654, 42.1377, 45.06531,
  575. -149.77373, 24.487143, -8.054966, -13.990831,
  576. -6.9395194, -3.9211385, 64.79172, -12.363858,
  577. -47.875, 59., 56.271786, -62.725567,
  578. 120.522675, 16.559765, 85.74334, 112.904495,
  579. 99.375, 29.499973, 2.0220923, -19.681704,
  580. 890.12494, 941.25, -7.0498576, 99.47632,
  581. -22.850792, -97.862236, -101.043236, -4.727012,
  582. 28.275675, -157.96654, 42.1377, 45.06531,
  583. -149.77373, 24.487143, -8.054966, -13.990831,
  584. 890.12494, 941.25, -7.0498576, 99.47632,
  585. -22.850792, -97.862236, -101.043236, -4.727012,
  586. 28.275675, -157.96654, 42.1377, 45.06531,
  587. -149.77373, 24.487143, -8.054966, -13.990831})});
  588. }
  589. TEST_F(NAIVE, DCT_WITH_MASK2) {
  590. Checker<DctChannelSelectForward> checker(handle(), false);
  591. DctChannelSelectForward::Param param;
  592. UniformIntRNG rng_oc(0, 3 * 64);
  593. for (size_t n : {1, 3}) {
  594. for (size_t ic : {1, 3}) {
  595. for (size_t ih : {8, 16, 32, 512, 1024}) {
  596. for (size_t iw : {8, 16, 32, 64, 128, 256, 512, 1024}) {
  597. int random_oc = static_cast<int>(rng_oc.gen_single_val());
  598. int max_oc = ic * 64;
  599. int mask_oc = (random_oc % max_oc) + 1;
  600. auto test_case = gen_dct_case(n, ic, ih, iw, mask_oc, param);
  601. checker.set_param(param).exect(
  602. test_case->testcase_in, test_case->testcase_out);
  603. }
  604. }
  605. }
  606. }
  607. }
  608. } // namespace test
  609. } // namespace megdnn
  610. // vim: syntax=cpp.doxygen