You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_demo.py 1.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/env python3
  2. import argparse
  3. import math
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import numpy as np
  7. from megengine import jit, tensor
  8. class ConvNet(M.Module):
  9. def __init__(self):
  10. super().__init__()
  11. self.conv1 = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3, bias=False)
  12. def forward(self, input):
  13. x = self.conv1(input)
  14. return x
  15. if __name__ == "__main__":
  16. parser = argparse.ArgumentParser(
  17. description="dump mge model for add_demo",
  18. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  19. )
  20. parser.add_argument(
  21. "--dir",
  22. help="set the dir where the model to dump",
  23. default=".",
  24. type=str,
  25. )
  26. args = parser.parse_args()
  27. net = ConvNet()
  28. net.eval()
  29. @jit.trace(symbolic=True, capture_as_const=True)
  30. def fun(data):
  31. return net(data)
  32. inp = tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  33. out = fun(inp)
  34. fun.dump(args.dir + "/conv_demo_f32_without_data.mge", arg_names=["data"],
  35. no_assert=True)