You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_jit_trace.py 1.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import megengine.functional as F
  11. import megengine.module as M
  12. import megengine.utils.comp_graph_tools as cgtools
  13. from megengine.jit import trace
  14. from megengine.module import Module
  15. from megengine.traced_module import trace_module
  16. class MyBlock(Module):
  17. def __init__(self, in_channels, channels):
  18. super(MyBlock, self).__init__()
  19. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  20. self.bn1 = M.BatchNorm2d(channels)
  21. def forward(self, x):
  22. x = self.conv1(x)
  23. x = self.bn1(x)
  24. x = F.relu(x) + 1
  25. return x
  26. class MyModule(Module):
  27. def __init__(self):
  28. super(MyModule, self).__init__()
  29. self.block0 = MyBlock(8, 4)
  30. self.block1 = MyBlock(4, 2)
  31. def forward(self, x):
  32. x = self.block0(x)
  33. x = self.block1(x)
  34. return x
  35. def test_jit_trace():
  36. module = MyModule()
  37. module.eval()
  38. x = F.ones((1, 8, 14, 14))
  39. expect = module(x)
  40. traced_module = trace_module(module, x)
  41. func = trace(traced_module, capture_as_const=True)
  42. np.testing.assert_array_equal(func(x), expect)
  43. model = io.BytesIO()
  44. func.dump(model)
  45. model.seek(0)
  46. infer_cg = cgtools.GraphInference(model)
  47. np.testing.assert_allclose(
  48. list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6
  49. )