diff --git a/sdk/c-opr-loaders/mace/dump_model.py b/sdk/c-opr-loaders/mace/dump_model.py index 19959722..cdc7afee 100644 --- a/sdk/c-opr-loaders/mace/dump_model.py +++ b/sdk/c-opr-loaders/mace/dump_model.py @@ -30,7 +30,7 @@ def main(): parser.add_argument("--input", help="mace model file") parser.add_argument("--param", help="mace param file") parser.add_argument( - "--output", help="converted model that can be fed to dump_with_testcase_mge.py" + "--output", help="converted mge model" ) parser.add_argument("--config", help="config file with yaml format") args = parser.parse_args() diff --git a/sdk/xor-deploy/xornet.py b/sdk/xor-deploy/xornet.py index 5608354f..57104205 100644 --- a/sdk/xor-deploy/xornet.py +++ b/sdk/xor-deploy/xornet.py @@ -75,8 +75,8 @@ def main(): for step, minibatch in enumerate(train_dataset): if step > 1000: break - data = minibatch["data"] - label = minibatch["label"] + data = mge.tensor(minibatch["data"]) + label = mge.tensor(minibatch["label"]) net.train() _, loss = train_fun(data, label) train_loss.append((step, loss.numpy())) @@ -128,6 +128,11 @@ def main(): print("Dump model as {}".format(model_name)) pred_fun.dump(model_name, arg_names=["data"]) + model_with_testcase_name = "xornet_with_testcase.mge" + + print("Dump model with testcase as {}".format(model_with_testcase_name)) + pred_fun.dump(model_with_testcase_name, arg_names=["data"], input_data=["#rand(0.1, 0.8, 4, 2)"]) + if __name__ == "__main__": main()