You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include "test/common/tensor.h"
  2. #include "test/common/random_state.h"
  3. #include <random>
  4. using namespace megdnn;
  5. void test::init_gaussian(
  6. SyncedTensor<dt_float32>& tensor, dt_float32 mean, dt_float32 stddev) {
  7. auto ptr = tensor.ptr_mutable_host();
  8. auto n = tensor.layout().span().dist_elem();
  9. auto&& gen = RandomState::generator();
  10. std::normal_distribution<dt_float32> dist(mean, stddev);
  11. for (size_t i = 0; i < n; ++i) {
  12. ptr[i] = dist(gen);
  13. }
  14. }
  15. std::shared_ptr<TensorND> test::make_tensor_h2d(
  16. Handle* handle, const TensorND& htensor) {
  17. auto span = htensor.layout.span();
  18. uint8_t* mptr = static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte()));
  19. megdnn_memcpy_H2D(
  20. handle, mptr, static_cast<uint8_t*>(htensor.raw_ptr()) + span.low_byte,
  21. span.dist_byte());
  22. TensorND ret{mptr + span.low_byte, htensor.layout};
  23. auto deleter = [handle, mptr](TensorND* p) {
  24. megdnn_free(handle, mptr);
  25. delete p;
  26. };
  27. return {new TensorND(ret), deleter};
  28. }
  29. std::shared_ptr<TensorND> test::make_tensor_d2h(
  30. Handle* handle, const TensorND& dtensor) {
  31. auto span = dtensor.layout.span();
  32. auto mptr = new uint8_t[span.dist_byte()];
  33. TensorND ret{mptr + span.low_byte, dtensor.layout};
  34. megdnn_memcpy_D2H(
  35. handle, mptr, static_cast<uint8_t*>(dtensor.raw_ptr()) + span.low_byte,
  36. span.dist_byte());
  37. auto deleter = [mptr](TensorND* p) {
  38. delete[] mptr;
  39. delete p;
  40. };
  41. return {new TensorND(ret), deleter};
  42. }
  43. std::vector<std::shared_ptr<TensorND>> test::load_tensors(const char* fpath) {
  44. FILE* fin = fopen(fpath, "rb");
  45. megdnn_assert(fin);
  46. std::vector<std::shared_ptr<TensorND>> ret;
  47. for (;;) {
  48. char dtype[128];
  49. size_t ndim;
  50. if (fscanf(fin, "%s %zu", dtype, &ndim) != 2)
  51. break;
  52. TensorLayout layout;
  53. do {
  54. #define cb(_dt) \
  55. if (!strcmp(DTypeTrait<dtype::_dt>::name, dtype)) { \
  56. layout.dtype = dtype::_dt(); \
  57. break; \
  58. }
  59. MEGDNN_FOREACH_DTYPE_NAME(cb)
  60. #undef cb
  61. char msg[256];
  62. sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype);
  63. ErrorHandler::on_megdnn_error(msg);
  64. } while (0);
  65. layout.ndim = ndim;
  66. for (size_t i = 0; i < ndim; ++i) {
  67. auto nr = fscanf(fin, "%zu", &layout.shape[i]);
  68. megdnn_assert(nr == 1);
  69. }
  70. auto ch = fgetc(fin);
  71. megdnn_assert(ch == '\n');
  72. layout.init_contiguous_stride();
  73. auto size = layout.span().dist_byte();
  74. auto mptr = new uint8_t[size];
  75. auto nr = fread(mptr, 1, size, fin);
  76. auto deleter = [mptr](TensorND* p) {
  77. delete[] mptr;
  78. delete p;
  79. };
  80. ret.emplace_back(new TensorND{mptr, layout}, deleter);
  81. megdnn_assert(nr == size);
  82. }
  83. fclose(fin);
  84. return ret;
  85. }
  86. // vim: syntax=cpp.doxygen