You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. /**
  2. * \file dnn/test/cuda/local.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/cuda/fixture.h"
  12. #include <cuda_runtime_api.h>
  13. #include "megcore_cuda.h"
  14. #include "test/common/checker.h"
  15. #include "test/common/local.h"
  16. #include "test/cuda/local/local.h"
  17. namespace megdnn {
  18. namespace test {
  19. TEST_F(CUDA, LOCAL_FORWARD) {
  20. auto args = local::get_args_for_cuda();
  21. for (auto&& arg : args) {
  22. Checker<LocalForward> checker(handle_cuda());
  23. cudaStream_t stream;
  24. ASSERT_EQ(
  25. megcoreSuccess,
  26. megcoreGetCUDAStream(
  27. handle_cuda()->megcore_computing_handle(), &stream));
  28. pollute_shared_mem(stream);
  29. checker.set_param(arg.param).exec(
  30. TensorShapeArray{arg.sshape(), arg.fshape(), arg.dshape()});
  31. }
  32. }
  33. TEST_F(CUDA, LOCAL_BACKWARD_DATA) {
  34. using namespace local;
  35. // std::vector<TestArg> args;
  36. // args.emplace_back(param::Convolution{
  37. // param::Convolution::Mode::CROSS_CORRELATION,
  38. // 1, 1, 1, 1},
  39. // 64, 16, 8, 7, 16, 8, 7, 3, 3);
  40. auto args = local::get_args_bwd_data_for_cuda();
  41. for (auto&& arg : args) {
  42. Checker<LocalBackwardData> checker(handle_cuda());
  43. cudaStream_t stream;
  44. ASSERT_EQ(
  45. megcoreSuccess,
  46. megcoreGetCUDAStream(
  47. handle_cuda()->megcore_computing_handle(), &stream));
  48. pollute_shared_mem(stream);
  49. checker.set_param(arg.param).exec(
  50. TensorShapeArray{arg.fshape(), arg.dshape(), arg.sshape()});
  51. }
  52. }
  53. TEST_F(CUDA, LOCAL_BACKWARD_FILTER) {
  54. using namespace local;
  55. // std::vector<TestArg> args;
  56. // args.emplace_back(param::Convolution{
  57. // param::Convolution::Mode::CROSS_CORRELATION,
  58. // 1, 1, 1, 1},
  59. // 64, 16, 8, 7, 16, 8, 7, 3, 3);
  60. auto args = local::get_args_bwd_filter_for_cuda();
  61. for (auto&& arg : args) {
  62. Checker<LocalBackwardFilter> checker(handle_cuda());
  63. cudaStream_t stream;
  64. ASSERT_EQ(
  65. megcoreSuccess,
  66. megcoreGetCUDAStream(
  67. handle_cuda()->megcore_computing_handle(), &stream));
  68. pollute_shared_mem(stream);
  69. checker.set_param(arg.param).exec(
  70. TensorShapeArray{arg.sshape(), arg.dshape(), arg.fshape()});
  71. }
  72. }
  73. } // namespace test
  74. } // namespace megdnn
  75. // vim: syntax=cpp.doxygen