You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * \file dnn/test/common/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/tensor.h"
  12. #include "test/common/random_state.h"
  13. #include <random>
  14. using namespace megdnn;
  15. void test::init_gaussian(
  16. SyncedTensor<dt_float32>& tensor, dt_float32 mean, dt_float32 stddev) {
  17. auto ptr = tensor.ptr_mutable_host();
  18. auto n = tensor.layout().span().dist_elem();
  19. auto&& gen = RandomState::generator();
  20. std::normal_distribution<dt_float32> dist(mean, stddev);
  21. for (size_t i = 0; i < n; ++i) {
  22. ptr[i] = dist(gen);
  23. }
  24. }
  25. std::shared_ptr<TensorND> test::make_tensor_h2d(
  26. Handle* handle, const TensorND& htensor) {
  27. auto span = htensor.layout.span();
  28. TensorND ret{nullptr, htensor.layout};
  29. uint8_t* mptr = static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte()));
  30. megdnn_memcpy_H2D(
  31. handle, mptr, static_cast<uint8_t*>(htensor.raw_ptr) + span.low_byte,
  32. span.dist_byte());
  33. ret.raw_ptr = mptr + span.low_byte;
  34. auto deleter = [handle, mptr](TensorND* p) {
  35. megdnn_free(handle, mptr);
  36. delete p;
  37. };
  38. return {new TensorND(ret), deleter};
  39. }
  40. std::shared_ptr<TensorND> test::make_tensor_d2h(
  41. Handle* handle, const TensorND& dtensor) {
  42. auto span = dtensor.layout.span();
  43. TensorND ret{nullptr, dtensor.layout};
  44. auto mptr = new uint8_t[span.dist_byte()];
  45. ret.raw_ptr = mptr + span.low_byte;
  46. megdnn_memcpy_D2H(
  47. handle, mptr, static_cast<uint8_t*>(dtensor.raw_ptr) + span.low_byte,
  48. span.dist_byte());
  49. auto deleter = [mptr](TensorND* p) {
  50. delete[] mptr;
  51. delete p;
  52. };
  53. return {new TensorND(ret), deleter};
  54. }
  55. std::vector<std::shared_ptr<TensorND>> test::load_tensors(const char* fpath) {
  56. FILE* fin = fopen(fpath, "rb");
  57. megdnn_assert(fin);
  58. std::vector<std::shared_ptr<TensorND>> ret;
  59. for (;;) {
  60. char dtype[128];
  61. size_t ndim;
  62. if (fscanf(fin, "%s %zu", dtype, &ndim) != 2)
  63. break;
  64. TensorLayout layout;
  65. do {
  66. #define cb(_dt) \
  67. if (!strcmp(DTypeTrait<dtype::_dt>::name, dtype)) { \
  68. layout.dtype = dtype::_dt(); \
  69. break; \
  70. }
  71. MEGDNN_FOREACH_DTYPE_NAME(cb)
  72. #undef cb
  73. char msg[256];
  74. sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype);
  75. ErrorHandler::on_megdnn_error(msg);
  76. } while (0);
  77. layout.ndim = ndim;
  78. for (size_t i = 0; i < ndim; ++i) {
  79. auto nr = fscanf(fin, "%zu", &layout.shape[i]);
  80. megdnn_assert(nr == 1);
  81. }
  82. auto ch = fgetc(fin);
  83. megdnn_assert(ch == '\n');
  84. layout.init_contiguous_stride();
  85. auto size = layout.span().dist_byte();
  86. auto mptr = new uint8_t[size];
  87. auto nr = fread(mptr, 1, size, fin);
  88. auto deleter = [mptr](TensorND* p) {
  89. delete[] mptr;
  90. delete p;
  91. };
  92. ret.emplace_back(new TensorND{mptr, layout}, deleter);
  93. megdnn_assert(nr == size);
  94. }
  95. fclose(fin);
  96. return ret;
  97. }
  98. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台