You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deduce_layout_proxy.h 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * \file dnn/test/common/deduce_layout_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/utils.h"
  14. namespace megdnn {
  15. namespace test {
  16. template <typename Opr, size_t Arity, bool can_deduce_layout>
  17. struct DeduceLayoutProxy;
  18. template <typename Opr, size_t Arity>
  19. struct DeduceLayoutProxy<Opr, Arity, false> {
  20. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  21. };
  22. template <typename Opr>
  23. struct DeduceLayoutProxy<Opr, 2, true> {
  24. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  25. megdnn_assert(layouts.size() == 2);
  26. opr->deduce_layout(layouts[0], layouts[1]);
  27. }
  28. };
  29. template <typename Opr>
  30. struct DeduceLayoutProxy<Opr, 3, true> {
  31. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  32. megdnn_assert(layouts.size() == 3);
  33. opr->deduce_layout(layouts[0], layouts[1], layouts[2]);
  34. }
  35. };
  36. template <typename Opr>
  37. struct DeduceLayoutProxy<Opr, 4, true> {
  38. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  39. megdnn_assert(layouts.size() == 4);
  40. opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3]);
  41. }
  42. };
  43. template <typename Opr>
  44. struct DeduceLayoutProxy<Opr, 5, true> {
  45. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  46. megdnn_assert(layouts.size() == 5);
  47. opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]);
  48. }
  49. };
  50. template <typename Opr>
  51. struct DeduceLayoutProxy<Opr, 6, true> {
  52. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  53. megdnn_assert(layouts.size() == 6);
  54. opr->deduce_layout(
  55. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
  56. }
  57. };
  58. template <typename Opr>
  59. struct DeduceLayoutProxy<Opr, 5, false> {
  60. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  61. };
  62. template <typename Opr>
  63. struct DeduceLayoutProxy<Opr, 6, false> {
  64. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  65. };
  66. template <typename Opr>
  67. struct DeduceLayoutProxy<Opr, 6, true> {
  68. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  69. megdnn_assert(layouts.size() == 6);
  70. opr->deduce_layout(
  71. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
  72. }
  73. };
  74. template <typename Opr>
  75. struct DeduceLayoutProxy<Opr, 7, false> {
  76. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  77. };
  78. template <typename Opr>
  79. struct DeduceLayoutProxy<Opr, 8, true> {
  80. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  81. megdnn_assert(layouts.size() == 8);
  82. opr->deduce_layout(
  83. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  84. layouts[6], layouts[7]);
  85. }
  86. };
  87. template <typename Opr>
  88. struct DeduceLayoutProxy<Opr, 9, true> {
  89. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  90. megdnn_assert(layouts.size() == 9);
  91. opr->deduce_layout(
  92. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  93. layouts[6], layouts[7], layouts[8]);
  94. }
  95. };
  96. } // namespace test
  97. } // namespace megdnn
  98. // vim: syntax=cpp.doxygen