You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #include "test/cuda/fixture.h"
  2. #include <cuda_runtime_api.h>
  3. #include "megcore_cuda.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/local.h"
  6. #include "test/cuda/local/local.h"
  7. namespace megdnn {
  8. namespace test {
  9. TEST_F(CUDA, LOCAL_FORWARD) {
  10. auto args = local::get_args_for_cuda();
  11. for (auto&& arg : args) {
  12. Checker<LocalForward> checker(handle_cuda());
  13. cudaStream_t stream;
  14. ASSERT_EQ(
  15. megcoreSuccess,
  16. megcoreGetCUDAStream(
  17. handle_cuda()->megcore_computing_handle(), &stream));
  18. pollute_shared_mem(stream);
  19. checker.set_param(arg.param).exec(
  20. TensorShapeArray{arg.sshape(), arg.fshape(), arg.dshape()});
  21. }
  22. }
  23. TEST_F(CUDA, LOCAL_BACKWARD_DATA) {
  24. using namespace local;
  25. // std::vector<TestArg> args;
  26. // args.emplace_back(param::Convolution{
  27. // param::Convolution::Mode::CROSS_CORRELATION,
  28. // 1, 1, 1, 1},
  29. // 64, 16, 8, 7, 16, 8, 7, 3, 3);
  30. auto args = local::get_args_bwd_data_for_cuda();
  31. for (auto&& arg : args) {
  32. Checker<LocalBackwardData> checker(handle_cuda());
  33. cudaStream_t stream;
  34. ASSERT_EQ(
  35. megcoreSuccess,
  36. megcoreGetCUDAStream(
  37. handle_cuda()->megcore_computing_handle(), &stream));
  38. pollute_shared_mem(stream);
  39. checker.set_param(arg.param).exec(
  40. TensorShapeArray{arg.fshape(), arg.dshape(), arg.sshape()});
  41. }
  42. }
  43. TEST_F(CUDA, LOCAL_BACKWARD_FILTER) {
  44. using namespace local;
  45. // std::vector<TestArg> args;
  46. // args.emplace_back(param::Convolution{
  47. // param::Convolution::Mode::CROSS_CORRELATION,
  48. // 1, 1, 1, 1},
  49. // 64, 16, 8, 7, 16, 8, 7, 3, 3);
  50. auto args = local::get_args_bwd_filter_for_cuda();
  51. for (auto&& arg : args) {
  52. Checker<LocalBackwardFilter> checker(handle_cuda());
  53. cudaStream_t stream;
  54. ASSERT_EQ(
  55. megcoreSuccess,
  56. megcoreGetCUDAStream(
  57. handle_cuda()->megcore_computing_handle(), &stream));
  58. pollute_shared_mem(stream);
  59. checker.set_param(arg.param).exec(
  60. TensorShapeArray{arg.sshape(), arg.dshape(), arg.fshape()});
  61. }
  62. }
  63. } // namespace test
  64. } // namespace megdnn
  65. // vim: syntax=cpp.doxygen