You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/opr_param_defs.h"
  4. #include "test/common/checker.h"
  5. #include "src/fallback/conv_bias/opr_impl.h"
  6. #include <regex>
  7. namespace megdnn {
  8. namespace test {
  9. namespace conv_bias {
  10. struct TestArg {
  11. param::ConvBias param;
  12. TensorShape src, filter, bias;
  13. TestArg(param::ConvBias param, TensorShape src, TensorShape filter,
  14. TensorShape bias)
  15. : param(param), src(src), filter(filter), bias(bias) {}
  16. };
  17. std::vector<TestArg> get_args();
  18. std::vector<TestArg> get_args_1x1();
  19. std::vector<TestArg> get_chanwise_args();
  20. std::vector<TestArg> get_winograd_args(size_t kernel_size);
  21. std::vector<TestArg> get_winograd_mk_packed_args(size_t pack_size = 4);
  22. std::vector<TestArg> get_quantized_winograd_mk_packed_args(
  23. size_t pack_size = 4, bool compute_float32 = false);
  24. std::vector<TestArg> get_quantized_args_with_nlmode(
  25. param::ConvBias::NonlineMode nlmode);
  26. std::vector<TestArg> get_quantized_args();
  27. std::vector<TestArg> get_int8_nchw4_args(size_t kernel_size);
  28. std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size);
  29. std::vector<TestArg> get_int8_nchw4_small_channel_args(size_t kernel_size);
  30. std::vector<TestArg> get_int8_nchw4_small_channel_args_check_bounds(size_t kernel_size);
  31. std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size);
  32. std::vector<TestArg> get_int8_chwn4_args(size_t kernel_size);
  33. std::vector<TestArg> get_int8_chwn4_args_check_bounds(size_t kernel_size);
  34. std::vector<TestArg> get_int8_chwn4_small_channel_args(size_t kernel_size);
  35. std::vector<TestArg> get_int8_chwn4_small_channel_args_check_bounds(size_t kernel_size);
  36. std::vector<TestArg> get_int8_chwn4_args_small_batch(size_t kernel_size);
  37. std::vector<TestArg> get_int8_nchw4_tensorcore_args(size_t kernel_size);
  38. std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size);
  39. std::vector<TestArg> get_int8_nchw44_args(
  40. size_t kernel_size, size_t pack_size, bool compute_float32 = false,
  41. bool group_mode = false);
  42. void check_conv_bias_preprocess(
  43. std::vector<conv_bias::TestArg> args, Handle* handle, RNG* rng, float epsilon,
  44. DType type0, DType type1, DType type2, DType type3, const char* algo_name);
  45. template <typename Opr>
  46. using ConvBiasAlgoChecker = AlgoChecker<Opr>;
  47. void check_conv_bias(
  48. DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype,
  49. Handle* handle, const char* algo = nullptr,
  50. param::ConvBias::Format format = param::ConvBias::Format::NCHW4,
  51. const std::vector<TestArg>& args = {}, bool fuse_z = false,
  52. bool stable_test = false);
  53. #if MEGDNN_WITH_BENCHMARK
  54. std::vector<conv_bias::TestArg> get_winograd_benchmark_args(
  55. size_t kernel, size_t pack_size = 1);
  56. void benchmark_winograd(
  57. const char* algo_name, megdnn::Handle* handle, size_t kernel,
  58. size_t pack_size = 1);
  59. #endif // MEGDNN_WITH_BENCHMARK
  60. std::vector<megdnn::test::conv_bias::TestArg> get_conv_bias_args(
  61. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  62. bool no_nonlinemode, bool quantized_nlmod = false,
  63. bool only_broadcast_bias = false);
  64. std::vector<megdnn::test::conv_bias::TestArg> get_conv_bias_1x1_args(
  65. bool no_bias, bool no_nonlinemode, bool quantized_nlmod = false,
  66. bool only_broadcast_bias = false);
  67. void check_conv_bias(
  68. std::vector<megdnn::test::conv_bias::TestArg> args, megdnn::Handle* handle,
  69. const char* algo_name);
  70. void checker_conv_bias_int8x8x16(
  71. std::vector<megdnn::test::conv_bias::TestArg> args, megdnn::Handle* handle,
  72. const char* algo_name);
  73. void checker_conv_bias_common(
  74. std::vector<conv_bias::TestArg> args, Handle* handle, RNG* rng, float epsilon,
  75. DType type0, DType type1, DType type2, DType type3, const char* algo_name);
  76. std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
  77. std::vector<size_t> kernel_vec,
  78. std::vector<param::ConvBias::NonlineMode> nlmode_vec,
  79. std::vector<megdnn::BiasMode> biasmode_vec, size_t stride, bool no_pad = false,
  80. bool is_input_nchw = false, bool is_nchw44_dot = false);
  81. void checker_conv_bias_mul_int8x8x32(
  82. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name);
  83. void checker_conv_bias_int8x8x32_preprocess(
  84. std::vector<conv_bias::TestArg> args, Handle* handle, const char* algo_name);
  85. #define FULL_NLMODE \
  86. { \
  87. param::ConvBias::NonlineMode::IDENTITY, param::ConvBias::NonlineMode::RELU, \
  88. param::ConvBias::NonlineMode::H_SWISH, \
  89. param::ConvBias::NonlineMode::SIGMOID \
  90. }
  91. #define QUAN_NLMODE \
  92. { \
  93. param::ConvBias::NonlineMode::IDENTITY, param::ConvBias::NonlineMode::RELU, \
  94. param::ConvBias::NonlineMode::H_SWISH \
  95. }
  96. #define ONLY_IDENTITY_NLMODE \
  97. { param::ConvBias::NonlineMode::IDENTITY }
  98. #define ALL_BIASMODE \
  99. { \
  100. megdnn::BiasMode::NO_BIAS, megdnn::BiasMode::BROADCAST_CHANNEL_BIAS, \
  101. megdnn::BiasMode::BIAS \
  102. }
  103. #define BR_AND_NO_BIASMODE \
  104. { megdnn::BiasMode::NO_BIAS, megdnn::BiasMode::BROADCAST_CHANNEL_BIAS }
  105. #define BR_AND_BIAS_BIASMODE \
  106. { megdnn::BiasMode::NO_BIAS, megdnn::BiasMode::BIAS }
  107. #define ONLY_BR_BIASMODE \
  108. { megdnn::BiasMode::BROADCAST_CHANNEL_BIAS }
  109. #define ONLY_NO_BIASMODE \
  110. { megdnn::BiasMode::NO_BIAS }
  111. #define ONLY_BIAS_BIASMODE \
  112. { megdnn::BiasMode::BIAS }
  113. } // namespace conv_bias
  114. } // namespace test
  115. } // namespace megdnn
  116. // vim: syntax=cpp.doxygen