You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. /**
  2. * \file dnn/test/fallback/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/fallback/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/tensor.h"
  14. #include "test/common/elemwise.h"
  15. #include <ctime>
  16. using namespace megdnn;
  17. using namespace test;
  18. template<typename tag>
  19. class FALLBACK_ELEMWISE: public FALLBACK {
  20. };
  21. TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
  22. TYPED_TEST(FALLBACK_ELEMWISE, run) {
  23. elemwise::run_test<TypeParam>(this->handle());
  24. }
  25. TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
  26. auto naive_handle = create_cpu_handle(2);
  27. auto run = [&](const TensorShape &shp0, const TensorShape &shp1) {
  28. TensorShape shpo;
  29. Elemwise::deduce_shape({shp0, shp1}, shpo);
  30. Tensor<> op0(handle(), {shp0, dtype::Float32()}),
  31. op1(handle(), {shp1, dtype::Float32()}),
  32. out(handle(), {shpo, dtype::Float32()});
  33. auto opr_cur = handle()->create_operator<Elemwise>();
  34. auto opr_naive = naive_handle->create_operator<Elemwise>();
  35. opr_cur->param() = {Elemwise::Mode::ADD};
  36. opr_naive->param() = {Elemwise::Mode::ADD};
  37. auto timeit = [&](Elemwise *opr) {
  38. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  39. auto start = clock();
  40. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  41. auto stop = clock();
  42. return (stop - start) * 1e3 / CLOCKS_PER_SEC;
  43. };
  44. auto t0 = timeit(opr_cur.get()),
  45. t1 = timeit(opr_naive.get());
  46. double tot_size_gb_ms = (
  47. op0.layout().span().dist_byte() +
  48. op1.layout().span().dist_byte() +
  49. out.layout().span().dist_byte()) /
  50. 1024.0 / 1024.0 / 1024.0 * 1e3;
  51. printf("%15s+%-15s: fallback=%7.3fms,%5.2fGiB/s "
  52. "naive=%7.3fms,%5.2fGiB/s\n",
  53. shp0.to_string().c_str(), shp1.to_string().c_str(),
  54. t0, tot_size_gb_ms / t0, t1, tot_size_gb_ms / t1);
  55. };
  56. // contig
  57. run({1024, 1024, 32}, {1024, 1024, 32});
  58. // bcast 101
  59. run({1024, 1024, 32}, {1, 1024, 1});
  60. // bcast 01
  61. run({4096 * 4, 1024}, {4096 * 4, 1});
  62. // bcast 10
  63. run({4096 * 4, 1024}, {1, 1024});
  64. // non-contig, fallback to naive
  65. run({1024, 1024, 32}, {1024, 1, 32});
  66. }
  67. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)