You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference_service_manager.py 4.1 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. import json
  14. import time
  15. from service.oneflow_inference_service import OneFlowInferenceService
  16. from service.tensorflow_inference_service import TensorflowInferenceService
  17. from service.pytorch_inference_service import PytorchInferenceService
  18. import service.common_inference_service as common_inference_service
  19. from logger import Logger
  20. from utils import file_utils
  21. from utils.find_class_in_file import FindClassInFile
  22. log = Logger().logger
  23. class InferenceServiceManager:
  24. def __init__(self, args):
  25. self.inference_service = None
  26. self.args = args
  27. self.model_name_service_map = {}
  28. def init(self):
  29. if self.args.model_config_file != "":
  30. with open(self.args.model_config_file) as data_file:
  31. model_config_file_dict = json.load(data_file)
  32. model_config_list = model_config_file_dict["model_config_list"]
  33. for model_config in model_config_list:
  34. model_name = model_config["model_name"]
  35. model_path = model_config["model_path"]
  36. self.args.model_name = model_name
  37. self.args.model_path = model_path
  38. model_platform = model_config.get("platform")
  39. if model_platform == "oneflow":
  40. self.inference_service = OneFlowInferenceService(self.args)
  41. elif model_platform == "tensorflow" or model_platform == "keras":
  42. self.inference_service = TensorflowInferenceService(self.args)
  43. elif model_platform == "pytorch":
  44. self.inference_service = PytorchInferenceService(self.args)
  45. self.model_name_service_map[model_name] = self.inference_service
  46. else:
  47. # Read from command-line parameter
  48. if self.args.use_script:
  49. # 使用自定义推理脚本
  50. find_class_in_file = FindClassInFile()
  51. cls = find_class_in_file.find(common_inference_service)
  52. self.inference_service = cls[1](self.args)
  53. else :
  54. # 使用默认推理脚本
  55. if self.args.platform == "oneflow":
  56. self.inference_service = OneFlowInferenceService(self.args)
  57. elif self.args.platform == "tensorflow" or self.args.platform == "keras":
  58. self.inference_service = TensorflowInferenceService(self.args)
  59. elif self.args.platform == "pytorch":
  60. self.inference_service = PytorchInferenceService(self.args)
  61. self.model_name_service_map[self.args.model_name] = self.inference_service
  62. def inference(self, model_name, data_list):
  63. """
  64. 在线服务推理方法
  65. """
  66. inferenceService = self.model_name_service_map[model_name]
  67. result = list()
  68. for data in data_list:
  69. output = inferenceService.inference(data)
  70. if len(data_list) == 1:
  71. return output
  72. else:
  73. result.append(output)
  74. return result
  75. def inference_and_save_json(self, model_name, json_path, data_list):
  76. """
  77. 批量服务推理方法
  78. """
  79. inferenceService = self.model_name_service_map[model_name]
  80. for data in data_list:
  81. result = inferenceService.inference(data)
  82. file_utils.writer_json_file(json_path, data['data_name'], result)
  83. time.sleep(1)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能