You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_static.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """
  10. purpose: use to test whether a model contain dynamic operator, if no dynamic
  11. operator the model is static, other wise the model is dynamic.
  12. """
  13. import argparse
  14. import logging
  15. import os
  16. import re
  17. import subprocess
  18. # test device
  19. device = {
  20. "name": "hwmt40p",
  21. "login_name": "hwmt40p-K9000-maliG78",
  22. "ip": "box86.br.megvii-inc.com",
  23. "port": 2200,
  24. "thread_number": 3,
  25. }
  26. class SshConnector:
  27. """imp ssh control master connector"""
  28. ip = None
  29. port = None
  30. login_name = None
  31. def setup(self, login_name, ip, port):
  32. self.ip = ip
  33. self.login_name = login_name
  34. self.port = port
  35. def copy(self, src_list, dst_dir):
  36. assert isinstance(src_list, list), "code issue happened!!"
  37. assert isinstance(dst_dir, str), "code issue happened!!"
  38. for src in src_list:
  39. cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
  40. self.port, src, self.login_name, self.ip, dst_dir
  41. )
  42. logging.debug("ssh run cmd: {}".format(cmd))
  43. subprocess.check_call(cmd, shell=True)
  44. def cmd(self, cmd):
  45. assert isinstance(cmd, list), "code issue happened!!"
  46. try:
  47. for sub_cmd in cmd:
  48. p_cmd = 'ssh -p {} {}@{} "{}" '.format(
  49. self.port, self.login_name, self.ip, sub_cmd
  50. )
  51. logging.debug("ssh run cmd: {}".format(p_cmd))
  52. subprocess.check_call(p_cmd, shell=True)
  53. except:
  54. raise
  55. def main():
  56. parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  57. parser.add_argument("--model_file", help="megengine model", required=True)
  58. parser.add_argument(
  59. "--load_and_run_file", help="path for load_and_run", required=True
  60. )
  61. args = parser.parse_args()
  62. assert os.path.isfile(
  63. args.model_file
  64. ), "invalid args for models_file, need a file for model"
  65. assert os.path.isfile(args.load_and_run_file), "invalid args for load_and_run_file"
  66. # init device
  67. ssh = SshConnector()
  68. ssh.setup(device["login_name"], device["ip"], device["port"])
  69. # create test dir
  70. workspace = "model_static_evaluation_workspace"
  71. ssh.cmd(["mkdir -p {}".format(workspace)])
  72. # copy load_and_run_file
  73. ssh.copy([args.load_and_run_file], workspace)
  74. model_file = args.model_file
  75. # copy model file
  76. ssh.copy([model_file], workspace)
  77. m = model_file.split('\\')[-1]
  78. # run single thread
  79. cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format(
  80. workspace, m
  81. )
  82. try:
  83. raw_log = ssh.cmd([cmd])
  84. except:
  85. print("model: {} is not static model, it has dynamic operator.".format(m))
  86. raise
  87. print("model: {} is static model.".format(m))
  88. if __name__ == "__main__":
  89. LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
  90. DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
  91. logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
  92. main()