You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_preprocess_2.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import io
  2. import pickle
  3. import numpy as np
  4. import megengine as mge
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.utils.comp_graph_tools as cgtools
  8. from megengine.core._trace_option import set_symbolic_shape
  9. from megengine.jit import trace
  10. from megengine.traced_module import trace_module
  11. set_symbolic_shape(True)
  12. class Main(M.Module):
  13. def forward(self, x):
  14. return x["data"]
  15. class PreProcess(M.Module):
  16. def __init__(self):
  17. super().__init__()
  18. self.A = F.zeros((1,))
  19. self.I = F.ones((1,))
  20. self.bb_out = mge.tensor(
  21. np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32")
  22. )
  23. def forward(self, data, quad):
  24. """
  25. data: (1, 3, 48, 160)
  26. quad: (1, 4, 2)
  27. """
  28. N = quad.shape[0]
  29. dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2)
  30. I = F.broadcast_to(self.I, quad.shape)
  31. A = F.broadcast_to(self.A, (N, 8, 8))
  32. A[:, 0:4, 0:2] = quad
  33. A[:, 4:8, 5:6] = I[:, :, 0:1]
  34. A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1]
  35. A[:, 4:8, 3:5] = quad
  36. A[:, 0:4, 2:3] = I[:, :, 0:1]
  37. A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2]
  38. B = dst.transpose(0, 2, 1).reshape(-1, 8, 1)
  39. M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape(
  40. -1, 3, 3
  41. )
  42. new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160)
  43. return {"data": new_data}
  44. class Net(M.Module):
  45. def __init__(self, traced_module):
  46. super().__init__()
  47. self.pre_process = PreProcess()
  48. self.traced_module = traced_module
  49. def forward(self, data, quad):
  50. x = self.pre_process(data, quad)
  51. x = self.traced_module(x)
  52. return x
  53. def test_preprocess():
  54. batch_size = 2
  55. module = Main()
  56. data = mge.tensor(
  57. np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32
  58. )
  59. traced_module = trace_module(module, {"data": data})
  60. obj = pickle.dumps(traced_module)
  61. traced_module = pickle.loads(obj)
  62. module = Net(traced_module)
  63. module.eval()
  64. quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32)
  65. expect = module(data, quad)
  66. traced_module = trace_module(module, data, quad)
  67. actual = traced_module(data, quad)
  68. for i, j in zip(expect, actual):
  69. np.testing.assert_array_equal(i, j)
  70. func = trace(traced_module, capture_as_const=True)
  71. actual = func(data, quad)
  72. for i, j in zip(expect, actual):
  73. np.testing.assert_array_equal(i, j)
  74. model = io.BytesIO()
  75. func.dump(model, arg_names=("data", "quad"))
  76. model.seek(0)
  77. infer_cg = cgtools.GraphInference(model)
  78. actual = list(
  79. infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
  80. )[0]
  81. np.testing.assert_allclose(expect, actual)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台