You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation_model_parallelism.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #!/usr/bin/env python3
  2. """
  3. purpose: use to test whether a model have good parallelism, if a model have good
  4. parallelism it will get high performance improvement.
  5. """
  6. import argparse
  7. import logging
  8. import os
  9. import re
  10. import subprocess
  11. # test device
  12. device = {
  13. "name": "hwmt40p",
  14. "login_name": "hwmt40p-K9000-maliG78",
  15. "ip": "box86.br.megvii-inc.com",
  16. "port": 2200,
  17. "thread_number": 3,
  18. }
  19. class SshConnector:
  20. """imp ssh control master connector"""
  21. ip = None
  22. port = None
  23. login_name = None
  24. def setup(self, login_name, ip, port):
  25. self.ip = ip
  26. self.login_name = login_name
  27. self.port = port
  28. def copy(self, src_list, dst_dir):
  29. assert isinstance(src_list, list), "code issue happened!!"
  30. assert isinstance(dst_dir, str), "code issue happened!!"
  31. for src in src_list:
  32. cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
  33. self.port, src, self.login_name, self.ip, dst_dir
  34. )
  35. logging.debug("ssh run cmd: {}".format(cmd))
  36. subprocess.check_call(cmd, shell=True)
  37. def cmd(self, cmd):
  38. output = ""
  39. assert isinstance(cmd, list), "code issue happened!!"
  40. for sub_cmd in cmd:
  41. p_cmd = 'ssh -p {} {}@{} "{}" '.format(
  42. self.port, self.login_name, self.ip, sub_cmd
  43. )
  44. logging.debug("ssh run cmd: {}".format(p_cmd))
  45. output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8")
  46. return output
  47. def get_finally_bench_resulut_from_log(raw_log) -> float:
  48. # raw_log --> avg_time=23.331ms -->23.331ms
  49. h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:]
  50. # to 23.331
  51. h = h[: h.find("ms")]
  52. # to float
  53. h = float(h)
  54. return h
  55. def main():
  56. parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  57. parser.add_argument("--model_file", help="model file", required=True)
  58. parser.add_argument(
  59. "--load_and_run_file", help="path for load_and_run", required=True
  60. )
  61. args = parser.parse_args()
  62. # init device
  63. ssh = SshConnector()
  64. ssh.setup(device["login_name"], device["ip"], device["port"])
  65. # create test dir
  66. workspace = "model_parallelism_test"
  67. ssh.cmd(["mkdir -p {}".format(workspace)])
  68. # copy load_and_run_file
  69. ssh.copy([args.load_and_run_file], workspace)
  70. # call test
  71. model_file = args.model_file
  72. # copy model file
  73. ssh.copy([args.model_file], workspace)
  74. m = model_file.split('\\')[-1]
  75. # run single thread
  76. result = []
  77. thread_number = [1, 2, 4]
  78. for b in thread_number :
  79. cmd = []
  80. cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
  81. workspace, m, b
  82. )
  83. cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
  84. workspace, m, b
  85. )
  86. cmd.append(cmd1)
  87. cmd.append(cmd2)
  88. raw_log = ssh.cmd(cmd)
  89. # logging.debug(raw_log)
  90. ret = get_finally_bench_resulut_from_log(raw_log)
  91. logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
  92. result.append(ret)
  93. thread_2 = result[0]/result[1]
  94. thread_4 = result[0]/result[2]
  95. if thread_2 > 1.6 or thread_4 > 3.0:
  96. print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
  97. else:
  98. print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
  99. if __name__ == "__main__":
  100. LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
  101. DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
  102. logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
  103. main()