You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mask_conv.h 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. /**
  2. * \file dnn/test/common/mask_conv.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "test/common/benchmarker.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/rng.h"
  15. #pragma once
  16. namespace {
  17. using namespace megdnn;
  18. using namespace test;
  19. std::vector<std::vector<int>> get_args() {
  20. std::vector<std::vector<int>> args;
  21. args.push_back({2, 1, 1, 5, 5, 3, 3, 1, 1, 0, 0, 1, 1});
  22. args.push_back({1, 2, 3, 24, 24, 3, 5, 1, 1, 0, 0, 1, 1});
  23. args.push_back({20, 3, 4, 24, 21, 5, 3, 1, 1, 0, 0, 1, 1});
  24. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 0, 0, 1, 1});
  25. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 2, 2, 1, 1});
  26. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 1, 1});
  27. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 2, 3});
  28. args.push_back({20, 3, 4, 24, 21, 5, 3, 2, 2, 1, 2, 3, 2});
  29. args.push_back({2, 108, 108, 14, 14, 3, 3, 1, 1, 0, 0, 1, 1});
  30. args.push_back({2, 108, 108, 14, 14, 3, 3, 1, 1, 2, 2, 1, 1});
  31. args.push_back({2, 108, 108, 14, 14, 3, 3, 2, 2, 2, 2, 1, 1});
  32. args.push_back({2, 108, 108, 14, 14, 3, 3, 2, 2, 0, 0, 1, 1});
  33. args.push_back({2, 3, 3, 224, 224, 3, 3, 1, 1, 0, 0, 1, 1});
  34. args.push_back({2, 3, 3, 224, 224, 3, 3, 2, 2, 0, 0, 1, 1});
  35. return args;
  36. }
  37. void mask_conv_test(Handle* handle) {
  38. auto run = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW,
  39. size_t FH, size_t FW, size_t SH, size_t SW, size_t PH,
  40. size_t PW, size_t DH, size_t DW) {
  41. size_t OH = (IH + 2 * PH - ((FH - 1) * DH + 1)) / SH + 1;
  42. size_t OW = (IW + 2 * PW - ((FW - 1) * DW + 1)) / SW + 1;
  43. Checker<MaskConvolution> checker(handle);
  44. using Param = param::Convolution;
  45. Param param(Param::Mode::CROSS_CORRELATION,
  46. // pad
  47. PH, PW,
  48. // stride
  49. SH, SW,
  50. // dilate
  51. DH, DW, Param::Sparse::DENSE, Param::Format::NCHW);
  52. TensorShape src_shape({N, IC, IH, IW}), filter_shape({OC, IC, FH, FW}),
  53. mask({OH, OW}), dst({});
  54. auto rng = std::make_unique<BernoulliRNG>(0.5);
  55. checker.set_param(param);
  56. checker.set_dtype(2, dtype::Int8())
  57. .execs({src_shape, filter_shape, mask, dst});
  58. checker.set_dtype(2, dtype::Int16())
  59. .execs({src_shape, filter_shape, mask, dst});
  60. checker.set_dtype(2, dtype::Int32())
  61. .execs({src_shape, filter_shape, mask, dst});
  62. };
  63. auto test_args = get_args();
  64. for (auto&& arg : test_args) {
  65. run(arg[0], arg[1], arg[2], arg[3], arg[4], arg[5], arg[6], arg[7],
  66. arg[8], arg[9], arg[10], arg[11], arg[12]);
  67. }
  68. }
  69. #if MEGDNN_WITH_BENCHMARK
  70. void mask_conv_benchmark(Handle* handle) {
  71. auto benchmark = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW,
  72. size_t FH, size_t FW, size_t SH, size_t SW, size_t PH,
  73. size_t PW, size_t DH, size_t DW) {
  74. size_t OH = (IH + 2 * PH - ((FH - 1) * DH + 1)) / SH + 1;
  75. size_t OW = (IW + 2 * PW - ((FW - 1) * DW + 1)) / SW + 1;
  76. Benchmarker<MaskConvolution> benchmark_fallback(handle);
  77. Benchmarker<Convolution> benchmark_naive(handle);
  78. using Param = param::Convolution;
  79. Param param(Param::Mode::CROSS_CORRELATION,
  80. // pad
  81. PH, PW,
  82. // stride
  83. SH, SW,
  84. // dilate
  85. DH, DW, Param::Sparse::DENSE, Param::Format::NCHW);
  86. TensorShape src_shape({N, IC, IH, IW}), filter_shape({OC, IC, FH, FW}),
  87. mask({OH, OW}), dst({});
  88. benchmark_fallback.set_param(param)
  89. .set_dtype(2, dtype::Int32())
  90. .set_times(20);
  91. printf("Execing mask conv: \n");
  92. #define test(p) \
  93. benchmark_fallback.set_rng(2, new BernoulliRNG(p)) \
  94. .execs({src_shape, filter_shape, mask, dst})
  95. for (auto p : {0.1, 0.2, 0.3, 0.4, 0.5, 0.99})
  96. test(p);
  97. printf("Execing normal conv: \n");
  98. benchmark_naive.set_param(param).set_times(20).execs(
  99. {src_shape, filter_shape, dst});
  100. #undef test
  101. };
  102. auto test_args = get_args();
  103. for (auto&& arg : test_args) {
  104. benchmark(arg[0], arg[1], arg[2], arg[3], arg[4], arg[5], arg[6],
  105. arg[7], arg[8], arg[9], arg[10], arg[11], arg[12]);
  106. }
  107. }
  108. #endif
  109. } // namespace

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台