""" Copyright 2020 Tianshu AI Platform. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import io import torch import torch.nn.functional as functional from PIL import Image from torchvision import transforms from imagenet1000_clsidx_to_labels import clsidx_2_labels from logger import Logger log = Logger().logger # 只能定义一个class class CommonInferenceService: # 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义) def __init__(self, args): self.args = args self.model = self.load_model() def load_data(self, data_path): image = open(data_path, 'rb').read() image = Image.open(io.BytesIO(image)) if image.mode != 'RGB': image = image.convert("RGB") image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) image = transforms.ToTensor()(image) image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) image = image[None] if self.args.use_gpu: image = image.cuda() return image def load_model(self): if os.path.isfile(self.args.model_path): self.checkpoint = torch.load(self.args.model_path) else: for file in os.listdir(self.args.model_path): self.checkpoint = torch.load(self.args.model_path + file) model = self.checkpoint["model"] model.load_state_dict(self.checkpoint['state_dict']) for parameter in model.parameters(): parameter.requires_grad = False if self.args.use_gpu: model.cuda() model.eval() return model # inference方法名称固定 def inference(self, data): result = {"data_name": data['data_name']} log.info("===============> start load " + data['data_name'] + " <===============") data = self.load_data(data['data_path']) preds = functional.softmax(self.model(data), dim=1) predictions = torch.topk(preds.data, k=5, dim=1) result['predictions'] = list() for prob, label in zip(predictions[0][0], predictions[1][0]): predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} result['predictions'].append(predictions) return result # 非必须,可用于本地调试 if __name__=="__main__": import argparse parser = argparse.ArgumentParser(description='dubhe serving') parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") args = parser.parse_args() server = CommonInferenceService(args) image_path = "./cat.jpg" image = {"data_name": "cat.jpg", "data_path": image_path} re = server.inference(image) print(re)