You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_static.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python3
  2. """
  3. purpose: use to test whether a model contain dynamic operator, if no dynamic
  4. operator the model is static, other wise the model is dynamic.
  5. """
  6. import argparse
  7. import logging
  8. import os
  9. import re
  10. import subprocess
  11. # test device
  12. device = {
  13. "name": "hwmt40p",
  14. "login_name": "hwmt40p-K9000-maliG78",
  15. "ip": "box86.br.megvii-inc.com",
  16. "port": 2200,
  17. "thread_number": 3,
  18. }
  19. class SshConnector:
  20. """imp ssh control master connector"""
  21. ip = None
  22. port = None
  23. login_name = None
  24. def setup(self, login_name, ip, port):
  25. self.ip = ip
  26. self.login_name = login_name
  27. self.port = port
  28. def copy(self, src_list, dst_dir):
  29. assert isinstance(src_list, list), "code issue happened!!"
  30. assert isinstance(dst_dir, str), "code issue happened!!"
  31. for src in src_list:
  32. cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
  33. self.port, src, self.login_name, self.ip, dst_dir
  34. )
  35. logging.debug("ssh run cmd: {}".format(cmd))
  36. subprocess.check_call(cmd, shell=True)
  37. def cmd(self, cmd):
  38. assert isinstance(cmd, list), "code issue happened!!"
  39. try:
  40. for sub_cmd in cmd:
  41. p_cmd = 'ssh -p {} {}@{} "{}" '.format(
  42. self.port, self.login_name, self.ip, sub_cmd
  43. )
  44. logging.debug("ssh run cmd: {}".format(p_cmd))
  45. subprocess.check_call(p_cmd, shell=True)
  46. except:
  47. raise
  48. def main():
  49. parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  50. parser.add_argument("--model_file", help="megengine model", required=True)
  51. parser.add_argument(
  52. "--load_and_run_file", help="path for load_and_run", required=True
  53. )
  54. args = parser.parse_args()
  55. assert os.path.isfile(
  56. args.model_file
  57. ), "invalid args for models_file, need a file for model"
  58. assert os.path.isfile(args.load_and_run_file), "invalid args for load_and_run_file"
  59. # init device
  60. ssh = SshConnector()
  61. ssh.setup(device["login_name"], device["ip"], device["port"])
  62. # create test dir
  63. workspace = "model_static_evaluation_workspace"
  64. ssh.cmd(["mkdir -p {}".format(workspace)])
  65. # copy load_and_run_file
  66. ssh.copy([args.load_and_run_file], workspace)
  67. model_file = args.model_file
  68. # copy model file
  69. ssh.copy([model_file], workspace)
  70. m = model_file.split('\\')[-1]
  71. # run single thread
  72. cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format(
  73. workspace, m
  74. )
  75. try:
  76. raw_log = ssh.cmd([cmd])
  77. except:
  78. print("model: {} is not static model, it has dynamic operator.".format(m))
  79. raise
  80. print("model: {} is static model.".format(m))
  81. if __name__ == "__main__":
  82. LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
  83. DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
  84. logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
  85. main()