You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * \file dnn/test/common/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/tensor.h"
  12. #include "test/common/random_state.h"
  13. #include <random>
  14. using namespace megdnn;
  15. void test::init_gaussian(SyncedTensor<dt_float32> &tensor,
  16. dt_float32 mean, dt_float32 stddev) {
  17. auto ptr = tensor.ptr_mutable_host();
  18. auto n = tensor.layout().span().dist_elem();
  19. auto &&gen = RandomState::generator();
  20. std::normal_distribution<dt_float32> dist(mean, stddev);
  21. for (size_t i = 0; i < n; ++i) {
  22. ptr[i] = dist(gen);
  23. }
  24. }
  25. std::shared_ptr<TensorND> test::make_tensor_h2d(
  26. Handle *handle, const TensorND &htensor) {
  27. auto span = htensor.layout.span();
  28. TensorND ret{nullptr, htensor.layout};
  29. uint8_t* mptr = static_cast<uint8_t*>(
  30. megdnn_malloc(handle, span.dist_byte()));
  31. megdnn_memcpy_H2D(handle,
  32. mptr, static_cast<uint8_t*>(htensor.raw_ptr) + span.low_byte,
  33. span.dist_byte());
  34. ret.raw_ptr = mptr + span.low_byte;
  35. auto deleter = [handle, mptr](TensorND *p) {
  36. megdnn_free(handle, mptr);
  37. delete p;
  38. };
  39. return {new TensorND(ret), deleter};
  40. }
  41. std::shared_ptr<TensorND> test::make_tensor_d2h(
  42. Handle *handle, const TensorND &dtensor) {
  43. auto span = dtensor.layout.span();
  44. TensorND ret{nullptr, dtensor.layout};
  45. auto mptr = new uint8_t[span.dist_byte()];
  46. ret.raw_ptr = mptr + span.low_byte;
  47. megdnn_memcpy_D2H(handle,
  48. mptr, static_cast<uint8_t*>(dtensor.raw_ptr) + span.low_byte,
  49. span.dist_byte());
  50. auto deleter = [mptr](TensorND *p) {
  51. delete []mptr;
  52. delete p;
  53. };
  54. return {new TensorND(ret), deleter};
  55. }
  56. std::vector<std::shared_ptr<TensorND>> test::load_tensors(const char *fpath) {
  57. FILE *fin = fopen(fpath, "rb");
  58. megdnn_assert(fin);
  59. std::vector<std::shared_ptr<TensorND>> ret;
  60. for (; ; ) {
  61. char dtype[128];
  62. size_t ndim;
  63. if (fscanf(fin, "%s %zu", dtype, &ndim) != 2)
  64. break;
  65. TensorLayout layout;
  66. do {
  67. #define cb(_dt) if (!strcmp(DTypeTrait<dtype::_dt>::name, dtype)) { \
  68. layout.dtype = dtype::_dt(); \
  69. break; \
  70. }
  71. MEGDNN_FOREACH_DTYPE_NAME(cb)
  72. #undef cb
  73. char msg[256];
  74. sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype);
  75. ErrorHandler::on_megdnn_error(msg);
  76. } while(0);
  77. layout.ndim = ndim;
  78. for (size_t i = 0; i < ndim; ++ i) {
  79. auto nr = fscanf(fin, "%zu", &layout.shape[i]);
  80. megdnn_assert(nr == 1);
  81. }
  82. auto ch = fgetc(fin);
  83. megdnn_assert(ch == '\n');
  84. layout.init_contiguous_stride();
  85. auto size = layout.span().dist_byte();
  86. auto mptr = new uint8_t[size];
  87. auto nr = fread(mptr, 1, size, fin);
  88. auto deleter = [mptr](TensorND *p) {
  89. delete []mptr;
  90. delete p;
  91. };
  92. ret.emplace_back(new TensorND{mptr, layout}, deleter);
  93. megdnn_assert(nr == size);
  94. }
  95. fclose(fin);
  96. return ret;
  97. }
  98. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台