You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

add_demo.py 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/usr/bin/env python3
  2. import argparse
  3. import math
  4. import megengine.functional as F
  5. import megengine.hub as hub
  6. import megengine.module as M
  7. import numpy as np
  8. from megengine import jit, tensor, Parameter
  9. class Simple(M.Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.a = Parameter([0,1,2], dtype=np.float32)
  13. def forward(self, x):
  14. x = x + self.a
  15. return x
  16. if __name__ == "__main__":
  17. parser = argparse.ArgumentParser(
  18. description="dump mge model for add_demo",
  19. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  20. )
  21. parser.add_argument(
  22. "--inputs",
  23. help="set the inputs data to get a model with testcase",
  24. default="",
  25. type=str,
  26. )
  27. parser.add_argument(
  28. "--dir",
  29. help="set the dir where the model to dump",
  30. default=".",
  31. type=str,
  32. )
  33. args = parser.parse_args()
  34. net = Simple()
  35. net.eval()
  36. @jit.trace(symbolic=True, capture_as_const=True)
  37. def fun(data):
  38. return net(data)
  39. data = tensor([3.0,4.0,5.0])
  40. fun(data)
  41. if args.inputs == "":
  42. fun.dump(
  43. args.dir + "/add_demo_f32_without_data.mge", arg_names=["data"],
  44. no_assert=True,
  45. )
  46. else:
  47. fun.dump(
  48. args.dir + "/add_demo_f32_with_data.mge", arg_names=["data"],
  49. input_data=[args.inputs], no_assert=True,
  50. )