You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /**
  2. * \file dnn/test/common/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/tensor.h"
  12. #include "test/common/random_state.h"
  13. #include <random>
  14. using namespace megdnn;
  15. void test::init_gaussian(
  16. SyncedTensor<dt_float32>& tensor, dt_float32 mean, dt_float32 stddev) {
  17. auto ptr = tensor.ptr_mutable_host();
  18. auto n = tensor.layout().span().dist_elem();
  19. auto&& gen = RandomState::generator();
  20. std::normal_distribution<dt_float32> dist(mean, stddev);
  21. for (size_t i = 0; i < n; ++i) {
  22. ptr[i] = dist(gen);
  23. }
  24. }
  25. std::shared_ptr<TensorND> test::make_tensor_h2d(
  26. Handle* handle, const TensorND& htensor) {
  27. auto span = htensor.layout.span();
  28. uint8_t* mptr = static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte()));
  29. megdnn_memcpy_H2D(
  30. handle, mptr, static_cast<uint8_t*>(htensor.raw_ptr()) + span.low_byte,
  31. span.dist_byte());
  32. TensorND ret{mptr + span.low_byte, htensor.layout};
  33. auto deleter = [handle, mptr](TensorND* p) {
  34. megdnn_free(handle, mptr);
  35. delete p;
  36. };
  37. return {new TensorND(ret), deleter};
  38. }
  39. std::shared_ptr<TensorND> test::make_tensor_d2h(
  40. Handle* handle, const TensorND& dtensor) {
  41. auto span = dtensor.layout.span();
  42. auto mptr = new uint8_t[span.dist_byte()];
  43. TensorND ret{mptr + span.low_byte, dtensor.layout};
  44. megdnn_memcpy_D2H(
  45. handle, mptr, static_cast<uint8_t*>(dtensor.raw_ptr()) + span.low_byte,
  46. span.dist_byte());
  47. auto deleter = [mptr](TensorND* p) {
  48. delete[] mptr;
  49. delete p;
  50. };
  51. return {new TensorND(ret), deleter};
  52. }
  53. std::vector<std::shared_ptr<TensorND>> test::load_tensors(const char* fpath) {
  54. FILE* fin = fopen(fpath, "rb");
  55. megdnn_assert(fin);
  56. std::vector<std::shared_ptr<TensorND>> ret;
  57. for (;;) {
  58. char dtype[128];
  59. size_t ndim;
  60. if (fscanf(fin, "%s %zu", dtype, &ndim) != 2)
  61. break;
  62. TensorLayout layout;
  63. do {
  64. #define cb(_dt) \
  65. if (!strcmp(DTypeTrait<dtype::_dt>::name, dtype)) { \
  66. layout.dtype = dtype::_dt(); \
  67. break; \
  68. }
  69. MEGDNN_FOREACH_DTYPE_NAME(cb)
  70. #undef cb
  71. char msg[256];
  72. sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype);
  73. ErrorHandler::on_megdnn_error(msg);
  74. } while (0);
  75. layout.ndim = ndim;
  76. for (size_t i = 0; i < ndim; ++i) {
  77. auto nr = fscanf(fin, "%zu", &layout.shape[i]);
  78. megdnn_assert(nr == 1);
  79. }
  80. auto ch = fgetc(fin);
  81. megdnn_assert(ch == '\n');
  82. layout.init_contiguous_stride();
  83. auto size = layout.span().dist_byte();
  84. auto mptr = new uint8_t[size];
  85. auto nr = fread(mptr, 1, size, fin);
  86. auto deleter = [mptr](TensorND* p) {
  87. delete[] mptr;
  88. delete p;
  89. };
  90. ret.emplace_back(new TensorND{mptr, layout}, deleter);
  91. megdnn_assert(nr == size);
  92. }
  93. fclose(fin);
  94. return ret;
  95. }
  96. // vim: syntax=cpp.doxygen