import numpy as np from jina.executors.encoders.frameworks import BaseMindsporeEncoder class MindsporeLeNet(BaseMindsporeEncoder): """ :class:`MindsporeLeNet` Encoding image into vectors using mindspore. """ def encode(self, data, *args, **kwargs): # data is B x D, where D = 28 * 28 # LeNet only accepts BCHW format where H=W=32 # hence we need to do some simple transform from mindspore import Tensor data = np.pad(data.reshape([-1, 1, 28, 28]), [(0, 0), (0, 0), (0, 4), (0, 4)]).astype('float32') return self.model(Tensor(data)).asnumpy() def get_cell(self): from .lenet.src.lenet import LeNet5 class LeNet5Embed(LeNet5): def construct(self, x): x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) return x return LeNet5Embed()