You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /**
  2. * \file dnn/test/fallback/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/fallback/fixture.h"
  12. #include "test/common/checker.h"
  13. #include "test/common/tensor.h"
  14. #include "test/common/elemwise.h"
  15. #include <ctime>
  16. using namespace megdnn;
  17. using namespace test;
  18. template<typename tag>
  19. class FALLBACK_ELEMWISE: public FALLBACK {
  20. };
  21. TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
  22. TYPED_TEST(FALLBACK_ELEMWISE, run) {
  23. elemwise::run_test<TypeParam>(this->handle());
  24. }
  25. #if MEGDNN_WITH_BENCHMARK
  26. TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
  27. auto naive_handle = create_cpu_handle(2);
  28. auto run = [&](const TensorShape &shp0, const TensorShape &shp1) {
  29. TensorShape shpo;
  30. Elemwise::deduce_shape({shp0, shp1}, shpo);
  31. Tensor<> op0(handle(), {shp0, dtype::Float32()}),
  32. op1(handle(), {shp1, dtype::Float32()}),
  33. out(handle(), {shpo, dtype::Float32()});
  34. auto opr_cur = handle()->create_operator<Elemwise>();
  35. auto opr_naive = naive_handle->create_operator<Elemwise>();
  36. opr_cur->param() = {Elemwise::Mode::ADD};
  37. opr_naive->param() = {Elemwise::Mode::ADD};
  38. auto timeit = [&](Elemwise *opr) {
  39. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  40. auto start = clock();
  41. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  42. auto stop = clock();
  43. return (stop - start) * 1e3 / CLOCKS_PER_SEC;
  44. };
  45. auto t0 = timeit(opr_cur.get()),
  46. t1 = timeit(opr_naive.get());
  47. double tot_size_gb_ms = (
  48. op0.layout().span().dist_byte() +
  49. op1.layout().span().dist_byte() +
  50. out.layout().span().dist_byte()) /
  51. 1024.0 / 1024.0 / 1024.0 * 1e3;
  52. printf("%15s+%-15s: fallback=%7.3fms,%5.2fGiB/s "
  53. "naive=%7.3fms,%5.2fGiB/s\n",
  54. shp0.to_string().c_str(), shp1.to_string().c_str(),
  55. t0, tot_size_gb_ms / t0, t1, tot_size_gb_ms / t1);
  56. };
  57. // contig
  58. run({1024, 1024, 32}, {1024, 1024, 32});
  59. // bcast 101
  60. run({1024, 1024, 32}, {1, 1024, 1});
  61. // bcast 01
  62. run({4096 * 4, 1024}, {4096 * 4, 1});
  63. // bcast 10
  64. run({4096 * 4, 1024}, {1, 1024});
  65. // non-contig, fallback to naive
  66. run({1024, 1024, 32}, {1024, 1, 32});
  67. }
  68. #endif
  69. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台