You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #include "test/cpu/fixture.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/local.h"
  5. #include "test/common/timer.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(CPU, LOCAL) {
  9. auto args = local::get_args();
  10. for (auto&& arg : args) {
  11. Checker<Local> checker(handle());
  12. checker.set_param(arg.param).exec(
  13. TensorShapeArray{arg.sshape(), arg.fshape(), arg.dshape()});
  14. }
  15. }
  16. #if MEGDNN_WITH_BENCHMARK
  17. TEST_F(CPU, BENCHMARK_LOCAL) {
  18. size_t T = 10;
  19. float memcpy_bandwidth, local_bandwidth;
  20. {
  21. std::vector<float> src(1000000), dst(1000000);
  22. auto total_mem = (src.size() + dst.size()) * sizeof(float) * T;
  23. Timer timer;
  24. timer.start();
  25. for (size_t t = 0; t < T; ++t) {
  26. std::memcpy(dst.data(), src.data(), sizeof(float) * src.size());
  27. // to prevent compiler optimizing out memcpy above.
  28. asm volatile("");
  29. }
  30. timer.stop();
  31. auto time_in_ms = timer.get_time_in_us() / 1e3;
  32. auto bandwidth = total_mem / (time_in_ms / 1000.0f);
  33. std::cout << "Copy from src(" << src.data() << ") to dst(" << dst.data() << ")"
  34. << std::endl;
  35. std::cout << "Memcpy bandwidth is " << bandwidth / 1e9 << "GB/s" << std::endl;
  36. memcpy_bandwidth = bandwidth;
  37. }
  38. {
  39. Benchmarker<Local> benchmarker(handle());
  40. TensorShape src{2, 64, 7, 7}, filter{5, 5, 64, 3, 3, 64}, dst{2, 64, 5, 5};
  41. Local::Param param;
  42. param.pad_h = param.pad_w = 0;
  43. auto time_in_ms =
  44. benchmarker.set_times(T).set_param(param).set_display(false).exec(
  45. {src, filter, dst});
  46. auto total_mem = (src.total_nr_elems() + filter.total_nr_elems() +
  47. dst.total_nr_elems()) *
  48. sizeof(float) * T;
  49. auto bandwidth = total_mem / (time_in_ms / 1000.0f);
  50. std::cout << "Bandwidth is " << bandwidth / 1e9 << "GB/s" << std::endl;
  51. local_bandwidth = bandwidth;
  52. }
  53. float ratio = local_bandwidth / memcpy_bandwidth;
  54. ASSERT_GE(ratio, 0.05);
  55. }
  56. #endif
  57. } // namespace test
  58. } // namespace megdnn
  59. // vim: syntax=cpp.doxygen