You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation_model_parallelism.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python3
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """
  10. purpose: use to test whether a model have good parallelism, if a model have good
  11. parallelism it will get high performance improvement.
  12. """
  13. import argparse
  14. import logging
  15. import os
  16. import re
  17. import subprocess
  18. # test device
  19. device = {
  20. "name": "hwmt40p",
  21. "login_name": "hwmt40p-K9000-maliG78",
  22. "ip": "box86.br.megvii-inc.com",
  23. "port": 2200,
  24. "thread_number": 3,
  25. }
  26. class SshConnector:
  27. """imp ssh control master connector"""
  28. ip = None
  29. port = None
  30. login_name = None
  31. def setup(self, login_name, ip, port):
  32. self.ip = ip
  33. self.login_name = login_name
  34. self.port = port
  35. def copy(self, src_list, dst_dir):
  36. assert isinstance(src_list, list), "code issue happened!!"
  37. assert isinstance(dst_dir, str), "code issue happened!!"
  38. for src in src_list:
  39. cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
  40. self.port, src, self.login_name, self.ip, dst_dir
  41. )
  42. logging.debug("ssh run cmd: {}".format(cmd))
  43. subprocess.check_call(cmd, shell=True)
  44. def cmd(self, cmd):
  45. output = ""
  46. assert isinstance(cmd, list), "code issue happened!!"
  47. for sub_cmd in cmd:
  48. p_cmd = 'ssh -p {} {}@{} "{}" '.format(
  49. self.port, self.login_name, self.ip, sub_cmd
  50. )
  51. logging.debug("ssh run cmd: {}".format(p_cmd))
  52. output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8")
  53. return output
  54. def get_finally_bench_resulut_from_log(raw_log) -> float:
  55. # raw_log --> avg_time=23.331ms -->23.331ms
  56. h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:]
  57. # to 23.331
  58. h = h[: h.find("ms")]
  59. # to float
  60. h = float(h)
  61. return h
  62. def main():
  63. parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  64. parser.add_argument("--model_file", help="model file", required=True)
  65. parser.add_argument(
  66. "--load_and_run_file", help="path for load_and_run", required=True
  67. )
  68. args = parser.parse_args()
  69. # init device
  70. ssh = SshConnector()
  71. ssh.setup(device["login_name"], device["ip"], device["port"])
  72. # create test dir
  73. workspace = "model_parallelism_test"
  74. ssh.cmd(["mkdir -p {}".format(workspace)])
  75. # copy load_and_run_file
  76. ssh.copy([args.load_and_run_file], workspace)
  77. # call test
  78. model_file = args.model_file
  79. # copy model file
  80. ssh.copy([args.model_file], workspace)
  81. m = model_file.split('\\')[-1]
  82. # run single thread
  83. result = []
  84. thread_number = [1, 2, 4]
  85. for b in thread_number :
  86. cmd = []
  87. cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
  88. workspace, m, b
  89. )
  90. cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
  91. workspace, m, b
  92. )
  93. cmd.append(cmd1)
  94. cmd.append(cmd2)
  95. raw_log = ssh.cmd(cmd)
  96. # logging.debug(raw_log)
  97. ret = get_finally_bench_resulut_from_log(raw_log)
  98. logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
  99. result.append(ret)
  100. thread_2 = result[0]/result[1]
  101. thread_4 = result[0]/result[2]
  102. if thread_2 > 1.6 or thread_4 > 3.0:
  103. print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
  104. else:
  105. print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
  106. if __name__ == "__main__":
  107. LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
  108. DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
  109. logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
  110. main()