You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

remap.h 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /**
  2. * \file dnn/test/common/remap.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <iostream>
  14. #include "megdnn/basic_types.h"
  15. #include "megdnn/opr_param_defs.h"
  16. #include "./rng.h"
  17. namespace megdnn {
  18. namespace test {
  19. namespace remap {
  20. struct TestArg {
  21. param::Remap param;
  22. TensorShape src;
  23. TensorShape map_xy;
  24. TensorShape dst;
  25. TestArg(param::Remap param_, TensorShape src_, TensorShape map_xy_,
  26. TensorShape dst_)
  27. : param(param_), src(src_), map_xy(map_xy_), dst(dst_) {}
  28. };
  29. static inline std::vector<TestArg> get_nchw_args() {
  30. std::vector<TestArg> args;
  31. param::Remap param;
  32. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW};
  33. std::vector<param::Remap::BorderMode> border_mode_vec = {
  34. param::Remap::BorderMode::CONSTANT,
  35. param::Remap::BorderMode::REFLECT_101,
  36. param::Remap::BorderMode::REFLECT,
  37. param::Remap::BorderMode::WRAP,
  38. param::Remap::BorderMode::REPLICATE};
  39. // current do not test this.
  40. std::vector<float> scalar;
  41. for (auto fmt : format_vec) {
  42. for (auto border_type : border_mode_vec) {
  43. param.format = fmt;
  44. param.border_type = border_type;
  45. args.emplace_back(param, TensorShape{1, 1, 2, 2},
  46. TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2});
  47. args.emplace_back(param, TensorShape{1, 3, 2, 2},
  48. TensorShape{1, 2, 2, 2}, TensorShape{1, 3, 2, 2});
  49. args.emplace_back(param, TensorShape{1, 10, 100, 100},
  50. TensorShape{1, 100, 100, 2},
  51. TensorShape{1, 10, 100, 100});
  52. args.emplace_back(param, TensorShape{2, 4, 100, 200},
  53. TensorShape{2, 100, 200, 2},
  54. TensorShape{2, 4, 100, 200});
  55. args.emplace_back(param, TensorShape{2, 4, 100, 200},
  56. TensorShape{2, 20, 30, 2},
  57. TensorShape{2, 4, 20, 30});
  58. args.emplace_back(param, TensorShape{2, 4, 10, 10},
  59. TensorShape{2, 20, 30, 2},
  60. TensorShape{2, 4, 20, 30});
  61. }
  62. }
  63. return args;
  64. }
  65. static inline std::vector<TestArg> get_nhwc_args() {
  66. std::vector<TestArg> args;
  67. param::Remap param;
  68. std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC};
  69. std::vector<param::Remap::BorderMode> border_mode_vec = {
  70. param::Remap::BorderMode::CONSTANT,
  71. param::Remap::BorderMode::REFLECT_101,
  72. param::Remap::BorderMode::REFLECT,
  73. param::Remap::BorderMode::WRAP,
  74. param::Remap::BorderMode::REPLICATE};
  75. // current do not test this.
  76. std::vector<float> scalar;
  77. for (auto fmt : format_vec) {
  78. for (auto border_type : border_mode_vec) {
  79. param.format = fmt;
  80. param.border_type = border_type;
  81. param.scalar = 12.f;
  82. args.emplace_back(param, TensorShape{1, 2, 2, 1},
  83. TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1});
  84. args.emplace_back(param, TensorShape{1, 2, 2, 3},
  85. TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 3});
  86. args.emplace_back(param, TensorShape{1, 2, 2, 66},
  87. TensorShape{1, 2, 2, 2},
  88. TensorShape{1, 2, 2, 66});
  89. args.emplace_back(param, TensorShape{1, 100, 100, 10},
  90. TensorShape{1, 100, 100, 2},
  91. TensorShape{1, 100, 100, 10});
  92. args.emplace_back(param, TensorShape{2, 100, 200, 4},
  93. TensorShape{2, 100, 200, 2},
  94. TensorShape{2, 100, 200, 4});
  95. args.emplace_back(param, TensorShape{2, 100, 200, 4},
  96. TensorShape{2, 20, 30, 2},
  97. TensorShape{2, 20, 30, 4});
  98. args.emplace_back(param, TensorShape{2, 10, 10, 4},
  99. TensorShape{2, 20, 30, 2},
  100. TensorShape{2, 20, 30, 4});
  101. }
  102. }
  103. return args;
  104. }
  105. } // namespace remap
  106. } // namespace test
  107. } // namespace megdnn
  108. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台