You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rng.cpp 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /**
  2. * \file dnn/test/naive/rng.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn.h"
  12. #include "test/naive/fixture.h"
  13. #include "test/naive/rng.h"
  14. #include "test/common/tensor.h"
  15. namespace megdnn {
  16. namespace test {
  17. template<typename ctype>
  18. void assert_uniform_correct(const ctype *src, size_t size) {
  19. for (size_t i = 0; i < size; ++ i) {
  20. ASSERT_GT(src[i], ctype(0));
  21. ASSERT_LE(src[i], ctype(1));
  22. }
  23. auto stat = get_mean_var(src, size, ctype(0.5));
  24. ASSERT_LE(std::abs(stat.first - 0.5), 1e-3);
  25. ASSERT_LE(std::abs(stat.second - 1.0 / 12), 1e-3);
  26. }
  27. namespace {
  28. template<typename dtype>
  29. void run_uniform(Handle *handle) {
  30. auto opr = handle->create_operator<UniformRNG>();
  31. Tensor<typename DTypeTrait<dtype>::ctype> t(
  32. handle, {TensorShape{200000}, dtype()});
  33. opr->exec(t.tensornd(), {});
  34. assert_uniform_correct(t.ptr(), t.layout().total_nr_elems());
  35. }
  36. template<typename dtype>
  37. void run_gaussian(Handle *handle) {
  38. using ctype = typename DTypeTrait<dtype>::ctype;
  39. auto opr = handle->create_operator<GaussianRNG>();
  40. opr->param().mean = 0.8;
  41. opr->param().std = 2.3;
  42. Tensor<ctype> t(handle, {TensorShape{200001}, dtype()});
  43. opr->exec(t.tensornd(), {});
  44. auto ptr = t.ptr();
  45. auto size = t.layout().total_nr_elems();
  46. for (size_t i = 0; i < size; ++ i) {
  47. ASSERT_LE(std::abs(ptr[i] - 0.8), ctype(15));
  48. }
  49. auto stat = get_mean_var(ptr, size, ctype(0.8));
  50. ASSERT_LE(std::abs(stat.first - 0.8), 5e-3);
  51. ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2);
  52. }
  53. }
  54. TEST_F(NAIVE, UNIFORM_RNG_F32) {
  55. run_uniform<dtype::Float32>(handle());
  56. }
  57. TEST_F(NAIVE, UNIFORM_RNG_F16) {
  58. MEGDNN_INC_FLOAT16(run_uniform<dtype::Float16>(handle()));
  59. }
  60. TEST_F(NAIVE, GAUSSIAN_RNG_F32) {
  61. run_gaussian<dtype::Float32>(handle());
  62. }
  63. TEST_F(NAIVE, GAUSSIAN_RNG_F16) {
  64. MEGDNN_INC_FLOAT16(run_gaussian<dtype::Float16>(handle()));
  65. }
  66. } // namespace test
  67. } // namespace megdnn
  68. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台