You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_one_hot.cpp 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/indexing_one_hot.h"
  4. #include "test/rocm/fixture.h"
  5. #include "megcore_rocm.h"
  6. #include "megdnn/oprs/general.h"
  7. #include "test/rocm/benchmarker.h"
  8. using namespace megdnn;
  9. using namespace test;
  10. TEST_F(ROCM, INDEXING_ONE_HOT) {
  11. run_indexing_one_hot_test(handle_rocm());
  12. }
  13. TEST_F(ROCM_ERROR_INFO, INDEXING_ONE_HOT) {
  14. ASSERT_EQ(0u, get_error_info().nr_error);
  15. bool failed = false;
  16. auto on_failure = [&failed, this]() {
  17. failed = true;
  18. auto err = get_error_info();
  19. ASSERT_GE(err.nr_error, 1u);
  20. printf("error msg: ");
  21. printf(err.msg, err.msg_args[0], err.msg_args[1], err.msg_args[2],
  22. err.msg_args[3]);
  23. printf("\n");
  24. };
  25. run_indexing_one_hot_test(handle_rocm(), on_failure);
  26. ASSERT_TRUE(failed);
  27. }
  28. TEST_F(ROCM, INDEXING_ONE_HOT_BENCHMARK) {
  29. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  30. auto benchmarker =
  31. ROCMBenchmarker<IndexingOneHotForward>(handle_rocm(), handle_naive(false));
  32. UniformFloatRNG rng_val{-10, 10};
  33. UniformIntRNG rng_idx{0, 119};
  34. benchmarker.set_display(true);
  35. benchmarker.set_param({2})
  36. .set_dtype(1, dtype::Int32{})
  37. .set_rng(1, &rng_idx)
  38. .set_rng(0, &rng_val);
  39. constexpr size_t A = 99, B = 41, C = 120, D = 191;
  40. benchmarker.execs({{A, B, C, D}, {A, B, D}, {}});
  41. auto time = benchmarker.execs({{A, B, C, D}, {A, B, D}, {}});
  42. time = benchmarker.execs({{A, B, C, D}, {A, B, D}, {}});
  43. printf("bandwidth: %.2fGiB/s\n", A * B * D * sizeof(float) / (1e6 * time));
  44. }
  45. TEST_F(ROCM, INDEXING_SET_ONE_HOT) {
  46. run_indexing_set_one_hot_test(handle_rocm());
  47. }
  48. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}