You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acl_cblas.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. /**
  2. * @file acl_cblas.h
  3. *
  4. * Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved.
  5. *
  6. * This program is distributed in the hope that it will be useful,
  7. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  8. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  9. */
  10. #ifndef INC_EXTERNAL_ACL_OPS_ACL_CBLAS_H_
  11. #define INC_EXTERNAL_ACL_OPS_ACL_CBLAS_H_
  12. #include "../acl.h"
  13. #ifdef __cplusplus
  14. extern "C" {
  15. #endif
  16. typedef enum aclTransType {
  17. ACL_TRANS_N,
  18. ACL_TRANS_T,
  19. ACL_TRANS_NZ,
  20. ACL_TRANS_NZ_T
  21. } aclTransType;
  22. typedef enum aclComputeType {
  23. ACL_COMPUTE_HIGH_PRECISION,
  24. ACL_COMPUTE_LOW_PRECISION
  25. } aclComputeType;
  26. /**
  27. * @ingroup AscendCL
  28. * @brief perform the matrix-vector multiplication
  29. *
  30. * @param transA [IN] transpose type of matrix A
  31. * @param m [IN] number of rows of matrix A
  32. * @param n [IN] number of columns of matrix A
  33. * @param alpha [IN] pointer to scalar used for multiplication.
  34. * of same type as dataTypeC
  35. * @param a [IN] pointer to matrix A
  36. * @param lda [IN] leading dimension used to store the matrix A
  37. * @param dataTypeA [IN] datatype of matrix A
  38. * @param x [IN] pointer to vector x
  39. * @param incx [IN] stride between consecutive elements of vector x
  40. * @param dataTypeX [IN] datatype of vector x
  41. * @param beta [IN] pointer to scalar used for multiplication.
  42. * of same type as dataTypeC If beta == 0,
  43. * then y does not have to be a valid input
  44. * @param y [IN|OUT] pointer to vector y
  45. * @param incy [IN] stride between consecutive elements of vector y
  46. * @param dataTypeY [IN] datatype of vector y
  47. * @param type [IN] computation type
  48. * @param stream [IN] stream
  49. * @retval ACL_ERROR_NONE The function is successfully executed.
  50. * @retval OtherValues Failure
  51. */
  52. ACL_FUNC_VISIBILITY aclError aclblasGemvEx(aclTransType transA, int m, int n,
  53. const void *alpha, const void *a, int lda, aclDataType dataTypeA,
  54. const void *x, int incx, aclDataType dataTypeX,
  55. const void *beta, void *y, int incy, aclDataType dataTypeY,
  56. aclComputeType type, aclrtStream stream);
  57. /**
  58. * @ingroup AscendCL
  59. * @brief create a handle for performing the matrix-vector multiplication
  60. *
  61. * @param transA [IN] transpose type of matrix A
  62. * @param m [IN] number of rows of matrix A
  63. * @param n [IN] number of columns of matrix A
  64. * @param dataTypeA [IN] datatype of matrix A
  65. * @param dataTypeX [IN] datatype of vector x
  66. * @param dataTypeY [IN] datatype of vector y
  67. * @param type [IN] computation type
  68. * @param handle [OUT] pointer to the pointer to the handle
  69. * @retval ACL_ERROR_NONE The function is successfully executed.
  70. * @retval OtherValues Failure
  71. */
  72. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemvEx(aclTransType transA,
  73. int m,
  74. int n,
  75. aclDataType dataTypeA,
  76. aclDataType dataTypeX,
  77. aclDataType dataTypeY,
  78. aclComputeType type,
  79. aclopHandle **handle);
  80. /**
  81. * @ingroup AscendCL
  82. * @brief perform the matrix-vector multiplication
  83. *
  84. * @param transA [IN] transpose type of matrix A
  85. * @param m [IN] number of rows of matrix A
  86. * @param n [IN] number of columns of matrix A
  87. * @param alpha [IN] pointer to scalar used for multiplication
  88. * @param a [IN] pointer to matrix A
  89. * @param lda [IN] leading dimension used to store the matrix A
  90. * @param x [IN] pointer to vector x
  91. * @param incx [IN] stride between consecutive elements of vector x
  92. * @param beta [IN] pointer to scalar used for multiplication.
  93. * If beta value == 0,
  94. * then y does not have to be a valid input
  95. * @param y [IN|OUT] pointer to vector y
  96. * @param incy [IN] stride between consecutive elements of vector y
  97. * @param type [IN] computation type
  98. * @param stream [IN] stream
  99. * @retval ACL_ERROR_NONE The function is successfully executed.
  100. * @retval OtherValues Failure
  101. */
  102. ACL_FUNC_VISIBILITY aclError aclblasHgemv(aclTransType transA,
  103. int m,
  104. int n,
  105. const aclFloat16 *alpha,
  106. const aclFloat16 *a,
  107. int lda,
  108. const aclFloat16 *x,
  109. int incx,
  110. const aclFloat16 *beta,
  111. aclFloat16 *y,
  112. int incy,
  113. aclComputeType type,
  114. aclrtStream stream);
  115. /**
  116. * @ingroup AscendCL
  117. * @brief create a handle for performing the matrix-vector multiplication
  118. *
  119. * @param transA [IN] transpose type of matrix A
  120. * @param m [IN] number of rows of matrix A
  121. * @param n [IN] number of columns of matrix A
  122. * @param type [IN] computation type
  123. * @param handle [OUT] pointer to the pointer to the handle
  124. * @retval ACL_ERROR_NONE The function is successfully executed.
  125. * @retval OtherValues Failure
  126. */
  127. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemv(aclTransType transA,
  128. int m,
  129. int n,
  130. aclComputeType type,
  131. aclopHandle **handle);
  132. /**
  133. * @ingroup AscendCL
  134. * @brief perform the matrix-vector multiplication
  135. *
  136. * @param transA [IN] transpose type of matrix A
  137. * @param m [IN] number of rows of matrix A
  138. * @param n [IN] number of columns of matrix A
  139. * @param alpha [IN] pointer to scalar used for multiplication
  140. * @param a [IN] pointer to matrix A
  141. * @param lda [IN] leading dimension used to store the matrix A
  142. * @param x [IN] pointer to vector x
  143. * @param incx [IN] stride between consecutive elements of vector x
  144. * @param beta [IN] pointer to scalar used for multiplication.
  145. * If beta value == 0,
  146. * then y does not have to be a valid input
  147. * @param y [IN|OUT] pointer to vector y
  148. * @param incy [IN] stride between consecutive elements of vector y
  149. * @param type [IN] computation type
  150. * @param stream [IN] stream
  151. * @retval ACL_ERROR_NONE The function is successfully executed.
  152. * @retval OtherValues Failure
  153. */
  154. ACL_FUNC_VISIBILITY aclError aclblasS8gemv(aclTransType transA,
  155. int m,
  156. int n,
  157. const int32_t *alpha,
  158. const int8_t *a,
  159. int lda,
  160. const int8_t *x,
  161. int incx,
  162. const int32_t *beta,
  163. int32_t *y,
  164. int incy,
  165. aclComputeType type,
  166. aclrtStream stream);
  167. /**
  168. * @ingroup AscendCL
  169. * @brief create a handle for performing the matrix-vector multiplication
  170. *
  171. * @param transA [IN] transpose type of matrix A
  172. * @param m [IN] number of rows of matrix A
  173. * @param n [IN] number of columns of matrix A
  174. * @param handle [OUT] pointer to the pointer to the handle
  175. * @param type [IN] computation type
  176. * @retval ACL_ERROR_NONE The function is successfully executed.
  177. * @retval OtherValues Failure
  178. */
  179. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemv(aclTransType transA,
  180. int m,
  181. int n,
  182. aclComputeType type,
  183. aclopHandle **handle);
  184. /**
  185. * @ingroup AscendCL
  186. * @brief perform the matrix-matrix multiplication
  187. *
  188. * @param transA [IN] transpose type of matrix A
  189. * @param transB [IN] transpose type of matrix B
  190. * @param transC [IN] transpose type of matrix C
  191. * @param m [IN] number of rows of matrix A and matrix C
  192. * @param n [IN] number of columns of matrix B and matrix C
  193. * @param k [IN] number of columns of matrix A and rows of matrix B
  194. * @param alpha [IN] pointer to scalar used for multiplication. of same type as dataTypeC
  195. * @param matrixA [IN] pointer to matrix A
  196. * @param lda [IN] leading dimension array used to store matrix A
  197. * @param dataTypeA [IN] datatype of matrix A
  198. * @param matrixB [IN] pointer to matrix B
  199. * @param ldb [IN] leading dimension array used to store matrix B
  200. * @param dataTypeB [IN] datatype of matrix B
  201. * @param beta [IN] pointer to scalar used for multiplication.
  202. * of same type as dataTypeC If beta == 0,
  203. * then matrixC does not have to be a valid input
  204. * @param matrixC [IN|OUT] pointer to matrix C
  205. * @param ldc [IN] leading dimension array used to store matrix C
  206. * @param dataTypeC [IN] datatype of matrix C
  207. * @param type [IN] computation type
  208. * @param stream [IN] stream
  209. * @retval ACL_ERROR_NONE The function is successfully executed.
  210. * @retval OtherValues Failure
  211. */
  212. ACL_FUNC_VISIBILITY aclError aclblasGemmEx(aclTransType transA,
  213. aclTransType transB,
  214. aclTransType transC,
  215. int m,
  216. int n,
  217. int k,
  218. const void *alpha,
  219. const void *matrixA,
  220. int lda,
  221. aclDataType dataTypeA,
  222. const void *matrixB,
  223. int ldb,
  224. aclDataType dataTypeB,
  225. const void *beta,
  226. void *matrixC,
  227. int ldc,
  228. aclDataType dataTypeC,
  229. aclComputeType type,
  230. aclrtStream stream);
  231. /**
  232. * @ingroup AscendCL
  233. * @brief create a handle for performing the matrix-matrix multiplication
  234. *
  235. * @param transA [IN] transpose type of matrix A
  236. * @param transB [IN] transpose type of matrix B
  237. * @param transC [IN] transpose type of matrix C
  238. * @param m [IN] number of rows of matrix A and matrix C
  239. * @param n [IN] number of columns of matrix B and matrix C
  240. * @param k [IN] number of columns of matrix A and rows of matrix B
  241. * @param dataTypeA [IN] datatype of matrix A
  242. * @param dataTypeB [IN] datatype of matrix B
  243. * @param dataTypeC [IN] datatype of matrix C
  244. * @param type [IN] computation type
  245. * @param handle [OUT] pointer to the pointer to the handle
  246. * @param type [IN] computation type
  247. * @retval ACL_ERROR_NONE The function is successfully executed.
  248. * @retval OtherValues Failure
  249. */
  250. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemmEx(aclTransType transA,
  251. aclTransType transB,
  252. aclTransType transC,
  253. int m,
  254. int n,
  255. int k,
  256. aclDataType dataTypeA,
  257. aclDataType dataTypeB,
  258. aclDataType dataTypeC,
  259. aclComputeType type,
  260. aclopHandle **handle);
  261. /**
  262. * @ingroup AscendCL
  263. * @brief perform the matrix-matrix multiplication
  264. *
  265. * @param transA [IN] transpose type of matrix A
  266. * @param transB [IN] transpose type of matrix B
  267. * @param transC [IN] transpose type of matrix C
  268. * @param m [IN] number of rows of matrix A and matrix C
  269. * @param n [IN] number of columns of matrix B and matrix C
  270. * @param k [IN] number of columns of matrix A and rows of matrix B
  271. * @param alpha [IN] pointer to scalar used for multiplication
  272. * @param matrixA [IN] pointer to matrix A
  273. * @param lda [IN] leading dimension used to store the matrix A
  274. * @param matrixB [IN] pointer to matrix B
  275. * @param ldb [IN] leading dimension used to store the matrix B
  276. * @param beta [IN] pointer to scalar used for multiplication.
  277. * If beta value == 0,
  278. * then matrixC does not have to be a valid input
  279. * @param matrixC [IN|OUT] pointer to matrix C
  280. * @param ldc [IN] leading dimension used to store the matrix C
  281. * @param type [IN] computation type
  282. * @param stream [IN] stream
  283. * @retval ACL_ERROR_NONE The function is successfully executed.
  284. * @retval OtherValues Failure
  285. */
  286. ACL_FUNC_VISIBILITY aclError aclblasHgemm(aclTransType transA,
  287. aclTransType transB,
  288. aclTransType transC,
  289. int m,
  290. int n,
  291. int k,
  292. const aclFloat16 *alpha,
  293. const aclFloat16 *matrixA,
  294. int lda,
  295. const aclFloat16 *matrixB,
  296. int ldb,
  297. const aclFloat16 *beta,
  298. aclFloat16 *matrixC,
  299. int ldc,
  300. aclComputeType type,
  301. aclrtStream stream);
  302. /**
  303. * @ingroup AscendCL
  304. * @brief create a handle for performing the matrix-matrix multiplication
  305. *
  306. * @param transA [IN] transpose type of matrix A
  307. * @param transB [IN] transpose type of matrix B
  308. * @param transC [IN] transpose type of matrix C
  309. * @param m [IN] number of rows of matrix A and matrix C
  310. * @param n [IN] number of columns of matrix B and matrix C
  311. * @param k [IN] number of columns of matrix A and rows of matrix B
  312. * @param type [IN] computation type
  313. * @param handle [OUT] pointer to the pointer to the handle
  314. * @retval ACL_ERROR_NONE The function is successfully executed.
  315. * @retval OtherValues Failure
  316. */
  317. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemm(aclTransType transA,
  318. aclTransType transB,
  319. aclTransType transC,
  320. int m,
  321. int n,
  322. int k,
  323. aclComputeType type,
  324. aclopHandle **handle);
  325. /**
  326. * @ingroup AscendCL
  327. * @brief perform the matrix-matrix multiplication
  328. *
  329. * @param transA [IN] transpose type of matrix A
  330. * @param transB [IN] transpose type of matrix B
  331. * @param transC [IN] transpose type of matrix C
  332. * @param m [IN] number of rows of matrix A and matrix C
  333. * @param n [IN] number of columns of matrix B and matrix C
  334. * @param k [IN] number of columns of matrix A and rows of matrix B
  335. * @param alpha [IN] pointer to scalar used for multiplication
  336. * @param matrixA [IN] pointer to matrix A
  337. * @param lda [IN] leading dimension used to store the matrix A
  338. * @param matrixB [IN] pointer to matrix B
  339. * @param ldb [IN] leading dimension used to store the matrix B
  340. * @param beta [IN] pointer to scalar used for multiplication.
  341. * If beta value == 0,
  342. * then matrixC does not have to be a valid input
  343. * @param matrixC [IN|OUT] pointer to matrix C
  344. * @param ldc [IN] leading dimension used to store the matrix C
  345. * @param type [IN] computation type
  346. * @param stream [IN] stream
  347. * @retval ACL_ERROR_NONE The function is successfully executed.
  348. * @retval OtherValues Failure
  349. */
  350. ACL_FUNC_VISIBILITY aclError aclblasS8gemm(aclTransType transA,
  351. aclTransType transB,
  352. aclTransType transC,
  353. int m,
  354. int n,
  355. int k,
  356. const int32_t *alpha,
  357. const int8_t *matrixA,
  358. int lda,
  359. const int8_t *matrixB,
  360. int ldb,
  361. const int32_t *beta,
  362. int32_t *matrixC,
  363. int ldc,
  364. aclComputeType type,
  365. aclrtStream stream);
  366. /**
  367. * @ingroup AscendCL
  368. * @brief create a handle for performing the matrix-matrix multiplication
  369. *
  370. * @param transA [IN] transpose type of matrix A
  371. * @param transB [IN] transpose type of matrix B
  372. * @param transC [IN] transpose type of matrix C
  373. * @param m [IN] number of rows of matrix A and matrix C
  374. * @param n [IN] number of columns of matrix B and matrix C
  375. * @param k [IN] number of columns of matrix A and rows of matrix B
  376. * @param type [IN] computation type
  377. * @param handle [OUT] pointer to the pointer to the handle
  378. * @retval ACL_ERROR_NONE The function is successfully executed.
  379. * @retval OtherValues Failure
  380. */
  381. ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemm(aclTransType transA,
  382. aclTransType transB,
  383. aclTransType transC,
  384. int m,
  385. int n,
  386. int k,
  387. aclComputeType type,
  388. aclopHandle **handle);
  389. #ifdef __cplusplus
  390. }
  391. #endif
  392. #endif // INC_EXTERNAL_ACL_OPS_ACL_CBLAS_H_

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台