You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mssa-2021-006.patch 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. From 5aab6599e7280d2512a87434c174f13a0a2e7008 Mon Sep 17 00:00:00 2001
  2. From: lzk <liuzhongkai2@huawei.com>
  3. Date: Fri, 21 May 2021 01:25:06 -0700
  4. Subject: [PATCH] array cross the border
  5. ---
  6. .../cpu/nnacl/infer/transpose_infer.c | 70 +++++++++++--------
  7. 1 file changed, 40 insertions(+), 30 deletions(-)
  8. diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c
  9. index 04da736190..b1460bc8be 100644
  10. --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c
  11. +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c
  12. @@ -26,6 +26,45 @@ bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const si
  13. return true;
  14. }
  15. +int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) {
  16. + if (perms_num == 4) {
  17. + const int nchw2nhwc[4] = {0, 2, 3, 1};
  18. + const int nhwc2nchw[4] = {0, 3, 1, 2};
  19. + const int trans3d[3] = {0, 2, 1};
  20. + if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
  21. + output->format_ = Format_NHWC;
  22. + } else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
  23. + output->format_ = Format_NCHW;
  24. + }
  25. + // though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
  26. + if (input->shape_size_ == 3) {
  27. + ShapeSet(perm, &perm_size, trans3d, 3);
  28. + }
  29. + }
  30. + // set output shape
  31. + size_t in_shape_size = input->shape_size_;
  32. + output->shape_size_ = in_shape_size;
  33. + if (perm_size == 0) {
  34. + for (size_t i = 0; i < in_shape_size; ++i) {
  35. + out_shape[in_shape_size - i - 1] = input->shape_[i];
  36. + }
  37. + } else if (perm_size != in_shape_size) {
  38. + for (size_t i = 0; i < in_shape_size; ++i) {
  39. + out_shape[i] = input->shape_[i];
  40. + }
  41. + } else {
  42. + output->shape_size_ = perm_size;
  43. + for (size_t i = 0; i < perm_size; ++i) {
  44. + if (perm[i] >= input->shape_size_) {
  45. + return NNACL_ERR;
  46. + } else {
  47. + out_shape[i] = input->shape_[perm[i]];
  48. + }
  49. + }
  50. + }
  51. + return NNACL_OK;
  52. +}
  53. +
  54. int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
  55. OpParameter *parameter) {
  56. #ifdef Debug
  57. @@ -60,38 +99,9 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
  58. for (size_t i = 0; i < perms_num; i++) {
  59. ShapePush(perm, &perm_size, perm_data[i]);
  60. }
  61. - const int nchw2nhwc[4] = {0, 2, 3, 1};
  62. - const int nhwc2nchw[4] = {0, 3, 1, 2};
  63. - const int trans3d[3] = {0, 2, 1};
  64. - if (perms_num == 4) {
  65. - if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
  66. - output->format_ = Format_NHWC;
  67. - } else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
  68. - output->format_ = Format_NCHW;
  69. - }
  70. - // though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
  71. - if (input->shape_size_ == 3) {
  72. - ShapeSet(perm, &perm_size, trans3d, 3);
  73. - }
  74. - }
  75. // set output shape
  76. int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0};
  77. - size_t in_shape_size = input->shape_size_;
  78. - output->shape_size_ = in_shape_size;
  79. - if (perm_size == 0) {
  80. - for (size_t i = 0; i < in_shape_size; ++i) {
  81. - out_shape[in_shape_size - i - 1] = input->shape_[i];
  82. - }
  83. - } else if (perm_size != in_shape_size) {
  84. - for (size_t i = 0; i < in_shape_size; ++i) {
  85. - out_shape[i] = input->shape_[i];
  86. - }
  87. - } else {
  88. - output->shape_size_ = perm_size;
  89. - for (size_t i = 0; i < perm_size; ++i) {
  90. - out_shape[i] = input->shape_[perm[i]];
  91. - }
  92. - }
  93. + SetOutputShape(perms_num, input, output, perm, perm_size, out_shape);
  94. SetShapeArray(output, out_shape, output->shape_size_);
  95. return NNACL_OK;
  96. }
  97. --
  98. 2.17.1