You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. /**
  2. * \file dnn/test/naive/resize.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/common/resize.h"
  12. #include "megdnn/oprs/cv.h"
  13. #include "test/common/checker.h"
  14. #include "test/naive/fixture.h"
  15. using namespace megdnn;
  16. using namespace test;
  17. TEST_F(NAIVE, RESIZE_NCHW4) {
  18. Checker<Resize> checker(handle());
  19. auto args = resize::get_nchw4_args();
  20. auto convert_true_format = [](const TensorLayout& layout) {
  21. return layout.reshape({layout[0], layout[1] / 4, layout[2], layout[3], 4})
  22. .dimshuffle({0, 1, 4, 2, 3});
  23. };
  24. for (auto&& arg : args) {
  25. auto extra_impl = [this, param = arg.param,
  26. convert_true_format](const TensorNDArray& tensors) {
  27. auto resize = handle()->create_operator<Resize>();
  28. resize->param().imode = param.imode;
  29. resize->param().format = Resize::Param::Format::NCHW;
  30. TensorNDArray nchw_tensors;
  31. for (size_t i = 0; i < tensors.size(); ++i) {
  32. auto layout = tensors[i].layout;
  33. layout = layout.reshape(
  34. {layout[0], layout[1] * 4, layout[2], layout[3]});
  35. layout.dtype = dtype::Int8();
  36. nchw_tensors.emplace_back(malloc(layout.span().dist_byte()), layout);
  37. }
  38. TensorNDArray nchw4_tensors;
  39. for (size_t i = 0; i < tensors.size(); ++i) {
  40. auto layout = convert_true_format(nchw_tensors[i].layout);
  41. nchw4_tensors.emplace_back(tensors[i].raw_ptr(), std::move(layout));
  42. }
  43. auto relayout = handle()->create_operator<RelayoutForward>();
  44. relayout->exec(nchw4_tensors[0], nchw_tensors[0]);
  45. auto workspace_size = resize->get_workspace_in_bytes(
  46. nchw_tensors[0].layout, nchw_tensors[1].layout);
  47. dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
  48. Workspace workspace{workspace_ptr, workspace_size};
  49. resize->exec(nchw_tensors[0], nchw_tensors[1], workspace);
  50. relayout->exec(nchw_tensors[1], nchw4_tensors[1]);
  51. free(workspace_ptr);
  52. for (auto&& tensor : nchw_tensors) {
  53. free(tensor.raw_ptr());
  54. }
  55. };
  56. checker.set_extra_opr_impl(extra_impl);
  57. checker.set_param(arg.param)
  58. .set_dtype(0, dtype::QuantizedS8(0.1f))
  59. .set_dtype(1, dtype::QuantizedS8(0.1f))
  60. .set_epsilon(1 + 1e-3)
  61. .execs({arg.src, arg.dst});
  62. }
  63. }