Browse Source

Fix version problem of demo code

pull/10/MERGE
bushuhui 3 years ago
parent
commit
1fdaa0e772
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      demo_code/3_NN_FC_1.py

+ 2
- 2
demo_code/3_NN_FC_1.py View File

@@ -90,7 +90,7 @@ def train(epoch):
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
print("Train epoch: %6d [%6d/%6d (%.0f %%)] \t Loss: %.6f" % ( print("Train epoch: %6d [%6d/%6d (%.0f %%)] \t Loss: %.6f" % (
epoch, batch_idx * len(data), len(train_loader.dataset), epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]) )
100. * batch_idx / len(train_loader), loss.item()) )




def test(): def test():
@@ -103,7 +103,7 @@ def test():
output = model(data) output = model(data)


# sum up batch loss # sum up batch loss
test_loss += criterion(output, target).data[0]
test_loss += criterion(output, target).item()


# get the index of the max # get the index of the max
pred = output.data.max(1, keepdim=True)[1] pred = output.data.max(1, keepdim=True)[1]


Loading…
Cancel
Save