You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_preprocess_1.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import io
  2. import pickle
  3. import numpy as np
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.comp_graph_tools as cgtools
  7. from megengine.core._trace_option import set_symbolic_shape
  8. from megengine.jit import trace
  9. from megengine.traced_module import trace_module
  10. class Main(M.Module):
  11. def forward(self, x):
  12. return x
  13. class PreProcess(M.Module):
  14. def __init__(self):
  15. super().__init__()
  16. self.I = F.ones((1,))
  17. self.M = F.zeros((1,))
  18. def forward(self, data, idx, roi):
  19. N, H, W, C = data.shape
  20. xmax = roi[:, 1, 0]
  21. xmin = roi[:, 0, 0]
  22. ymax = roi[:, 1, 1]
  23. ymin = roi[:, 0, 1]
  24. scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
  25. I = F.broadcast_to(self.I, (N,))
  26. M = F.broadcast_to(self.M, (N, 3, 3))
  27. M[:, 0, 0] = scale
  28. M[:, 0, 2] = xmin
  29. M[:, 1, 1] = scale
  30. M[:, 1, 2] = ymin
  31. M[:, 2, 2] = I
  32. resized = (
  33. F.warp_perspective(
  34. data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
  35. )
  36. .transpose(0, 3, 1, 2)
  37. .astype(np.float32)
  38. )
  39. return resized
  40. class Net(M.Module):
  41. def __init__(self, traced_module):
  42. super().__init__()
  43. self.pre_process = PreProcess()
  44. self.traced_module = traced_module
  45. def forward(self, data, idx, roi):
  46. x = self.pre_process(data, idx, roi)
  47. x = self.traced_module(x)
  48. return x
  49. def test_preprocess():
  50. saved = set_symbolic_shape(True)
  51. module = Main()
  52. data = F.ones((1, 14, 8, 8), dtype=np.uint8)
  53. traced_module = trace_module(module, data)
  54. obj = pickle.dumps(traced_module)
  55. traced_module = pickle.loads(obj)
  56. module = Net(traced_module)
  57. module.eval()
  58. idx = F.zeros((1,), dtype=np.int32)
  59. roi = F.ones((1, 2, 2), dtype=np.float32)
  60. y = module(data, idx, roi)
  61. traced_module = trace_module(module, data, idx, roi)
  62. np.testing.assert_array_equal(traced_module(data, idx, roi), y)
  63. func = trace(traced_module, capture_as_const=True)
  64. np.testing.assert_array_equal(func(data, idx, roi), y)
  65. model = io.BytesIO()
  66. func.dump(model, arg_names=("data", "idx", "roi"))
  67. model.seek(0)
  68. infer_cg = cgtools.GraphInference(model)
  69. np.testing.assert_allclose(
  70. list(
  71. infer_cg.run(
  72. inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
  73. ).values()
  74. )[0],
  75. y,
  76. atol=1e-6,
  77. )
  78. set_symbolic_shape(saved)