You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 1.2 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637
  1. import numpy as np
  2. from jina.executors.encoders.frameworks import BaseMindsporeEncoder
  3. class MindsporeLeNet(BaseMindsporeEncoder):
  4. """
  5. :class:`MindsporeLeNet` Encoding image into vectors using mindspore.
  6. """
  7. def encode(self, data, *args, **kwargs):
  8. # data is B x D, where D = 28 * 28
  9. # LeNet only accepts BCHW format where H=W=32
  10. # hence we need to do some simple transform
  11. from mindspore import Tensor
  12. data = np.pad(data.reshape([-1, 1, 28, 28]),
  13. [(0, 0), (0, 0), (0, 4), (0, 4)]).astype('float32')
  14. return self.model(Tensor(data)).asnumpy()
  15. def get_cell(self):
  16. from .lenet.src.lenet import LeNet5
  17. class LeNet5Embed(LeNet5):
  18. def construct(self, x):
  19. x = self.conv1(x)
  20. x = self.relu(x)
  21. x = self.max_pool2d(x)
  22. x = self.conv2(x)
  23. x = self.relu(x)
  24. x = self.max_pool2d(x)
  25. x = self.flatten(x)
  26. x = self.fc1(x)
  27. x = self.relu(x)
  28. x = self.fc2(x)
  29. x = self.relu(x)
  30. return x
  31. return LeNet5Embed()