You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. /**
  2. * \file dnn/test/fallback/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/fallback/fixture.h"
  12. #include <ctime>
  13. #include "test/common/checker.h"
  14. #include "test/common/elemwise.h"
  15. #include "test/common/task_record_check.h"
  16. #include "test/common/tensor.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. template <typename tag>
  20. class FALLBACK_ELEMWISE : public FALLBACK {};
  21. TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
  22. TYPED_TEST(FALLBACK_ELEMWISE, run) {
  23. elemwise::run_test<TypeParam>(this->handle());
  24. }
  25. TEST_F(FALLBACK, ELEMWISE_RECORD) {
  26. TaskRecordChecker<Elemwise> checker{1};
  27. checker.set_param({Elemwise::Mode::ADD});
  28. checker.set_dtype(0, dtype::Float32());
  29. checker.set_dtype(1, dtype::Float32());
  30. checker.set_dtype(2, dtype::Float32());
  31. UniformIntRNG rng{-100, 100};
  32. checker.set_rng(0, &rng);
  33. checker.set_rng(1, &rng);
  34. checker.set_rng(2, &rng);
  35. checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
  36. }
  37. #if MEGDNN_WITH_BENCHMARK
  38. TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
  39. auto naive_handle = create_cpu_handle(2);
  40. auto run = [&](const TensorShape& shp0, const TensorShape& shp1) {
  41. TensorShape shpo;
  42. Elemwise::deduce_shape({shp0, shp1}, shpo);
  43. Tensor<> op0(handle(), {shp0, dtype::Float32()}),
  44. op1(handle(), {shp1, dtype::Float32()}),
  45. out(handle(), {shpo, dtype::Float32()});
  46. auto opr_cur = handle()->create_operator<Elemwise>();
  47. auto opr_naive = naive_handle->create_operator<Elemwise>();
  48. opr_cur->param() = {Elemwise::Mode::ADD};
  49. opr_naive->param() = {Elemwise::Mode::ADD};
  50. auto timeit = [&](Elemwise* opr) {
  51. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  52. auto start = clock();
  53. opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
  54. auto stop = clock();
  55. return (stop - start) * 1e3 / CLOCKS_PER_SEC;
  56. };
  57. auto t0 = timeit(opr_cur.get()), t1 = timeit(opr_naive.get());
  58. double tot_size_gb_ms =
  59. (op0.layout().span().dist_byte() + op1.layout().span().dist_byte() +
  60. out.layout().span().dist_byte()) /
  61. 1024.0 / 1024.0 / 1024.0 * 1e3;
  62. printf("%15s+%-15s: fallback=%7.3fms,%5.2fGiB/s "
  63. "naive=%7.3fms,%5.2fGiB/s\n",
  64. shp0.to_string().c_str(), shp1.to_string().c_str(), t0,
  65. tot_size_gb_ms / t0, t1, tot_size_gb_ms / t1);
  66. };
  67. // contig
  68. run({1024, 1024, 32}, {1024, 1024, 32});
  69. // bcast 101
  70. run({1024, 1024, 32}, {1, 1024, 1});
  71. // bcast 01
  72. run({4096 * 4, 1024}, {4096 * 4, 1});
  73. // bcast 10
  74. run({4096 * 4, 1024}, {1, 1024});
  75. // non-contig, fallback to naive
  76. run({1024, 1024, 32}, {1024, 1, 32});
  77. }
  78. #endif
  79. // vim: syntax=cpp.doxygen