""" Copyright 2020 Tianshu AI Platform. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import io import torch import torch.nn.functional as functional from PIL import Image from torchvision import transforms import requests from imagenet1000_clsidx_to_labels import clsidx_2_labels from io import BytesIO from logger import Logger from service.abstract_inference_service import AbstractInferenceService log = Logger().logger class PytorchInferenceService(AbstractInferenceService): """ pytorch 框架推理service """ def __init__(self, args): super().__init__() self.args = args self.model_name = args.model_name self.model_path = args.model_path self.model = self.load_model() self.checkpoint = None def load_image(self, image_path): if image_path.startswith("http"): response = requests.get(image_path) response = response.content BytesIOObj = BytesIO() BytesIOObj.write(response) image = Image.open(BytesIOObj) else: image = open(image_path, 'rb').read() image = Image.open(io.BytesIO(image)) if image.mode != 'RGB': image = image.convert("RGB") image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) image = transforms.ToTensor()(image) image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) image = image[None] if self.args.use_gpu: image = image.cuda() log.info("===============> load image success <===============") return image def load_model(self): log.info("===============> start load pytorch model :" + self.args.model_path + " <===============") if os.path.isfile(self.args.model_path): self.checkpoint = torch.load(self.model_path) else: for file in os.listdir(self.args.model_path): self.checkpoint = torch.load(self.model_path + file) model = self.checkpoint[self.args.model_structure] model.load_state_dict(self.checkpoint['state_dict']) for parameter in model.parameters(): parameter.requires_grad = False if self.args.use_gpu: model.cuda() model.eval() log.info("===============> load pytorch model success <===============") return model def inference(self, image): data = {"data_name": image['data_name']} log.info("===============> start load " + image['data_name'] + " <===============") image = self.load_image(image['data_path']) preds = functional.softmax(self.model(image), dim=1) results = torch.topk(preds.data, k=5, dim=1) data['predictions'] = list() for prob, label in zip(results[0][0], results[1][0]): result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} data['predictions'].append(result) return data