|
@@ -75,8 +75,8 @@ def main(): |
|
|
for step, minibatch in enumerate(train_dataset): |
|
|
for step, minibatch in enumerate(train_dataset): |
|
|
if step > 1000: |
|
|
if step > 1000: |
|
|
break |
|
|
break |
|
|
data = minibatch["data"] |
|
|
|
|
|
label = minibatch["label"] |
|
|
|
|
|
|
|
|
data = mge.tensor(minibatch["data"]) |
|
|
|
|
|
label = mge.tensor(minibatch["label"]) |
|
|
net.train() |
|
|
net.train() |
|
|
_, loss = train_fun(data, label) |
|
|
_, loss = train_fun(data, label) |
|
|
train_loss.append((step, loss.numpy())) |
|
|
train_loss.append((step, loss.numpy())) |
|
@@ -128,6 +128,11 @@ def main(): |
|
|
print("Dump model as {}".format(model_name)) |
|
|
print("Dump model as {}".format(model_name)) |
|
|
pred_fun.dump(model_name, arg_names=["data"]) |
|
|
pred_fun.dump(model_name, arg_names=["data"]) |
|
|
|
|
|
|
|
|
|
|
|
model_with_testcase_name = "xornet_with_testcase.mge" |
|
|
|
|
|
|
|
|
|
|
|
print("Dump model with testcase as {}".format(model_with_testcase_name)) |
|
|
|
|
|
pred_fun.dump(model_with_testcase_name, arg_names=["data"], input_data=["#rand(0.1, 0.8, 4, 2)"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
main() |