You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_take.cpp 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /**
  2. * \file dnn/test/common/cond_take.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cond_take.h"
  12. #include "./utils.h"
  13. #include "./tensor.h"
  14. #include "./rng.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. using Param = CondTake::Param;
  18. std::vector<CondTakeTestcase> CondTakeTestcase::make() {
  19. std::vector<CondTakeTestcase> ret;
  20. for (uint32_t mode = 0; mode < Param::MODE_NR_MEMBER; ++ mode) {
  21. ret.push_back({
  22. Param{static_cast<Param::Mode>(mode), 0.1f, 0.1f},
  23. TensorLayout{{1}, dtype::Int8()},
  24. TensorLayout{{1}, dtype::Float32()},
  25. });
  26. ret.push_back({
  27. Param{static_cast<Param::Mode>(mode), 0.1f, 0.1f},
  28. TensorLayout{{2, 3}, dtype::Int8()},
  29. TensorLayout{{2, 3}, dtype::Float32()},
  30. });
  31. ret.push_back({
  32. Param{static_cast<Param::Mode>(mode), 100},
  33. TensorLayout{{1024}, dtype::Float32()},
  34. TensorLayout{{1024}, dtype::Int32()},
  35. });
  36. }
  37. NormalRNG data_rng;
  38. UniformIntRNG rng_byte(0, 255);
  39. auto fill_data = [&](TensorND data) {
  40. auto sz = data.layout.span().dist_byte(),
  41. szf = sz / sizeof(dt_float32);
  42. auto pf = static_cast<dt_float32*>(data.raw_ptr);
  43. data_rng.fill_fast_float32(pf, szf);
  44. auto prem = reinterpret_cast<uint8_t*>(pf + szf);
  45. size_t szrem = sz % sizeof(dt_float32);
  46. for (size_t i = 0; i < szrem; ++ i) {
  47. prem[i] = rng_byte.gen_single_val();
  48. }
  49. };
  50. for (auto &&i: ret) {
  51. auto size0 = i.m_data.layout.span().dist_byte(),
  52. size1 = i.m_mask.layout.span().dist_byte();
  53. i.m_mem.reset(new uint8_t[size0 + size1]);
  54. i.m_data.raw_ptr = i.m_mem.get();
  55. i.m_mask.raw_ptr = i.m_mem.get() + size0;
  56. fill_data(i.m_data);
  57. auto mean = i.m_param.val;
  58. if (i.m_mask.layout.dtype == dtype::Int32()) {
  59. UniformIntRNG rng(mean - 10, mean + 10);
  60. rng.gen(i.m_mask);
  61. } else {
  62. megdnn_assert(i.m_mask.layout.dtype == dtype::Float32());
  63. NormalRNG rng(mean);
  64. rng.gen(i.m_mask);
  65. }
  66. }
  67. return ret;
  68. }
  69. CondTakeTestcase::Result CondTakeTestcase::run(CondTake* opr) {
  70. auto handle = opr->handle();
  71. auto data = make_tensor_h2d(handle, m_data),
  72. mask = make_tensor_h2d(handle, m_mask);
  73. opr->param() = m_param;
  74. DynOutMallocPolicyImpl malloc_policy(handle);
  75. auto workspace_size = opr->get_workspace_in_bytes(data->layout);
  76. auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr);
  77. auto result =
  78. opr->exec(*data, *mask, {(dt_byte*)workspace_ptr, workspace_size},
  79. &malloc_policy);
  80. malloc_policy.free_workspace(workspace_ptr, nullptr);
  81. return {make_tensor_d2h(handle, result[0]),
  82. make_tensor_d2h(handle, result[1])};
  83. }
  84. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台