Browse Source

feat(mge/module): add module for extern-c-opr

GitOrigin-RevId: a2d9fa067a
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
7a0c7ef45c
1 changed files with 13 additions and 19 deletions
  1. +13
    -19
      sdk/c-opr-loaders/mace/dump_model.py

+ 13
- 19
sdk/c-opr-loaders/mace/dump_model.py View File

@@ -8,9 +8,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse


import megengine._internal as mgb
import numpy as np import numpy as np
import yaml import yaml
from megengine import jit
from megengine.module.external import ExternOprSubgraph




# "1,3,224,224" -> (1,3,224,224) # "1,3,224,224" -> (1,3,224,224)
@@ -89,26 +90,19 @@ def main():
+ raw_param + raw_param
) )


# cn not ensured
cn = mgb.comp_node("xpux")
cg = mgb.comp_graph()

inp = [
mgb.make_shared(
comp_node=cn,
comp_graph=cg,
shape=isizes[i],
name=input_names[i],
dtype=np.float32,
)
for i in range(len(isizes))
]
net = ExternOprSubgraph(wk_raw_content, "mace", osizes)
net.eval()


oup = mgb.opr.extern_c_opr_placeholder(
inp, osizes, dump_name="mace", dump_data=wk_raw_content,
)
@jit.trace(symbolic=True)
def inference(inputs):
return net(inputs)

inputs = [
np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes))
]


mgb.serialize_comp_graph_to_file(args.output, oup)
inference.trace(inputs)
inference.dump(args.output)




if __name__ == "__main__": if __name__ == "__main__":


Loading…
Cancel
Save