@@ -0,0 +1,33 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
network config setting, will be used in train.py | |||
""" | |||
from easydict import EasyDict as edict | |||
mnist_cfg = edict({ | |||
'num_classes': 10, | |||
'lr': 0.01, | |||
'momentum': 0.9, | |||
'epoch_size': 1, | |||
'batch_size': 32, | |||
'buffer_size': 1000, | |||
'image_height': 32, | |||
'image_width': 32, | |||
'save_checkpoint_steps': 1875, | |||
'keep_checkpoint_max': 10, | |||
'air_name': "lenet.air", | |||
}) |
@@ -0,0 +1,60 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Produce the dataset | |||
""" | |||
import mindspore.dataset as ds | |||
import mindspore.dataset.vision.c_transforms as CV | |||
import mindspore.dataset.transforms.c_transforms as C | |||
from mindspore.dataset.vision import Inter | |||
from mindspore.common import dtype as mstype | |||
def create_dataset(data_path, batch_size=32, repeat_size=1, | |||
num_parallel_workers=1): | |||
""" | |||
create dataset for train or test | |||
""" | |||
# define dataset | |||
mnist_ds = ds.MnistDataset(data_path) | |||
resize_height, resize_width = 32, 32 | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
rescale_nml = 1 / 0.3081 | |||
shift_nml = -1 * 0.1307 / 0.3081 | |||
# define map operations | |||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode | |||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) | |||
rescale_op = CV.Rescale(rescale, shift) | |||
hwc2chw_op = CV.HWC2CHW() | |||
type_cast_op = C.TypeCast(mstype.int32) | |||
# apply map operations on images | |||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) | |||
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) | |||
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) | |||
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) | |||
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) | |||
# apply DatasetOps | |||
buffer_size = 10000 | |||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script | |||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||
mnist_ds = mnist_ds.repeat(repeat_size) | |||
return mnist_ds |
@@ -0,0 +1,61 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
"""LeNet.""" | |||
import mindspore.nn as nn | |||
from mindspore.common.initializer import Normal | |||
class LeNet5(nn.Cell): | |||
""" | |||
Lenet network | |||
Args: | |||
num_class (int): Number of classes. Default: 10. | |||
num_channel (int): Number of channels. Default: 1. | |||
Returns: | |||
Tensor, output tensor | |||
Examples: | |||
>>> LeNet(num_class=10) | |||
""" | |||
def __init__(self, num_class=10, num_channel=1, include_top=True): | |||
super(LeNet5, self).__init__() | |||
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') | |||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||
self.relu = nn.ReLU() | |||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
self.include_top = include_top | |||
if self.include_top: | |||
self.flatten = nn.Flatten() | |||
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) | |||
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) | |||
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
if not self.include_top: | |||
return x | |||
x = self.flatten(x) | |||
x = self.relu(self.fc1(x)) | |||
x = self.relu(self.fc2(x)) | |||
x = self.fc3(x) | |||
return x |
@@ -0,0 +1,65 @@ | |||
# Copyright 2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
######################## train lenet example ######################## | |||
train lenet and get network model files(.ckpt) : | |||
python train.py --data_path /YourDataPath | |||
""" | |||
import os | |||
import argparse | |||
from src.config import mnist_cfg as cfg | |||
from src.dataset import create_dataset | |||
from src.lenet import LeNet5 | |||
import mindspore.nn as nn | |||
from mindspore import context | |||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
from mindspore.train import Model | |||
from mindspore.nn.metrics import Accuracy | |||
from mindspore.common import set_seed | |||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
help='device where the code will be implemented (default: Ascend)') | |||
parser.add_argument('--data_path', type=str, default="./Data", | |||
help='path where the dataset is saved') | |||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
path where the trained ckpt file') | |||
args = parser.parse_args() | |||
set_seed(1) | |||
if __name__ == "__main__": | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) | |||
if ds_train.get_dataset_size() == 0: | |||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||
network = LeNet5(cfg.num_classes) | |||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck) | |||
if args.device_target != "Ascend": | |||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
else: | |||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2") | |||
print("============== Starting Training ==============") | |||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) |