You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_test_utils.h 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /**
  2. * \file dnn/test/cuda/conv_test_utils.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs/nn.h"
  14. #include "src/common/utils.h"
  15. #include "src/cuda/cudnn_with_check.h"
  16. #include "test/common/checker.h"
  17. #include "test/common/conv_bias.h"
  18. #include "test/common/tensor.h"
  19. #include "test/common/workspace_wrapper.h"
  20. #include "test/cuda/benchmark.h"
  21. #include "test/cuda/fixture.h"
  22. #include "test/cuda/utils.h"
  23. #define V1(x) #x
  24. #define V(x) V1(x)
  25. namespace megdnn {
  26. namespace test {
  27. namespace conv {
  28. #if MEGDNN_WITH_BENCHMARK
  29. struct BenchArgs {
  30. size_t n, ci, hi, wi, co, f, s;
  31. };
  32. std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64);
  33. std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16);
  34. std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16);
  35. void benchmark_target_algo(
  36. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  37. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  38. const char* algo = nullptr,
  39. param::ConvBias::Format format = param::ConvBias::Format::NCHW4);
  40. void benchmark_target_algo_with_cudnn_tsc(
  41. Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype,
  42. DType filter_dtype, DType bias_dtype, DType dst_dtype,
  43. const char* algo = nullptr,
  44. param::ConvBias::Format format = param::ConvBias::Format::NCHW4,
  45. bool with_cudnn = true, const char* change_cudnn_algo = nullptr,
  46. param::ConvBias::Format change_cudnn_format = param::ConvBias::Format::NCHW4,
  47. DType change_cudnn_src_dtype = dtype::Int8(),
  48. DType change_cudnn_filter_dtype = dtype::Int8(),
  49. DType change_cudnn_bias_dtype = dtype::Int8(),
  50. DType change_cudnn_dst_dtype = dtype::Int8());
  51. #endif
  52. } // namespace conv
  53. } // namespace test
  54. } // namespace megdnn
  55. #undef V1
  56. #undef V