From 5c961059798d6d24a0ec63d546c7f1596239b093 Mon Sep 17 00:00:00 2001 From: bushuhui Date: Fri, 17 Dec 2021 10:42:05 +0800 Subject: [PATCH] Improve demo code --- demo_code/3_CNN_MNIST.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/demo_code/3_CNN_MNIST.py b/demo_code/3_CNN_MNIST.py index 73d7f26..4abbfdb 100644 --- a/demo_code/3_CNN_MNIST.py +++ b/demo_code/3_CNN_MNIST.py @@ -63,12 +63,16 @@ for e in range(100): for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data), Variable(target) + # inference, loss calculation out = model(data) loss = criterion(out, target) + + # backward, optimize optim.zero_grad() loss.backward() optim.step() + # print loss, acc if batch_idx % 100 == 0: pred = out.data.max(1, keepdim=True)[1] c = float(pred.eq(target.data.view_as(pred)).cpu().sum() ) /out.size(0)