You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

correlation.h 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. /**
  2. * \file dnn/test/common/correlation.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/opr_param_defs.h"
  15. namespace megdnn {
  16. namespace test {
  17. namespace correlation {
  18. struct TestArg {
  19. param::Correlation param;
  20. TensorShape data1, data2;
  21. TestArg(param::Correlation param, TensorShape data1, TensorShape data2)
  22. : param(param), data1(data1), data2(data2) {}
  23. };
  24. inline static std::vector<TestArg> get_args() {
  25. std::vector<TestArg> args;
  26. param::Correlation cur_param;
  27. for (size_t batch_size : {2}) {
  28. for (size_t channel : {2}) {
  29. for (size_t height : {160}) {
  30. for (size_t width : {160}) {
  31. cur_param.is_multiply = true;
  32. cur_param.kernel_size = 3;
  33. cur_param.max_displacement = 3;
  34. cur_param.pad_size = 0;
  35. cur_param.stride1 = 1;
  36. cur_param.stride2 = 1;
  37. cur_param.format = megdnn::param::Correlation::Format::NCHW;
  38. args.emplace_back(
  39. cur_param,
  40. TensorShape{batch_size, channel, height, width},
  41. TensorShape{batch_size, channel, height, width});
  42. }
  43. }
  44. }
  45. }
  46. return args;
  47. }
  48. } // namespace correlation
  49. } // namespace test
  50. } // namespace megdnn
  51. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台