You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_one_hot.cpp 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. /**
  2. * \file dnn/test/cuda/indexing_one_hot.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/benchmarker.h"
  12. #include "test/common/indexing_one_hot.h"
  13. #include "test/cuda/fixture.h"
  14. #include "megcore_cuda.h"
  15. #include "megdnn/oprs/general.h"
  16. using namespace megdnn;
  17. using namespace test;
  18. TEST_F(CUDA, INDEXING_ONE_HOT) {
  19. run_indexing_one_hot_test(handle_cuda());
  20. }
  21. TEST_F(CUDA_ERROR_INFO, INDEXING_ONE_HOT) {
  22. ASSERT_EQ(0u, get_error_info().nr_error);
  23. bool failed = false;
  24. auto on_failure = [&failed, this]() {
  25. failed = true;
  26. auto err = get_error_info();
  27. ASSERT_GE(err.nr_error, 1u);
  28. printf("error msg: ");
  29. printf(err.msg, err.msg_args[0], err.msg_args[1], err.msg_args[2],
  30. err.msg_args[3]);
  31. printf("\n");
  32. };
  33. run_indexing_one_hot_test(handle_cuda(), on_failure);
  34. ASSERT_TRUE(failed);
  35. }
  36. TEST_F(CUDA, BENCHMARK_INDEXING_ONE_HOT) {
  37. Benchmarker<IndexingOneHot> bench{handle_cuda()};
  38. bench.set_times(1);
  39. UniformFloatRNG rng_val{-10, 10};
  40. UniformIntRNG rng_idx{0, 119};
  41. bench.set_param({2})
  42. .set_dtype(1, dtype::Int32{})
  43. .set_rng(1, &rng_idx)
  44. .set_rng(0, &rng_val);
  45. constexpr size_t A = 99, B = 41, C = 120, D = 191;
  46. auto time = bench.execs({{A, B, C, D}, {A, B, D}, {}}) * 1e-3;
  47. printf("bandwidth: %.2fGiB/s\n",
  48. A * B * D * sizeof(float) / 1024.0 / 1024 / 1024 / time);
  49. }
  50. TEST_F(CUDA, INDEXING_SET_ONE_HOT) {
  51. run_indexing_set_one_hot_test(handle_cuda());
  52. }
  53. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台