You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

param_pack.cpp 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * \file dnn/test/rocm/param_pack.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "test/rocm/fixture.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/utils.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. namespace {
  18. template <class T>
  19. std::vector<int32_t> create_offsets(const TensorShapeArray& shapes, size_t alignment) {
  20. size_t dtype_size = sizeof(T);
  21. if (alignment < dtype_size)
  22. alignment = dtype_size;
  23. alignment /= dtype_size;
  24. auto get_aligned = [alignment](size_t v) {
  25. auto mod = v & (alignment - 1);
  26. return v + ((alignment - mod) & (alignment - 1));
  27. };
  28. std::vector<dt_int32> offsets(shapes.size() << 1);
  29. size_t offset = 0;
  30. for (size_t i = 0; i < shapes.size(); i++) {
  31. offset = get_aligned(offset);
  32. offsets[i << 1] = offset;
  33. offset += shapes[i].total_nr_elems();
  34. offsets[(i << 1) + 1] = offset;
  35. }
  36. return offsets;
  37. }
  38. template <class T>
  39. std::vector<T> create_pack(
  40. size_t pack_size, const std::vector<int32_t>& offsets,
  41. const std::vector<std::vector<T>>& ptr) {
  42. megdnn_assert(pack_size == static_cast<size_t>(offsets.back()));
  43. std::vector<T> data(pack_size, 0);
  44. for (size_t i = 0; i * 2 < offsets.size(); ++i) {
  45. size_t begin = offsets[i * 2], end = offsets[i * 2 + 1];
  46. for (size_t j = 0; j < end - begin; j++)
  47. data[begin + j] = ptr[i][j];
  48. }
  49. return data;
  50. }
  51. template <class T>
  52. std::vector<std::vector<T>> create_params(
  53. size_t nr_params, const TensorShapeArray& shapes) {
  54. std::vector<std::vector<T>> params;
  55. for (size_t i = 0; i < nr_params; ++i) {
  56. std::vector<T> expected_data;
  57. for (size_t x = 0; x < shapes[i].total_nr_elems(); ++x) {
  58. expected_data.push_back(rand());
  59. }
  60. params.push_back(std::move(expected_data));
  61. }
  62. return params;
  63. }
  64. template <class T>
  65. T* create_device_data(Handle* handle, const T* data, size_t size) {
  66. T* data_device = static_cast<T*>(test::megdnn_malloc(handle, size * sizeof(T)));
  67. if (data)
  68. test::megdnn_memcpy_H2D(handle, data_device, data, size * sizeof(T));
  69. return data_device;
  70. }
  71. template <class T>
  72. void test_param_pack_concat(
  73. Handle* handle, const TensorShapeArray& shapes, DType type) {
  74. auto concat = handle->create_operator<ParamPackConcat>();
  75. size_t nr_params = shapes.size();
  76. std::vector<T*> param_ptrs;
  77. std::vector<std::vector<T>> params = create_params<T>(nr_params, shapes);
  78. for (size_t i = 0; i < nr_params; ++i) {
  79. param_ptrs.push_back(create_device_data<T>(
  80. handle, params[i].data(), shapes[i].total_nr_elems()));
  81. }
  82. std::vector<int32_t> offsets =
  83. create_offsets<T>(shapes, handle->alignment_requirement());
  84. size_t pack_size = offsets.back();
  85. int32_t* offsets_gpu =
  86. create_device_data<int32_t>(handle, offsets.data(), offsets.size());
  87. std::vector<T> expected_pack = create_pack<T>(pack_size, offsets, params);
  88. T* pack_gpu = create_device_data<T>(handle, nullptr, expected_pack.size());
  89. TensorLayout dst_layout({pack_size}, type);
  90. TensorND dst_tensor(pack_gpu, dst_layout);
  91. TensorLayout offsets_layout({offsets.size()}, dtype::Int32());
  92. TensorND offsets_tensor(offsets_gpu, offsets_layout);
  93. test::WorkspaceWrapper workspace(
  94. handle,
  95. concat->get_workspace_in_bytes(shapes, offsets_layout, {pack_size}));
  96. TensorND src_tensor(param_ptrs.data(), TensorLayout({nr_params}, dtype::Int32()));
  97. concat->exec(src_tensor, offsets_tensor, dst_tensor, workspace.workspace());
  98. // check
  99. T* actual_pack = static_cast<T*>(malloc(pack_size * sizeof(T)));
  100. test::megdnn_memcpy_D2H(handle, actual_pack, pack_gpu, sizeof(T) * pack_size);
  101. for (size_t i = 0; i < pack_size; ++i) {
  102. ASSERT_EQ(actual_pack[i], expected_pack[i]);
  103. }
  104. free(actual_pack);
  105. test::megdnn_free(handle, pack_gpu);
  106. test::megdnn_free(handle, offsets_gpu);
  107. for (auto ptr : param_ptrs) {
  108. test::megdnn_free(handle, ptr);
  109. }
  110. }
  111. } // namespace
  112. TEST_F(ROCM, PARAM_PACK) {
  113. SmallVector<TensorShapeArray> shapes_vec;
  114. shapes_vec.push_back({{1}});
  115. shapes_vec.push_back({{129}, {21}});
  116. shapes_vec.push_back({{15}, {21}, {34}});
  117. shapes_vec.push_back({{1, 2}, {3, 5}, {5, 8}, {7, 11}, {9, 14}});
  118. shapes_vec.push_back(
  119. {{1, 2},
  120. {3, 5},
  121. {1},
  122. {3, 3, 3, 4},
  123. {71},
  124. {9, 14},
  125. {111, 111, 111},
  126. {128, 128, 128}});
  127. for (auto shapes : shapes_vec) {
  128. test_param_pack_concat<int32_t>(handle_rocm(), shapes, dtype::Int32());
  129. test_param_pack_concat<int16_t>(handle_rocm(), shapes, dtype::Int16());
  130. test_param_pack_concat<float>(handle_rocm(), shapes, dtype::Float32());
  131. }
  132. }
  133. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台