You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_test_utils.h 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #pragma once
  2. #include "megdnn/oprs/nn.h"
  3. #include "src/common/utils.h"
  4. #include "src/cuda/cudnn_with_check.h"
  5. #include "test/common/checker.h"
  6. #include "test/common/conv_bias.h"
  7. #include "test/common/tensor.h"
  8. #include "test/common/workspace_wrapper.h"
  9. #include "test/cuda/benchmark.h"
  10. #include "test/cuda/fixture.h"
  11. #include "test/cuda/utils.h"
  12. #define V1(x) #x
  13. #define V(x) V1(x)
  14. namespace megdnn {
  15. namespace test {
  16. namespace conv {
  17. #if MEGDNN_WITH_BENCHMARK
  18. struct BenchArgs {
  19. size_t n, ci, hi, wi, co, f, s;
  20. };
  21. std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64);
  22. std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16);
  23. std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16);
  24. void benchmark_target_algo(
  25. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  26. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  27. const char* algo = nullptr,
  28. param::ConvBias::Format format = param::ConvBias::Format::NCHW4);
  29. void benchmark_target_algo_with_cudnn_tsc(
  30. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  31. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  32. const char* algo = nullptr,
  33. param::ConvBias::Format format = param::ConvBias::Format::NCHW4,
  34. bool with_cudnn = true, const char* change_cudnn_algo = nullptr,
  35. param::ConvBias::Format change_cudnn_format = param::ConvBias::Format::NCHW4,
  36. DType change_cudnn_src_dtype = dtype::Int8(),
  37. DType change_cudnn_filter_dtype = dtype::Int8(),
  38. DType change_cudnn_bias_dtype = dtype::Int8(),
  39. DType change_cudnn_dst_dtype = dtype::Int8());
  40. #endif
  41. } // namespace conv
  42. } // namespace test
  43. } // namespace megdnn
  44. #undef V1
  45. #undef V