You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.h 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #pragma once
  2. #include "test/common/opr_proxy.h"
  3. namespace megdnn {
  4. namespace test {
  5. struct OprProxyIndexingMultiAxisVecHelper {
  6. size_t axes[TensorLayout::MAX_NDIM];
  7. /*!
  8. * \brief OprProxy for indexing multi-vec family oprs
  9. *
  10. * \param init_axes axes that are indexed
  11. */
  12. OprProxyIndexingMultiAxisVecHelper(std::initializer_list<size_t> init_axes = {}) {
  13. size_t i = 0;
  14. for (auto ax : init_axes)
  15. axes[i++] = ax;
  16. }
  17. OprProxyIndexingMultiAxisVecHelper(SmallVector<size_t> init_axes) {
  18. size_t i = 0;
  19. for (auto ax : init_axes)
  20. axes[i++] = ax;
  21. }
  22. IndexingMultiAxisVec::IndexDesc make_index_desc(
  23. const TensorNDArray& tensors) const {
  24. megdnn_assert(tensors.size() >= 3);
  25. IndexingMultiAxisVec::IndexDesc ret;
  26. ret.resize(tensors.size() - 2);
  27. for (size_t i = 2; i < tensors.size(); ++i) {
  28. ret[i - 2] = {axes[i - 2], tensors[i]};
  29. }
  30. return ret;
  31. }
  32. size_t get_index_ndim(const TensorNDArray& tensors) const {
  33. megdnn_assert(tensors.size() >= 3);
  34. size_t ndim = 0;
  35. for (size_t i = 2; i < tensors.size(); ++i) {
  36. ndim = std::max(tensors[i].layout.ndim, ndim);
  37. }
  38. return ndim;
  39. }
  40. IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout(
  41. const TensorLayoutArray& layouts) const {
  42. megdnn_assert(layouts.size() >= 3);
  43. IndexingMultiAxisVec::IndexDescLayoutOnly ret;
  44. ret.resize(layouts.size() - 2);
  45. for (size_t i = 2; i < layouts.size(); ++i) {
  46. ret[i - 2] = {axes[i - 2], layouts[i]};
  47. }
  48. return ret;
  49. }
  50. };
  51. template <>
  52. struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
  53. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
  54. void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const {
  55. WorkspaceWrapper W(
  56. opr->handle(), opr->get_workspace_in_bytes(
  57. tensors[1].layout, axes, tensors.size() - 2,
  58. get_index_ndim(tensors)));
  59. opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace());
  60. }
  61. void deduce_layout(IndexingMultiAxisVec* opr, TensorLayoutArray& layouts) {
  62. opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]);
  63. }
  64. };
  65. template <>
  66. struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
  67. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
  68. void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const {
  69. WorkspaceWrapper W(
  70. opr->handle(), opr->get_workspace_in_bytes(
  71. tensors[1].layout, axes, tensors.size() - 2,
  72. get_index_ndim(tensors)));
  73. opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
  74. }
  75. void deduce_layout(IndexingIncrMultiAxisVec*, TensorLayoutArray&) {}
  76. };
  77. template <>
  78. struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelper {
  79. using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper;
  80. void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const {
  81. WorkspaceWrapper W(
  82. opr->handle(), opr->get_workspace_in_bytes(
  83. tensors[1].layout, axes, tensors.size() - 2,
  84. get_index_ndim(tensors)));
  85. opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace());
  86. }
  87. void deduce_layout(IndexingSetMultiAxisVec*, TensorLayoutArray&) {}
  88. };
  89. } // namespace test
  90. } // namespace megdnn
  91. // vim: syntax=cpp.doxygen