You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_take.cpp 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #include "./cond_take.h"
  2. #include "./rng.h"
  3. #include "./tensor.h"
  4. #include "./utils.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. using Param = CondTake::Param;
  8. std::vector<CondTakeTestcase> CondTakeTestcase::make() {
  9. std::vector<CondTakeTestcase> ret;
  10. for (uint32_t mode = 0; mode < Param::MODE_NR_MEMBER; ++mode) {
  11. ret.push_back({
  12. Param{static_cast<Param::Mode>(mode), 0.1f, 0.1f},
  13. TensorLayout{{1}, dtype::Int8()},
  14. TensorLayout{{1}, dtype::Float32()},
  15. });
  16. ret.push_back({
  17. Param{static_cast<Param::Mode>(mode), 0.1f, 0.1f},
  18. TensorLayout{{2, 3}, dtype::Int8()},
  19. TensorLayout{{2, 3}, dtype::Float32()},
  20. });
  21. ret.push_back({
  22. Param{static_cast<Param::Mode>(mode), 100},
  23. TensorLayout{{1024}, dtype::Float32()},
  24. TensorLayout{{1024}, dtype::Int32()},
  25. });
  26. }
  27. NormalRNG data_rng;
  28. UniformIntRNG rng_byte(0, 255);
  29. auto fill_data = [&](TensorND data) {
  30. auto sz = data.layout.span().dist_byte(), szf = sz / sizeof(dt_float32);
  31. auto pf = static_cast<dt_float32*>(data.raw_ptr());
  32. data_rng.fill_fast_float32(pf, szf);
  33. auto prem = reinterpret_cast<uint8_t*>(pf + szf);
  34. size_t szrem = sz % sizeof(dt_float32);
  35. for (size_t i = 0; i < szrem; ++i) {
  36. prem[i] = rng_byte.gen_single_val();
  37. }
  38. };
  39. for (auto&& i : ret) {
  40. auto size0 = i.m_data.layout.span().dist_byte(),
  41. size1 = i.m_mask.layout.span().dist_byte();
  42. i.m_mem.reset(new uint8_t[size0 + size1]);
  43. i.m_data.reset_ptr(i.m_mem.get());
  44. i.m_mask.reset_ptr(i.m_mem.get() + size0);
  45. fill_data(i.m_data);
  46. auto mean = i.m_param.val;
  47. if (i.m_mask.layout.dtype == dtype::Int32()) {
  48. UniformIntRNG rng(mean - 10, mean + 10);
  49. rng.gen(i.m_mask);
  50. } else {
  51. megdnn_assert(i.m_mask.layout.dtype == dtype::Float32());
  52. NormalRNG rng(mean);
  53. rng.gen(i.m_mask);
  54. }
  55. }
  56. return ret;
  57. }
  58. CondTakeTestcase::Result CondTakeTestcase::run(CondTake* opr) {
  59. auto handle = opr->handle();
  60. auto data = make_tensor_h2d(handle, m_data), mask = make_tensor_h2d(handle, m_mask);
  61. opr->param() = m_param;
  62. DynOutMallocPolicyImpl malloc_policy(handle);
  63. auto workspace_size = opr->get_workspace_in_bytes(data->layout, mask->layout);
  64. auto workspace_ptr = malloc_policy.alloc_workspace(workspace_size, nullptr);
  65. auto result = opr->exec(
  66. *data, *mask, {(dt_byte*)workspace_ptr, workspace_size}, &malloc_policy);
  67. malloc_policy.free_workspace(workspace_ptr, nullptr);
  68. return {make_tensor_d2h(handle, result[0]), make_tensor_d2h(handle, result[1])};
  69. }
  70. // vim: syntax=cpp.doxygen