Browse Source

Improve demo code

savefigrue
bushuhui 3 years ago
parent
commit
5c96105979
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      demo_code/3_CNN_MNIST.py

+ 4
- 0
demo_code/3_CNN_MNIST.py View File

@@ -63,12 +63,16 @@ for e in range(100):
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data), Variable(target) data, target = Variable(data), Variable(target)


# inference, loss calculation
out = model(data) out = model(data)
loss = criterion(out, target) loss = criterion(out, target)
# backward, optimize
optim.zero_grad() optim.zero_grad()
loss.backward() loss.backward()
optim.step() optim.step()


# print loss, acc
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
pred = out.data.max(1, keepdim=True)[1] pred = out.data.max(1, keepdim=True)[1]
c = float(pred.eq(target.data.view_as(pred)).cpu().sum() ) /out.size(0) c = float(pred.eq(target.data.view_as(pred)).cpu().sum() ) /out.size(0)


Loading…
Cancel
Save