You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checksum.cpp 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. /**
  2. * \file dnn/test/rocm/checksum.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "megdnn/oprs.h"
  13. #include "test/common/checker.h"
  14. #include "test/rocm/fixture.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. TEST_F(ROCM, CHECKSUM_FORWARD) {
  18. auto rocm_opr = handle_rocm()->create_operator<megdnn::Checksum>(),
  19. naive_opr = handle_naive()->create_operator<megdnn::Checksum>();
  20. std::mt19937 rng(std::random_device{}());
  21. for (size_t size : {3, 8, 12345, 1024 * 1024, 1024 * 1024 * 10}) {
  22. auto aligned_size = size + ((512 - size % 512) % 512);
  23. auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) {
  24. TensorND tensor;
  25. tensor.reset_ptr(ptr);
  26. tensor.layout.init_contiguous_stride({size});
  27. tensor.layout.dtype = dtype::Byte();
  28. WorkspaceWrapper workspace(
  29. handle_rocm(), opr->get_workspace_in_bytes(tensor.layout));
  30. if (log_size) {
  31. printf("checksum(%zu): workspace=%zu\n", size,
  32. workspace.workspace().size);
  33. }
  34. return opr->exec(tensor, workspace.workspace());
  35. };
  36. std::vector<uint8_t> buf(aligned_size);
  37. for (size_t i = 0; i < size; ++i)
  38. buf[i] = rng();
  39. auto run_offsset = [&](size_t offset) {
  40. void* dev_ptr = megdnn_malloc(handle_rocm(), buf.size() + offset);
  41. void* dev_buf = static_cast<char*>(dev_ptr) + offset;
  42. Checksum::Result res_rocm[2], res_naive[2];
  43. for (int change_last = 0; change_last < 2; ++change_last) {
  44. if (change_last)
  45. ++buf[size - 1];
  46. megdnn_memcpy_H2D(handle_rocm(), dev_buf, buf.data(), size);
  47. res_rocm[change_last] = run(rocm_opr.get(), dev_buf, !change_last);
  48. res_naive[change_last] = run(naive_opr.get(), buf.data(), false);
  49. }
  50. megdnn_free(handle_rocm(), dev_ptr);
  51. ASSERT_EQ(res_naive[0], res_rocm[0]) << "failed for size " << size;
  52. ASSERT_EQ(res_naive[1], res_rocm[1]);
  53. ASSERT_NE(res_rocm[0], res_rocm[1]);
  54. };
  55. for (size_t i = 0; i < 8; ++i) {
  56. run_offsset(i);
  57. }
  58. }
  59. }
  60. // vim: syntax=cpp.doxygen