You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 1.7 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import argparse
  3. import logging
  4. import sys
  5. from collections import OrderedDict
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.optim as optim
  10. from torchvision import datasets, transforms
  11. from pytorch.mutables import LayerChoice, InputChoice
  12. from mutator import ClassicMutator
  13. import numpy as np
  14. class Net(nn.Module):
  15. def __init__(self, hidden_size):
  16. super(Net, self).__init__()
  17. # two options of conv1
  18. self.conv1 = LayerChoice(OrderedDict([
  19. ("conv5x5", nn.Conv2d(1, 20, 5, 1)),
  20. ("conv3x3", nn.Conv2d(1, 20, 3, 1))
  21. ]), key='conv1')
  22. # two options of mid_conv
  23. self.mid_conv = LayerChoice(OrderedDict([
  24. ("conv3x3",nn.Conv2d(20, 20, 3, 1, padding=1)),
  25. ("conv5x5",nn.Conv2d(20, 20, 5, 1, padding=2))
  26. ]), key='mid_conv')
  27. self.conv2 = nn.Conv2d(20, 50, 5, 1)
  28. self.fc1 = nn.Linear(4*4*50, hidden_size)
  29. self.fc2 = nn.Linear(hidden_size, 10)
  30. # skip connection over mid_conv
  31. self.input_switch = InputChoice(n_candidates=2,
  32. n_chosen=1,
  33. key='skip')
  34. def forward(self, x):
  35. x = F.relu(self.conv1(x))
  36. x = F.max_pool2d(x, 2, 2)
  37. old_x = x
  38. x = F.relu(self.mid_conv(x))
  39. zero_x = torch.zeros_like(old_x)
  40. skip_x = self.input_switch([zero_x, old_x])
  41. x = torch.add(x, skip_x)
  42. x = F.relu(self.conv2(x))
  43. x = F.max_pool2d(x, 2, 2)
  44. x = x.view(-1, 4*4*50)
  45. x = F.relu(self.fc1(x))
  46. x = self.fc2(x)
  47. return F.log_softmax(x, dim=1)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能