You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.h 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /**
  2. * \file dnn/test/common/group_local.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/opr_param_defs.h"
  13. #include "megdnn/basic_types.h"
  14. #include <cstddef>
  15. namespace megdnn {
  16. namespace test {
  17. namespace group_local {
  18. struct TestArg {
  19. param::Convolution param;
  20. size_t n, ic, ih, iw, groups, ocpg, oh, ow, fh, fw;
  21. TestArg(param::Convolution param, size_t n, size_t ic, size_t ih, size_t iw,
  22. size_t groups, size_t ocpg, size_t oh, size_t ow, size_t fh,
  23. size_t fw)
  24. : param(param),
  25. n(n),
  26. ic(ic),
  27. ih(ih),
  28. iw(iw),
  29. groups(groups),
  30. ocpg(ocpg),
  31. oh(oh),
  32. ow(ow),
  33. fh(fh),
  34. fw(fw) {
  35. param.sparse = param::Convolution::Sparse::GROUP;
  36. }
  37. TensorShape sshape() const { return {n, ic, ih, iw}; }
  38. TensorShape fshape() const {
  39. size_t icpg = ic / groups;
  40. return {groups, oh, ow, icpg, fh, fw, ocpg};
  41. }
  42. TensorShape dshape() {
  43. size_t oc = ocpg * groups;
  44. return {n, oc, oh, ow};
  45. }
  46. };
  47. static inline std::vector<TestArg> get_args_for_fp16() {
  48. std::vector<TestArg> test_args;
  49. test_args.emplace_back(
  50. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1,
  51. 1, 1, 1},
  52. 64, 16, 8, 7, 4, 4, 8, 7, 3, 3);
  53. test_args.emplace_back(
  54. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0,
  55. 0, 1, 1},
  56. 15, 15, 7, 7, 5, 3, 5, 5, 3, 3);
  57. test_args.emplace_back(
  58. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1,
  59. 1, 1, 1},
  60. 15, 15, 5, 5, 5, 3, 5, 5, 3, 3);
  61. test_args.emplace_back(
  62. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0,
  63. 0, 2, 2},
  64. 15, 15, 7, 7, 5, 3, 3, 3, 3, 3);
  65. /*! \warning: this operator need reduce values along the axis of IC, so this
  66. * will results in large error in fp16 situation. so in the test cases, we
  67. * use small IC values.
  68. */
  69. // clang-format off
  70. for (size_t N: {1, 2})
  71. for (size_t OC: {16, 32, 48, 64})
  72. {
  73. test_args.emplace_back(
  74. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION,
  75. 0, 0, 1, 1},
  76. N, 16, 7, 7, 4, OC / 4, 5, 5, 3, 3);
  77. }
  78. // clang-format on
  79. return test_args;
  80. }
  81. static inline std::vector<TestArg> get_args() {
  82. std::vector<TestArg> test_args;
  83. test_args.emplace_back(
  84. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1,
  85. 1, 1, 1},
  86. 64, 16, 8, 7, 4, 4, 8, 7, 3, 3);
  87. test_args.emplace_back(
  88. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0,
  89. 0, 1, 1},
  90. 15, 15, 7, 7, 5, 3, 5, 5, 3, 3);
  91. test_args.emplace_back(
  92. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1,
  93. 1, 1, 1},
  94. 15, 15, 5, 5, 5, 3, 5, 5, 3, 3);
  95. test_args.emplace_back(
  96. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0,
  97. 0, 2, 2},
  98. 15, 15, 7, 7, 5, 3, 3, 3, 3, 3);
  99. // clang-format off
  100. for (size_t N: {1, 2})
  101. for (size_t OC: {16, 32, 48, 64})
  102. {
  103. test_args.emplace_back(
  104. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION,
  105. 0, 0, 1, 1},
  106. N, 32, 7, 7, 4, OC / 4, 5, 5, 3, 3);
  107. }
  108. // clang-format on
  109. return test_args;
  110. }
  111. } // namespace group_local
  112. } // namespace test
  113. } // namespace megdnn
  114. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台