Browse Source

add code

pull/90/head
huxiaoman 4 years ago
parent
commit
c2bb60507e
8 changed files with 1 additions and 37 deletions
  1. +0
    -0
      mindspore-jina/MindsporeLeNet/Dockerfile
  2. +0
    -0
      mindspore-jina/MindsporeLeNet/config.yml
  3. +0
    -0
      mindspore-jina/MindsporeLeNet/lenet/.keep
  4. +0
    -0
      mindspore-jina/MindsporeLeNet/manifest.yml
  5. +1
    -0
      mindspore-jina/MindsporeLeNet/requirements.txt
  6. +0
    -0
      mindspore-jina/MindsporeLeNet/tests/__init__.py
  7. +0
    -0
      mindspore-jina/MindsporeLeNet/tests/test_mindsporelenet.py
  8. +0
    -37
      mindspore-jina/__init__.py

mindspore-jina/Dockerfile → mindspore-jina/MindsporeLeNet/Dockerfile View File


mindspore-jina/config.yml → mindspore-jina/MindsporeLeNet/config.yml View File


mindspore-jina/lenet/.keep → mindspore-jina/MindsporeLeNet/lenet/.keep View File


mindspore-jina/manifest.yml → mindspore-jina/MindsporeLeNet/manifest.yml View File


+ 1
- 0
mindspore-jina/MindsporeLeNet/requirements.txt View File

@@ -0,0 +1 @@
jina

mindspore-jina/tests/__init__.py → mindspore-jina/MindsporeLeNet/tests/__init__.py View File


mindspore-jina/tests/test_mindsporelenet.py → mindspore-jina/MindsporeLeNet/tests/test_mindsporelenet.py View File


+ 0
- 37
mindspore-jina/__init__.py View File

@@ -1,37 +0,0 @@
import numpy as np
from jina.executors.encoders.frameworks import BaseMindsporeEncoder


class MindsporeLeNet(BaseMindsporeEncoder):
"""
:class:`MindsporeLeNet` Encoding image into vectors using mindspore.
"""

def encode(self, data, *args, **kwargs):
# data is B x D, where D = 28 * 28
# LeNet only accepts BCHW format where H=W=32
# hence we need to do some simple transform
from mindspore import Tensor

data = np.pad(data.reshape([-1, 1, 28, 28]),
[(0, 0), (0, 0), (0, 4), (0, 4)]).astype('float32')
return self.model(Tensor(data)).asnumpy()

def get_cell(self):
from .lenet.src.lenet import LeNet5
class LeNet5Embed(LeNet5):
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
return x

return LeNet5Embed()

Loading…
Cancel
Save