From 6cab1dd78fa9cd915fb1ada4007d1c0cd904a39f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 17 Sep 2020 19:35:55 +0800 Subject: [PATCH] refactor(mge/sdk): update xor-deploy GitOrigin-RevId: 372c37cdc5116834b47344c16c2a443c6e7ebfdb --- sdk/xor-deploy/xornet.py | 75 +++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/sdk/xor-deploy/xornet.py b/sdk/xor-deploy/xornet.py index a032ef56..04835f73 100644 --- a/sdk/xor-deploy/xornet.py +++ b/sdk/xor-deploy/xornet.py @@ -1,6 +1,7 @@ import numpy as np import megengine as mge +import megengine.autodiff as ad import megengine.functional as F import megengine.module as M import megengine.optimizer as optim @@ -35,57 +36,54 @@ class XORNet(M.Module): return x -@trace(symbolic=True) -def train_fun(data, label, net=None, opt=None): - net.train() - pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) - opt.backward(loss) - return pred, loss - - -@trace(symbolic=True) -def val_fun(data, label, net=None): - net.eval() - pred = net(data) - loss = F.cross_entropy_with_softmax(pred, label) - return pred, loss - - -@trace(symbolic=True) -def pred_fun(data, net=None): - net.eval() - pred = net(data) - pred_normalized = F.softmax(pred) - return pred_normalized - - def main(): if not mge.is_cuda_available(): mge.set_default_device("cpux") net = XORNet() + gm = ad.GradManager().attach(net.parameters()) opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) batch_size = 64 train_dataset = minibatch_generator(batch_size) val_dataset = minibatch_generator(batch_size) - data = mge.tensor() - label = mge.tensor(np.zeros((batch_size,)), dtype=np.int32) + def train_fun(data, label): + opt.clear_grad() + with gm: + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + gm.backward(loss) + opt.step() + return pred, loss + + def val_fun(data, label): + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + return pred, loss + + @trace(symbolic=True, capture_as_const=True) + def pred_fun(data): + pred = net(data) + pred_normalized = F.softmax(pred) + return pred_normalized + + data = np.random.random((batch_size, 2)).astype(np.float32) + label = np.zeros((batch_size,)).astype(np.int32) train_loss = [] val_loss = [] for step, minibatch in enumerate(train_dataset): if step > 1000: break - data.set_value(minibatch["data"]) - label.set_value(minibatch["label"]) - opt.zero_grad() - _, loss = train_fun(data, label, net=net, opt=opt) + data = minibatch["data"] + label = minibatch["label"] + net.train() + _, loss = train_fun(data, label) train_loss.append((step, loss.numpy())) if step % 50 == 0: minibatch = next(val_dataset) - _, loss = val_fun(data, label, net=net) + net.eval() + _, loss = val_fun(data, label) loss = loss.numpy()[0] val_loss.append((step, loss)) print("Step: {} loss={}".format(step, loss)) @@ -108,8 +106,10 @@ def main(): ] ) - data.set_value(test_data) - out = pred_fun(data, net=net) + # tracing only accepts tensor as input + data = mge.tensor(test_data, dtype=np.float32) + net.eval() + out = pred_fun(data) pred_output = out.numpy() pred_label = np.argmax(pred_output, 1) @@ -125,11 +125,8 @@ def main(): model_name = "xornet_deploy.mge" - if pred_fun.enabled: - print("Dump model as {}".format(model_name)) - pred_fun.dump(model_name, arg_names=["data"]) - else: - print("pred_fun must be run with trace enabled in order to dump model") + print("Dump model as {}".format(model_name)) + pred_fun.dump(model_name, arg_names=["data"]) if __name__ == "__main__":