From 8d825246e8165cf4f57c34365e40d269873d6abb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 1 Nov 2021 20:11:21 +0800 Subject: [PATCH] fix(sdk/load_and_run): remove docs of dump_with_testcase_with_mge.py GitOrigin-RevId: 4a1138cb55238a44cc988d4e1976bb488695ee55 --- sdk/c-opr-loaders/mace/dump_model.py | 2 +- sdk/xor-deploy/xornet.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sdk/c-opr-loaders/mace/dump_model.py b/sdk/c-opr-loaders/mace/dump_model.py index 19959722..cdc7afee 100644 --- a/sdk/c-opr-loaders/mace/dump_model.py +++ b/sdk/c-opr-loaders/mace/dump_model.py @@ -30,7 +30,7 @@ def main(): parser.add_argument("--input", help="mace model file") parser.add_argument("--param", help="mace param file") parser.add_argument( - "--output", help="converted model that can be fed to dump_with_testcase_mge.py" + "--output", help="converted mge model" ) parser.add_argument("--config", help="config file with yaml format") args = parser.parse_args() diff --git a/sdk/xor-deploy/xornet.py b/sdk/xor-deploy/xornet.py index 5608354f..57104205 100644 --- a/sdk/xor-deploy/xornet.py +++ b/sdk/xor-deploy/xornet.py @@ -75,8 +75,8 @@ def main(): for step, minibatch in enumerate(train_dataset): if step > 1000: break - data = minibatch["data"] - label = minibatch["label"] + data = mge.tensor(minibatch["data"]) + label = mge.tensor(minibatch["label"]) net.train() _, loss = train_fun(data, label) train_loss.append((step, loss.numpy())) @@ -128,6 +128,11 @@ def main(): print("Dump model as {}".format(model_name)) pred_fun.dump(model_name, arg_names=["data"]) + model_with_testcase_name = "xornet_with_testcase.mge" + + print("Dump model with testcase as {}".format(model_with_testcase_name)) + pred_fun.dump(model_with_testcase_name, arg_names=["data"], input_data=["#rand(0.1, 0.8, 4, 2)"]) + if __name__ == "__main__": main()