You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_preprocess_1.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import io
  2. import pickle
  3. import numpy as np
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.comp_graph_tools as cgtools
  7. from megengine.core._trace_option import set_symbolic_shape
  8. from megengine.jit import trace
  9. from megengine.traced_module import trace_module
  10. set_symbolic_shape(True)
  11. class Main(M.Module):
  12. def forward(self, x):
  13. return x
  14. class PreProcess(M.Module):
  15. def __init__(self):
  16. super().__init__()
  17. self.I = F.ones((1,))
  18. self.M = F.zeros((1,))
  19. def forward(self, data, idx, roi):
  20. N, H, W, C = data.shape
  21. xmax = roi[:, 1, 0]
  22. xmin = roi[:, 0, 0]
  23. ymax = roi[:, 1, 1]
  24. ymin = roi[:, 0, 1]
  25. scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
  26. I = F.broadcast_to(self.I, (N,))
  27. M = F.broadcast_to(self.M, (N, 3, 3))
  28. M[:, 0, 0] = scale
  29. M[:, 0, 2] = xmin
  30. M[:, 1, 1] = scale
  31. M[:, 1, 2] = ymin
  32. M[:, 2, 2] = I
  33. resized = (
  34. F.warp_perspective(
  35. data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
  36. )
  37. .transpose(0, 3, 1, 2)
  38. .astype(np.float32)
  39. )
  40. return resized
  41. class Net(M.Module):
  42. def __init__(self, traced_module):
  43. super().__init__()
  44. self.pre_process = PreProcess()
  45. self.traced_module = traced_module
  46. def forward(self, data, idx, roi):
  47. x = self.pre_process(data, idx, roi)
  48. x = self.traced_module(x)
  49. return x
  50. def test_preprocess():
  51. module = Main()
  52. data = F.ones((1, 14, 8, 8), dtype=np.uint8)
  53. traced_module = trace_module(module, data)
  54. obj = pickle.dumps(traced_module)
  55. traced_module = pickle.loads(obj)
  56. module = Net(traced_module)
  57. module.eval()
  58. idx = F.zeros((1,), dtype=np.int32)
  59. roi = F.ones((1, 2, 2), dtype=np.float32)
  60. y = module(data, idx, roi)
  61. traced_module = trace_module(module, data, idx, roi)
  62. np.testing.assert_array_equal(traced_module(data, idx, roi), y)
  63. func = trace(traced_module, capture_as_const=True)
  64. np.testing.assert_array_equal(func(data, idx, roi), y)
  65. model = io.BytesIO()
  66. func.dump(model, arg_names=("data", "idx", "roi"))
  67. model.seek(0)
  68. infer_cg = cgtools.GraphInference(model)
  69. np.testing.assert_allclose(
  70. list(
  71. infer_cg.run(
  72. inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
  73. ).values()
  74. )[0],
  75. y,
  76. atol=1e-6,
  77. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台