You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linspace.cpp 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/checker.h"
  5. #include "test/rocm/benchmarker.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(ROCM, LINSPACE) {
  9. Checker<Linspace> checker(handle_rocm());
  10. Linspace::Param param;
  11. param.start = 0.5;
  12. param.stop = 1.5;
  13. param.endpoint = true;
  14. for (DType dtype :
  15. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
  16. checker.set_dtype(0, dtype).set_param(param).exec(TensorShapeArray{{11}});
  17. }
  18. param.endpoint = false;
  19. for (DType dtype :
  20. std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) {
  21. checker.set_dtype(0, dtype).set_param(param).exec(TensorShapeArray{{11}});
  22. }
  23. }
  24. TEST_F(ROCM, LINSPACE_BENCHMARK) {
  25. ROCMBenchmarker<Linspace> benchmarker(handle_rocm(), handle_naive(false));
  26. benchmarker.set_display(true);
  27. Linspace::Param param{0.1, 9999.9, true};
  28. size_t sz = 50000;
  29. auto time_ms =
  30. benchmarker.set_dtype(0, dtype::Float32()).set_param(param).execs({{sz}});
  31. double bytes = sz * dtype::Float32().size();
  32. printf("vec size = %ld, bandwidth = %.2f GB/s\n", sz,
  33. (float)(bytes / (time_ms * 1e6)));
  34. }
  35. } // namespace test
  36. } // namespace megdnn
  37. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}