You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_jit_trace.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import io
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine.jit import trace
  7. from megengine.module import Module
  8. from megengine.traced_module import trace_module
  9. class MyBlock(Module):
  10. def __init__(self, in_channels, channels):
  11. super(MyBlock, self).__init__()
  12. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  13. self.bn1 = M.BatchNorm2d(channels)
  14. def forward(self, x):
  15. x = self.conv1(x)
  16. x = self.bn1(x)
  17. x = F.relu(x) + 1
  18. return x
  19. class MyModule(Module):
  20. def __init__(self):
  21. super(MyModule, self).__init__()
  22. self.block0 = MyBlock(8, 4)
  23. self.block1 = MyBlock(4, 2)
  24. def forward(self, x):
  25. x = self.block0(x)
  26. x = self.block1(x)
  27. return x
  28. def test_jit_trace():
  29. module = MyModule()
  30. module.eval()
  31. x = F.ones((1, 8, 14, 14))
  32. expect = module(x)
  33. traced_module = trace_module(module, x)
  34. func = trace(traced_module, capture_as_const=True)
  35. np.testing.assert_array_equal(func(x), expect)
  36. model = io.BytesIO()
  37. func.dump(model)
  38. model.seek(0)
  39. infer_cg = cgtools.GraphInference(model)
  40. np.testing.assert_allclose(
  41. list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6
  42. )