You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.cpp 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #include "test/naive/rng.h"
  2. #include "hcc_detail/hcc_defs_prologue.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/tensor.h"
  5. #include "test/rocm/fixture.h"
  6. namespace megdnn {
  7. namespace test {
  8. TEST_F(ROCM, UNIFORM_RNG_F32) {
  9. auto opr = handle_rocm()->create_operator<UniformRNG>();
  10. SyncedTensor<> t(handle_rocm(), {TensorShape{200000}, dtype::Float32()});
  11. opr->exec(t.tensornd_dev(), {});
  12. assert_uniform_correct(t.ptr_mutable_host(), t.layout().total_nr_elems());
  13. }
  14. TEST_F(ROCM, GAUSSIAN_RNG_F32) {
  15. auto opr = handle_rocm()->create_operator<GaussianRNG>();
  16. opr->param().mean = 0.8;
  17. opr->param().std = 2.3;
  18. for (size_t size : {1, 200000, 200001}) {
  19. TensorLayout ly{{size}, dtype::Float32()};
  20. Tensor<dt_byte> workspace(
  21. handle_rocm(),
  22. {TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()});
  23. SyncedTensor<> t(handle_rocm(), ly);
  24. opr->exec(
  25. t.tensornd_dev(),
  26. {workspace.ptr(), workspace.layout().total_nr_elems()});
  27. auto ptr = t.ptr_mutable_host();
  28. if (size >= 1000) {
  29. auto stat = get_mean_var(ptr, size, 0.8f);
  30. ASSERT_LE(std::abs(stat.first - 0.8), 5e-3);
  31. ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2);
  32. }
  33. }
  34. }
  35. } // namespace test
  36. } // namespace megdnn
  37. // vim: syntax=cpp.doxygen