|
|
@@ -63,12 +63,16 @@ for e in range(100): |
|
|
|
for batch_idx, (data, target) in enumerate(train_loader): |
|
|
|
data, target = Variable(data), Variable(target) |
|
|
|
|
|
|
|
# inference, loss calculation |
|
|
|
out = model(data) |
|
|
|
loss = criterion(out, target) |
|
|
|
|
|
|
|
# backward, optimize |
|
|
|
optim.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optim.step() |
|
|
|
|
|
|
|
# print loss, acc |
|
|
|
if batch_idx % 100 == 0: |
|
|
|
pred = out.data.max(1, keepdim=True)[1] |
|
|
|
c = float(pred.eq(target.data.view_as(pred)).cpu().sum() ) /out.size(0) |
|
|
|