GitOrigin-RevId: b092699dee
release-1.7
@@ -74,7 +74,6 @@ option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) | |||||
option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | ||||
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | ||||
option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) | option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) | ||||
option(MGE_BUILD_SDK "Build load_and_run" ON) | |||||
option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | ||||
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | ||||
option(MGE_WITH_ROCM "Enable ROCM support" OFF) | option(MGE_WITH_ROCM "Enable ROCM support" OFF) | ||||
@@ -542,6 +541,8 @@ if(MGE_WITH_TEST) | |||||
include(cmake/gtest.cmake) | include(cmake/gtest.cmake) | ||||
endif() | endif() | ||||
include(cmake/gflags.cmake) | |||||
if(MGE_BUILD_IMPERATIVE_RT) | if(MGE_BUILD_IMPERATIVE_RT) | ||||
set(CMAKE_CXX_STANDARD 17) | set(CMAKE_CXX_STANDARD 17) | ||||
endif() | endif() | ||||
@@ -1147,10 +1148,6 @@ endif() | |||||
add_subdirectory(src) | add_subdirectory(src) | ||||
if(MGE_BUILD_SDK) | |||||
add_subdirectory(sdk/load-and-run) | |||||
endif() | |||||
if(MGE_BUILD_IMPERATIVE_RT) | if(MGE_BUILD_IMPERATIVE_RT) | ||||
add_subdirectory(imperative) | add_subdirectory(imperative) | ||||
message(STATUS "Enable imperative python wrapper runtime") | message(STATUS "Enable imperative python wrapper runtime") | ||||
@@ -0,0 +1 @@ | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/gflags ${CMAKE_CURRENT_BINARY_DIR}/gflags) |
@@ -150,6 +150,9 @@ if(MGE_WITH_TEST) | |||||
add_subdirectory(test) | add_subdirectory(test) | ||||
endif() | endif() | ||||
#load_and_run | |||||
add_subdirectory(load_and_run) | |||||
# tools and example | # tools and example | ||||
add_executable(rc4_encryptor tools/rc4_encrypt.cpp) | add_executable(rc4_encryptor tools/rc4_encrypt.cpp) | ||||
@@ -0,0 +1,38 @@ | |||||
load("//brain/megbrain/lite:flags.bzl","pthread_select") | |||||
cc_library( | |||||
name = "mgblar", | |||||
copts = ["-std=c++14"], | |||||
srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), | |||||
hdrs = glob(["src/**/*.h"]), | |||||
includes = ["src"], | |||||
features = if_opt([ | |||||
"no_exceptions", | |||||
"no_rtti", | |||||
]), | |||||
defines = [ | |||||
"LITE_BUILD_WITH_MGE=1", | |||||
], | |||||
deps = ["//brain/megbrain/lite:lite_static_test"]+ | |||||
pthread_select( | |||||
["@com_github_gflags_gflags//:gflags_nothreads"], | |||||
["//external:gflags"] | |||||
), | |||||
alwayslink = 1, | |||||
visibility = ["//visibility:public"], | |||||
) | |||||
cc_megvii_binary( | |||||
name = "load_and_run", | |||||
copts = ["-std=c++14"], | |||||
srcs = ["src/main.cpp"], | |||||
features = if_opt([ | |||||
"no_exceptions", | |||||
"no_rtti", | |||||
]), | |||||
internal_deps = [":mgblar"], | |||||
visibility = ["//visibility:public"], | |||||
) | |||||
@@ -0,0 +1,29 @@ | |||||
# BUILD the load and run for lite | |||||
include_directories(PUBLIC $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||||
file (GLOB_RECURSE SOURCES ./*.cpp) | |||||
add_executable (load_and_run ${SOURCES}) | |||||
target_link_libraries(load_and_run lite_static) | |||||
target_link_libraries(load_and_run megbrain) | |||||
target_link_libraries(load_and_run gflags) | |||||
if(LITE_BUILD_WITH_RKNPU) | |||||
#rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check | |||||
target_link_options(load_and_run PRIVATE "-fuse-ld=gold") | |||||
endif() | |||||
if(MGE_WITH_ROCM) | |||||
# FIXME: hip obj can not find cpp obj only through lite_static | |||||
target_link_libraries(load_and_run megdnn) | |||||
endif() | |||||
if(UNIX) | |||||
if(APPLE OR ANDROID) | |||||
target_link_libraries(load_and_run dl) | |||||
else() | |||||
target_link_libraries(load_and_run dl rt) | |||||
endif() | |||||
endif() | |||||
install (TARGETS load_and_run EXPORT ${LITE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) |
@@ -0,0 +1,404 @@ | |||||
#!/usr/bin/env mdl | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from megskull.graph import NodeFilter, FpropEnv | |||||
from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization | |||||
from megskull.utils.logconf import get_logger | |||||
from meghair.utils import io | |||||
import megbrain as mgb | |||||
import argparse | |||||
import struct | |||||
import re | |||||
import os | |||||
import numpy as np | |||||
import cv2 | |||||
logger = get_logger(__name__) | |||||
def auto_reformat_image(args, path, data, dst_shape): | |||||
"""reformat image to target shape | |||||
:param data: image data as numpy array | |||||
:param dst_shape: target shape | |||||
""" | |||||
dim3_format = False # required input format does not contain batch | |||||
hwc_format = False # required input format is NHWC | |||||
if len(dst_shape) == 3: | |||||
dst_shape = (1, ) + dst_shape | |||||
dim3_format = True | |||||
assert len(dst_shape) == 4, 'bad dst_shape: {}'.format(dst_shape) | |||||
chl = dst_shape[1] | |||||
if chl in [1, 3]: | |||||
n, c, h, w = dst_shape | |||||
dst_shape = (n, h, w, c) | |||||
else: | |||||
chl = dst_shape[3] | |||||
assert chl in [1, 3], ( | |||||
'can not infer input format from shape: {}'.format(dst_shape)) | |||||
hwc_format = True | |||||
# dst_shape has now been normalized to NHWC format | |||||
if args.resize_input: | |||||
h, w = dst_shape[1:3] | |||||
data = cv2.resize(data, (w, h)) | |||||
logger.info('input {} resized to {}'.format(path, data.shape)) | |||||
if chl == 1: | |||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
data = data[:, :, np.newaxis] | |||||
assert data.ndim == 3 | |||||
data = data[np.newaxis] | |||||
# data normalized to NHWC format | |||||
if not hwc_format: | |||||
data = np.transpose(data, (0, 3, 1, 2)) | |||||
if dim3_format: | |||||
data = np.squeeze(data, 0) | |||||
return data | |||||
def read_input_data(args, dst_shape, dtype, path, repeat): | |||||
def check_shape_equal(dst_shape, data_shape): | |||||
assert len(data_shape) == len(dst_shape) , ( | |||||
'input/data shapes mismatch: {} vs {}'.format( | |||||
dst_shape, data_shape)) | |||||
if data_shape[1:] != dst_shape[1:]: | |||||
logger.warning('dst_shape is {}; data_shape is {}'.format( | |||||
dst_shape, data_shape)) | |||||
if path.startswith('#'): | |||||
assert not args.resize_input | |||||
assert not args.input_transform | |||||
spec = path | |||||
m = re.match( | |||||
r'^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$', spec) | |||||
assert m, 'bad spec {}'.format(spec) | |||||
rng_min = float(m.group(1)) | |||||
rng_max = float(m.group(2)) | |||||
if m.group(3): | |||||
shape_str = m.group(3) | |||||
try: | |||||
shape = shape_str[1:].split(',') | |||||
if shape[-1].strip() == '...': | |||||
shape = shape[:-1] | |||||
shape.extend(list(dst_shape[len(shape):])) | |||||
data_shape = tuple(map(int, shape)) | |||||
except ValueError as e: | |||||
raise ValueError('bad spec {}: {}'.format(spec, e.args)) | |||||
else: | |||||
data_shape = dst_shape | |||||
check_shape_equal(dst_shape, data_shape) | |||||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
# try to load image | |||||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
if data is None: | |||||
assert not args.resize_input | |||||
data = io.load(path) | |||||
assert isinstance(data, np.ndarray) | |||||
else: | |||||
# load image succeeds, so we expect input format is image format | |||||
data = auto_reformat_image(args, path, data, dst_shape) | |||||
data = np.repeat(data, repeat, axis=0) | |||||
if repeat > 1: | |||||
logger.info('repeat input for {} times, data shape is {}'.format( | |||||
repeat, data.shape)) | |||||
check_shape_equal(dst_shape, data.shape) | |||||
if args.input_transform: | |||||
data = eval(args.input_transform, {'data': data, 'np': np}) | |||||
return data | |||||
def gen_one_testcase(args, inputs, spec): | |||||
paths = spec.split(';') | |||||
if len(paths) != len(inputs): | |||||
if len(paths) == 1 and paths[0].startswith('#'): | |||||
paths = ['{}:{}'.format(name, paths[0]) for name in inputs.keys()] | |||||
assert len(paths) == len(inputs), ( | |||||
'required inputs: {}; data paths: {}'.format(inputs.keys(), paths)) | |||||
if len(paths) == 1 and ':' not in paths[0]: | |||||
paths[0] = next(iter(inputs.keys())) + ':' + paths[0] | |||||
ret = {} | |||||
for path in paths: | |||||
var, path = path.split(':') | |||||
if args.repeat: | |||||
repeat = args.repeat | |||||
else: | |||||
repeat = 1 | |||||
ret[var] = read_input_data(args, inputs[var].imm_shape, | |||||
inputs[var].dtype, path, repeat) | |||||
return ret | |||||
def make_feeds(args): | |||||
outputs = io.load_network(args.input).outputs | |||||
if not args.no_assert: | |||||
env = FpropEnv(verbose_fprop=False) | |||||
# set flag so ExternCOprPlaceholder produce expected output | |||||
env.flags.user['extern_c_opr_eval'] = True | |||||
func = env.comp_graph.compile(None, [mgb.copy_output(env.get_mgbvar(i)) | |||||
for i in outputs]) | |||||
def expect_name(var): return 'expect:{}'.format(var.name) | |||||
nf = NodeFilter.make_all_deps(*outputs) | |||||
inputs = {i.name: i for i in nf.data_provider()} | |||||
if args.init_bn: | |||||
for i in nf: | |||||
if isinstance(i, BatchNormalization): | |||||
if i._iter.get_value() == 0: | |||||
i._iter.set_value(1) | |||||
i._variance.set_value(np.ones(i._variance.shape)) | |||||
testcases = [] | |||||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
data_list = [] | |||||
for item in args.data: | |||||
if item.startswith('@'): | |||||
with open(item[1:], 'r') as f: | |||||
data_list.extend([ line.rstrip() for line in f if line.rstrip() != '']) | |||||
else: | |||||
data_list.append(item) | |||||
for inp_spec in data_list: | |||||
cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||||
assert len(cur_testcase) == len(inputs), ( | |||||
'required inputs: {}; given data: {}'.format( | |||||
inputs.keys(), cur_testcase.keys())) | |||||
if not args.no_assert: | |||||
outputs_get = func(**cur_testcase) | |||||
for var, val in zip(outputs, outputs_get): | |||||
cur_testcase[expect_name(var)] = val | |||||
logger.info( | |||||
'generate test groundtruth: var={} shape={} range=({}, {})' | |||||
' mean={} var={}'.format( | |||||
var, val.shape, val.min(), val.max(), | |||||
np.mean(val), np.var(val))) | |||||
testcases.append(cur_testcase) | |||||
logger.info('add testcase: \n {}'.format( | |||||
'\n '.join('{}: shape={} dtype={} range=({:.2f},{:.2f}) ' | |||||
'mean={:.2f} sd={:.2f}'.format( | |||||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), | |||||
np.std(v)) | |||||
for k, v in sorted(cur_testcase.items())))) | |||||
if not args.no_assert: | |||||
def expect_shp(var): | |||||
ret = var.partial_shape.determined_shape | |||||
if ret: | |||||
return ret | |||||
return testcases[0][expect_name(var)].shape | |||||
verbose = not args.silent | |||||
outputs = [AssertEqual(DataProvider(expect_name(i), expect_shp(i), | |||||
dtype=i.dtype, | |||||
comp_node=i.comp_node), | |||||
i, verbose=verbose, maxerr=args.maxerr) | |||||
for i in outputs] | |||||
return {'outputs': outputs, 'testcases': testcases} | |||||
def optimize_for_inference(args, outputs): | |||||
args_map = { | |||||
'enable_io16xc32': 'f16_io_f32_comp', | |||||
'enable_ioc16': 'f16_io_comp', | |||||
'enable_hwcd4': 'use_nhwcd4', | |||||
'enable_nchw4': 'use_nchw4', | |||||
'enable_nchw88': 'use_nchw88', | |||||
'enable_nchw44': 'use_nchw44', | |||||
'enable_nchw44_dot': 'use_nchw44_dot', | |||||
'enable_nchw32': 'use_nchw32', | |||||
'enable_chwn4': 'use_chwn4', | |||||
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||||
'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | |||||
'enable_nchw64': 'use_nchw64', | |||||
'enable_fuse_preprocess': 'fuse_preprocess', | |||||
} | |||||
kwargs = {} | |||||
for k, v in args_map.items(): | |||||
if getattr(args, k): | |||||
assert args.optimize_for_inference, ( | |||||
'optimize_for_inference should be set when {} is given'.format( | |||||
k)) | |||||
kwargs[v] = True | |||||
if args.optimize_for_inference: | |||||
return mgb.optimize_for_inference(outputs, **kwargs) | |||||
return outputs | |||||
def main(): | |||||
parser = argparse.ArgumentParser( | |||||
description='Pack computing graph, input values and expected output ' | |||||
'values into one file for checking correctness. README.md gives more ' | |||||
'details on the usage', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('input', help='input file; see README for details') | |||||
parser.add_argument('-o', '--output', help='output file', required=True) | |||||
parser.add_argument('--init-bn', action='store_true', | |||||
help='initialize untrained batch-normalization, to ' | |||||
'avoid NaN or Inf results') | |||||
parser.add_argument( | |||||
'-d', '--data', default=[], action='append', | |||||
help='Given input test data when input file is a network, ' | |||||
'and current network output would be used as groundtruth. ' | |||||
'The format is var0:file0;var1:file1... to specify data files for ' | |||||
'input vars. It can also be #rand(min,max,shape...) for generating ' | |||||
'random input data, for example, #rand(0,255), ' | |||||
'#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means ' | |||||
'the remaining part of the original shape. ' | |||||
'If the shape is not specified, the shape of ' | |||||
'corresponding DataProvider in the network will be used. ' | |||||
'If there is only one input var, its name can be omitted. ' | |||||
'Each data file can either be an image which can be loaded by opencv, ' | |||||
'or a pickled numpy.ndarray. ' | |||||
'This option can be given multiple times to add multiple testcases. ' | |||||
' *NOTE* ' | |||||
'If you start the data with the letter @, the rest should be a ' | |||||
'filename, and each line in the file should be a single datum in ' | |||||
'the format described above. ' | |||||
) | |||||
parser.add_argument( | |||||
'--repeat', type=int, default=1, | |||||
help='Specify how many times the input image is repeated. ' | |||||
'Useful when running benchmark for batch size other than one. ' | |||||
'Have no effect on randomly generated input data.') | |||||
parser.add_argument('--silent', action='store_true', | |||||
help='set verbose to False in AssertEqual opr') | |||||
parser.add_argument('--optimize-for-inference', action='store_true', | |||||
help='enbale optimization for inference') | |||||
parser.add_argument('--no-assert', action='store_true', | |||||
help='do not insert AssertEqual opr to check result; ' | |||||
'this option is useful for benchmarking') | |||||
parser.add_argument('--maxerr', type=float, default=AssertEqual.maxerr, | |||||
help='max error for AssertEqual check during runtime') | |||||
parser.add_argument('--resize-input', action='store_true', | |||||
help='resize input image to fit input var shape') | |||||
parser.add_argument('--input-transform', | |||||
help='a python expression to transform the input data. ' | |||||
'Example: data / np.std(data)') | |||||
parser.add_argument('--discard-var-name', action='store_true', | |||||
help='discard variable and param names in the ' | |||||
'generated output') | |||||
parser.add_argument('--output-strip-info', action='store_true', | |||||
help='output code strip information') | |||||
parser.add_argument('--enable-io16xc32', action='store_true', | |||||
help='transform the mode to float16 io float32 compute') | |||||
parser.add_argument('--enable-ioc16', action='store_true', | |||||
help='transform the dtype of the model to float16 io ' | |||||
'and compute') | |||||
parser.add_argument('--enable-fuse-conv-bias-nonlinearity', | |||||
action='store_true', | |||||
help='fuse convolution bias and nonlinearity opr to a ' | |||||
'conv_bias opr and compute') | |||||
parser.add_argument('--enable-hwcd4', action='store_true', | |||||
help='transform the model format from NCHW to NHWCD4 ' | |||||
'for inference; you may need to disable CUDA and set ' | |||||
'MGB_USE_MEGDNN_DBG=2') | |||||
parser.add_argument('--enable-nchw4', action='store_true', | |||||
help='transform the model format from NCHW to NCHW4 ' | |||||
'for inference') | |||||
parser.add_argument('--enable-nchw88', action='store_true', | |||||
help='transform the model format from NCHW to NCHW88 ' | |||||
'for inference') | |||||
parser.add_argument('--enable-nchw44', action='store_true', | |||||
help='transform the model format from NCHW to NCHW44 ' | |||||
'for inference') | |||||
parser.add_argument('--enable-nchw44-dot', action='store_true', | |||||
help='transform the model format from NCHW to NCHW44_DOT ' | |||||
'for optimizing armv8.2 dot in inference') | |||||
parser.add_argument('--enable-chwn4', action='store_true', | |||||
help='transform the model format to CHWN4 ' | |||||
'for inference, mainly used for nvidia tensorcore') | |||||
parser.add_argument('--enable-nchw32', action='store_true', | |||||
help='transform the model format from NCHW4 to NCHW32 ' | |||||
'for inference on nvidia TensoCore') | |||||
parser.add_argument('--enable-nchw64', action='store_true', | |||||
help='transform the model format from NCHW to NCHW64 ' | |||||
'for inference on Nvidia GPU') | |||||
parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true', | |||||
help='fuse conv_bias with z input for inference on ' | |||||
'nvidia GPU (this optimization pass will result in mismatch ' | |||||
'of the precision of output of training and inference)') | |||||
parser.add_argument('--enable-fuse-preprocess', action='store_true', | |||||
help='fuse astype\pad_channel\dimshuffle and etc opr ' | |||||
'from h2d op') | |||||
args = parser.parse_args() | |||||
if args.data: | |||||
feeds = make_feeds(args) | |||||
else: | |||||
feeds = io.load(args.input) | |||||
assert isinstance(feeds, dict) and feeds['testcases'], ( | |||||
'testcases can not be empty') | |||||
env = FpropEnv(verbose_fprop=False) | |||||
outputs = feeds['outputs'] | |||||
output_mgbvars = list(map(env.get_mgbvar, outputs)) | |||||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||||
inputs = sorted(((i.name, i.dtype) for i in | |||||
NodeFilter.make_all_deps(*outputs).data_provider())) | |||||
if args.discard_var_name: | |||||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||||
else: | |||||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||||
with open(args.output, 'wb') as fout: | |||||
fout.write(b'mgbtest0') | |||||
fout.write(struct.pack('I', len(feeds['testcases']))) | |||||
stat = mgb.serialize_comp_graph_to_file( | |||||
args.output, output_mgbvars, append=True, | |||||
output_strip_info=args.output_strip_info, | |||||
**sereg_kwargs) | |||||
logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'. | |||||
format(stat.tot_bytes / 1024, | |||||
(stat.tot_bytes - stat.tensor_value_bytes) / 1024)) | |||||
for testcase in feeds['testcases']: | |||||
assert isinstance(testcase, dict) | |||||
cg = mgb.comp_graph() | |||||
cn = mgb.comp_node('cpux') | |||||
output_mgbvars = [] | |||||
for name, dtype in inputs: | |||||
output_mgbvars.append(cg.make_shared(cn, value=testcase.pop(name), | |||||
dtype=dtype)) | |||||
assert not testcase, 'extra inputs provided in testcase: {}'.format( | |||||
testcase.keys()) | |||||
mgb.serialize_comp_graph_to_file( | |||||
args.output, | |||||
output_mgbvars, | |||||
append=True, | |||||
output_strip_info=args.output_strip_info, | |||||
append_json=True) | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,535 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import argparse | |||||
import os | |||||
import re | |||||
import struct | |||||
import cv2 | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.core._imperative_rt as rt | |||||
import megengine.core.tensor.megbrain_graph as G | |||||
from megengine import tensor | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core.ops import builtin | |||||
from megengine.utils import comp_graph_tools as cgtools | |||||
logger = mge.get_logger(__name__) | |||||
def auto_reformat_image(args, path, data, dst_shape): | |||||
"""reformat image to target shape | |||||
:param data: image data as numpy array | |||||
:param dst_shape: target shape | |||||
""" | |||||
dim3_format = False # required input format does not contain batch | |||||
hwc_format = False # required input format is NHWC | |||||
if not dst_shape: # input tensor shape is not predefined | |||||
if len(data.shape) == 2: | |||||
chl = 1 | |||||
h = data.shape[0] | |||||
w = data.shape[1] | |||||
else: | |||||
assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||||
h, w, chl = data.shape | |||||
dst_shape = (1, chl, h, w) | |||||
if len(dst_shape) == 3: | |||||
dst_shape = (1,) + dst_shape | |||||
dim3_format = True | |||||
assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
chl = dst_shape[1] | |||||
if chl in [1, 3]: | |||||
n, c, h, w = dst_shape | |||||
dst_shape = (n, h, w, c) | |||||
else: | |||||
chl = dst_shape[3] | |||||
assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||||
dst_shape | |||||
) | |||||
hwc_format = True | |||||
# dst_shape has now been normalized to NHWC format | |||||
if args.resize_input: | |||||
h, w = dst_shape[1:3] | |||||
data = cv2.resize(data, (w, h)) | |||||
logger.info("input {} resized to {}".format(path, data.shape)) | |||||
if chl == 1: | |||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
data = data[:, :, np.newaxis] | |||||
assert data.ndim == 3 | |||||
data = data[np.newaxis] | |||||
# data normalized to NHWC format | |||||
if not hwc_format: | |||||
data = np.transpose(data, (0, 3, 1, 2)) | |||||
if dim3_format: | |||||
data = np.squeeze(data, 0) | |||||
return data | |||||
def read_input_data(args, dst_shape, dtype, path, repeat): | |||||
def check_shape_equal(dst_shape, data_shape): | |||||
if len(dst_shape): | |||||
assert len(data_shape) == len( | |||||
dst_shape | |||||
), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||||
if data_shape[1:] != dst_shape[1:]: | |||||
logger.warning( | |||||
"dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||||
) | |||||
if path.startswith("#"): | |||||
assert not args.resize_input | |||||
assert not args.input_transform | |||||
spec = path | |||||
m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||||
assert m, "bad spec {}".format(spec) | |||||
rng_min = float(m.group(1)) | |||||
rng_max = float(m.group(2)) | |||||
if m.group(3): | |||||
shape_str = m.group(3) | |||||
try: | |||||
shape = shape_str[1:].split(",") | |||||
if shape[-1].strip() == "...": | |||||
shape = shape[:-1] | |||||
shape.extend(list(dst_shape[len(shape) :])) | |||||
data_shape = tuple(map(int, shape)) | |||||
except ValueError as e: | |||||
raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
else: | |||||
data_shape = dst_shape | |||||
check_shape_equal(dst_shape, data_shape) | |||||
return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
# try to load image | |||||
data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
if data is None: | |||||
assert not args.resize_input | |||||
data = np.load(path) | |||||
assert isinstance(data, np.ndarray) | |||||
else: | |||||
# load image succeeds, so we expect input format is image format | |||||
data = auto_reformat_image(args, path, data, dst_shape) | |||||
data = np.repeat(data, repeat, axis=0) | |||||
if repeat > 1: | |||||
logger.info( | |||||
"repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||||
) | |||||
check_shape_equal(dst_shape, data.shape) | |||||
if args.input_transform: | |||||
data = eval(args.input_transform, {"data": data, "np": np}) | |||||
return data | |||||
def gen_one_testcase(args, inputs, spec): | |||||
paths = spec.split(";") | |||||
if len(paths) != len(inputs): | |||||
if len(paths) == 1 and paths[0].startswith("#"): | |||||
paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||||
inputs.keys(), paths | |||||
) | |||||
if len(paths) == 1 and ":" not in paths[0]: | |||||
paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
ret = {} | |||||
for path in paths: | |||||
var, path = path.split(":") | |||||
if args.repeat: | |||||
repeat = args.repeat | |||||
else: | |||||
repeat = 1 | |||||
ret[var] = read_input_data( | |||||
args, inputs[var].shape, inputs[var].dtype, path, repeat | |||||
) | |||||
return ret | |||||
def make_feeds(args): | |||||
ret = G.load_graph(args.input) | |||||
cg_rt, outputs = ret.graph, ret.output_vars_list | |||||
inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
inputs = {i.name: i for i in inputs} | |||||
if not args.no_assert: | |||||
replace_varmap = {} | |||||
inp_map = {} | |||||
# replace var use InputNode | |||||
for name, var in inputs.items(): | |||||
inp = G.InputNode( | |||||
device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||||
) | |||||
replace_varmap[var] = inp.outputs[0] | |||||
inp_map[name] = inp | |||||
new = cgtools.replace_vars(outputs, replace_varmap) | |||||
if isinstance(new, rt.VarNode): | |||||
new = list(new) | |||||
output_nodes = [G.OutputNode(var) for var in new] | |||||
func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
def calculate(*args, **kwargs): | |||||
output_val = [] | |||||
# set inputs value | |||||
for name, var in inputs.items(): | |||||
val = kwargs.pop(name, None) | |||||
assert val is not None, "miss input name{}".format(name) | |||||
dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
inp_map[name].set_value(dev_tensor) | |||||
func.execute() | |||||
for res in output_nodes: | |||||
output_val.append(res.get_value().numpy()) | |||||
return output_val | |||||
def expect_name(var): | |||||
return "{}:expect".format(var.name) | |||||
testcases = [] | |||||
np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
data_list = [] | |||||
for item in args.data: | |||||
if item.startswith("@"): | |||||
with open(item[1:], "r") as f: | |||||
data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||||
else: | |||||
data_list.append(item) | |||||
for inp_spec in data_list: | |||||
cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||||
assert len(cur_testcase) == len( | |||||
inputs | |||||
), "required inputs: {}; given data: {}".format( | |||||
inputs.keys(), cur_testcase.keys() | |||||
) | |||||
if not args.no_assert: | |||||
outputs_get = calculate(**cur_testcase) | |||||
for var, val in zip(outputs, outputs_get): | |||||
cur_testcase[expect_name(var)] = val | |||||
logger.info( | |||||
"generate test groundtruth: var={} shape={} range=({}, {})" | |||||
" mean={} var={}".format( | |||||
var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||||
) | |||||
) | |||||
testcases.append(cur_testcase) | |||||
logger.info( | |||||
"add testcase: \n {}".format( | |||||
"\n ".join( | |||||
"{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
"mean={:.2f} sd={:.2f}".format( | |||||
k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
) | |||||
for k, v in sorted(cur_testcase.items()) | |||||
) | |||||
) | |||||
) | |||||
if not args.no_assert: | |||||
def expect_shp(var): | |||||
ret = var.shape | |||||
if ret: | |||||
return ret | |||||
return testcases[0][expect_name(var)].shape | |||||
def assert_equal(expect, real, **kwargs): | |||||
op = builtin.AssertEqual(**kwargs) | |||||
(res,) = G.apply_normal_varnode(op, expect, real) | |||||
return res | |||||
verbose = not args.silent | |||||
outputs_new = [] | |||||
for i in outputs: | |||||
device = rt.CompNode("xpux") | |||||
dtype = i.dtype | |||||
name = expect_name(i) | |||||
shape = expect_shp(i) | |||||
# make expect output as one input of model. | |||||
expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||||
# insert assert opr to check expect and real. | |||||
outputs_new.append( | |||||
assert_equal( | |||||
expect_get, | |||||
i, | |||||
verbose=verbose, | |||||
maxerr=args.maxerr, | |||||
) | |||||
) | |||||
inputs[expect_name(i)] = expect_get | |||||
outputs = outputs_new | |||||
return {"outputs": outputs, "testcases": testcases} | |||||
def optimize_for_inference(args, outputs): | |||||
args_list = [ | |||||
"enable_io16xc32", | |||||
"enable_ioc16", | |||||
"enable_hwcd4", | |||||
"enable_nchw4", | |||||
"enable_nchw88", | |||||
"enable_nchw44", | |||||
"enable_nchw44_dot", | |||||
"enable_nchw32", | |||||
"enable_chwn4", | |||||
"enable_fuse_conv_bias_nonlinearity", | |||||
"enable_fuse_conv_bias_with_z", | |||||
"enable_fuse_preprocess", | |||||
] | |||||
kwargs = {} | |||||
for k in args_list: | |||||
if getattr(args, k): | |||||
kwargs[k] = True | |||||
if args.optimize_for_inference: | |||||
outputs = G.optimize_for_inference(outputs, **kwargs) | |||||
return outputs | |||||
def main(): | |||||
parser = argparse.ArgumentParser( | |||||
description="Pack computing graph, input values and expected output " | |||||
"values into one file for checking correctness. README.md gives more " | |||||
"details on the usage", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument("input", help="MegEngine dumped model file") | |||||
parser.add_argument("-o", "--output", help="output file", required=True) | |||||
parser.add_argument( | |||||
"-d", | |||||
"--data", | |||||
default=[], | |||||
action="append", | |||||
required=True, | |||||
help="Given input test data when input file is a network, " | |||||
"and current network output would be used as groundtruth. " | |||||
"The format is var0:file0;var1:file1... to specify data files for " | |||||
"input vars. It can also be #rand(min,max,shape...) for generating " | |||||
"random input data, for example, #rand(0,255), " | |||||
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||||
"the remaining part of the original shape. " | |||||
"If the shape is not specified, the shape of " | |||||
"corresponding input tensors in the network will be used. " | |||||
"If there is only one input var, its name can be omitted. " | |||||
"Each data file can either be an image which can be loaded by opencv, " | |||||
"or a pickled numpy.ndarray. " | |||||
"This option can be given multiple times to add multiple testcases. " | |||||
" *NOTE* " | |||||
"If you start the data with the letter @, the rest should be a " | |||||
"filename, and each line in the file should be a single datum in " | |||||
"the format described above. ", | |||||
) | |||||
parser.add_argument( | |||||
"--repeat", | |||||
type=int, | |||||
default=1, | |||||
help="Specify how many times the input image is repeated. " | |||||
"Useful when running benchmark for batch size other than one. " | |||||
"Have no effect on randomly generated input data.", | |||||
) | |||||
parser.add_argument( | |||||
"--silent", | |||||
action="store_true", | |||||
help="set verbose to False in asserti_equal opr", | |||||
) | |||||
parser.add_argument( | |||||
"--optimize-for-inference", | |||||
action="store_true", | |||||
help="enable optimization for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--no-assert", | |||||
action="store_true", | |||||
help="do not insert assert_equal opr to check result; " | |||||
"this option is useful for benchmarking", | |||||
) | |||||
parser.add_argument( | |||||
"--maxerr", | |||||
type=float, | |||||
default=1e-4, | |||||
help="max error for assert_equal check during runtime", | |||||
) | |||||
parser.add_argument( | |||||
"--resize-input", | |||||
action="store_true", | |||||
help="resize input image to fit input var shape", | |||||
) | |||||
parser.add_argument( | |||||
"--input-transform", | |||||
help="a python expression to transform the input data. " | |||||
"Example: data / np.std(data)", | |||||
) | |||||
parser.add_argument( | |||||
"--discard-var-name", | |||||
action="store_true", | |||||
help="discard variable and param names in the " "generated output", | |||||
) | |||||
parser.add_argument( | |||||
"--output-strip-info", action="store_true", help="output code strip information" | |||||
) | |||||
parser.add_argument( | |||||
"--enable-io16xc32", | |||||
action="store_true", | |||||
help="transform the mode to float16 io float32 compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-ioc16", | |||||
action="store_true", | |||||
help="transform the dtype of the model to float16 io " "and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-nonlinearity", | |||||
action="store_true", | |||||
help="fuse convolution bias and nonlinearity opr to a " | |||||
"conv_bias opr and compute", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-hwcd4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NHWCD4 " | |||||
"for inference; you may need to disable CUDA and set " | |||||
"MGB_USE_MEGDNN_DBG=2", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw4", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW4 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw88", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW88 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44 " "for inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw44-dot", | |||||
action="store_true", | |||||
help="transform the model format from NCHW to NCHW44_DOT " | |||||
"for optimizing armv8.2 dot in inference", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-nchw32", | |||||
action="store_true", | |||||
help="transform the model format from NCHW4 to NCHW32 " | |||||
"for inference on nvidia TensoCore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-chwn4", | |||||
action="store_true", | |||||
help="transform the model format to CHWN4 " | |||||
"for inference, mainly used for nvidia tensorcore", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-conv-bias-with-z", | |||||
action="store_true", | |||||
help="fuse conv_bias with z input for inference on " | |||||
"nvidia GPU (this optimization pass will result in mismatch " | |||||
"of the precision of output of training and inference)", | |||||
) | |||||
parser.add_argument( | |||||
"--enable-fuse-preprocess", | |||||
action="store_true", | |||||
help="fuse astype\pad_channel\dimshuffle and etc opr " | |||||
"from h2d opr", | |||||
) | |||||
args = parser.parse_args() | |||||
feeds = make_feeds(args) | |||||
assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||||
output_mgbvars = feeds["outputs"] | |||||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
if args.discard_var_name: | |||||
sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||||
else: | |||||
sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||||
strip_info_file = args.output + ".json" if args.output_strip_info else None | |||||
with open(args.output, "wb") as fout: | |||||
fout.write(b"mgbtest0") | |||||
fout.write(struct.pack("I", len(feeds["testcases"]))) | |||||
dump_content, stat = G.dump_graph( | |||||
output_mgbvars, | |||||
append_json=True, | |||||
strip_info_file=strip_info_file, | |||||
**sereg_kwargs, | |||||
) | |||||
fout.write(dump_content) | |||||
logger.info( | |||||
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||||
stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||||
) | |||||
) | |||||
def make_dev_tensor(value, dtype=None, device=None): | |||||
return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
for testcase in feeds["testcases"]: | |||||
assert isinstance(testcase, dict) | |||||
cg = G.Graph() | |||||
output_mgbvars = [] | |||||
for name, dtype in inputs: | |||||
output_mgbvars.append( | |||||
cg.make_const( | |||||
make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||||
) | |||||
) | |||||
assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
testcase.keys() | |||||
) | |||||
with open(args.output, "ab") as fout: | |||||
dump_content, _ = G.dump_graph( | |||||
output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||||
) | |||||
fout.write(dump_content) | |||||
if __name__ == "__main__": | |||||
main() |
@@ -0,0 +1,74 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/common.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include <memory> | |||||
DECLARE_int32(thread); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: state of model running | |||||
*/ | |||||
enum class RunStage { | |||||
BEFORE_MODEL_LOAD = 0, | |||||
AFTER_MODEL_LOAD = 1, | |||||
BEFORE_OUTSPEC_SET = 2, | |||||
//! using for dump static memory information svg file | |||||
AFTER_OUTSPEC_SET = 3, | |||||
//! using for external c opr library | |||||
MODEL_RUNNING = 4, | |||||
//! using for output dumper | |||||
AFTER_RUNNING_WAIT = 5, | |||||
//! using for external c opr library | |||||
AFTER_RUNNING_ITER = 6, | |||||
AFTER_MODEL_RUNNING = 7, | |||||
}; | |||||
/*! | |||||
* \brief: type of different model | |||||
*/ | |||||
enum class ModelType { | |||||
LITE_MODEL = 0, | |||||
MEGDL_MODEL, | |||||
UNKNOWN, | |||||
}; | |||||
/*! | |||||
* \brief: param for running model | |||||
*/ | |||||
struct RuntimeParam { | |||||
RunStage stage = RunStage::AFTER_MODEL_LOAD; | |||||
size_t warmup_iter; //! warm up number before running model | |||||
size_t run_iter; //! iteration number for running model | |||||
size_t threads = FLAGS_thread; //! thread number for running model (NOTE:it's | |||||
//! different from multithread device ) | |||||
size_t testcase_num = 1; //! testcase number for model with testcase | |||||
}; | |||||
/*! | |||||
* \brief:layout type for running model optimization | |||||
*/ | |||||
enum class OptLayoutType { | |||||
NCHW4 = 1 << 0, | |||||
CHWN4 = 1 << 1, | |||||
NCHW44 = 1 << 2, | |||||
NCHW88 = 1 << 3, | |||||
NCHW32 = 1 << 4, | |||||
NCHW64 = 1 << 5, | |||||
NHWCD4 = 1 << 6, | |||||
NCHW44_DOT = 1 << 7 | |||||
}; | |||||
} // namespace lar | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,266 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/data_parser.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "data_parser.h" | |||||
#include <sstream> | |||||
#include "json_loader.h" | |||||
#include "npy.h" | |||||
using namespace lar; | |||||
/*! | |||||
* \brief feed different data to diffferent parser | |||||
* \param path data file path or data string | |||||
*/ | |||||
void DataParser::feed(const std::string& path) { | |||||
std::string blob_name = "data", blob_string = path; | |||||
size_t sep = path.find(":"); | |||||
if (sep != std::string::npos) { | |||||
blob_name = path.substr(0, sep); | |||||
blob_string = path.substr(sep + 1); | |||||
} | |||||
auto endWith = [blob_string](std::string suffix) -> bool { | |||||
return blob_string.rfind(suffix) == (blob_string.length() - suffix.length()); | |||||
}; | |||||
if (endWith(".ppm") || endWith(".pgm")) { | |||||
parse_image(blob_name, blob_string); | |||||
} else if (endWith(".json")) { | |||||
parse_json(blob_string); | |||||
} else if (endWith(".npy")) { | |||||
parse_npy(blob_name, blob_string); | |||||
} else { | |||||
parse_string(blob_name, blob_string); | |||||
} | |||||
} | |||||
void DataParser::parse_json(const std::string& path) { | |||||
mgb::JsonLoader json; | |||||
std::shared_ptr<mgb::JsonLoader::Value> root = json.load(path.c_str()); | |||||
mgb_assert(root != nullptr, "parse json %s fail", path.c_str()); | |||||
// parse json to data map | |||||
const std::string SHAPE = "shape", TYPE = "type", RAW = "raw"; | |||||
for (auto& item : root->objects()) { | |||||
auto&& value = *item.second; | |||||
auto&& shape = value[SHAPE]; | |||||
mgb_assert(shape->is_array()); | |||||
auto&& type = value[TYPE]; | |||||
mgb_assert(type->is_str()); | |||||
auto&& raw = value[RAW]; | |||||
mgb_assert(raw->is_array()); | |||||
megdnn::SmallVector<size_t> data_shape; | |||||
for (auto&& shape_ptr : shape->array()) { | |||||
data_shape.append({static_cast<size_t>(std::round(shape_ptr->number()))}); | |||||
} | |||||
// get type | |||||
const std::map<std::string, megdnn::DType> type_map = { | |||||
{"float32", mgb::dtype::Float32()}, {"float", mgb::dtype::Float32()}, | |||||
{"int32", mgb::dtype::Int32()}, {"int", mgb::dtype::Int32()}, | |||||
{"int8", mgb::dtype::Int8()}, {"uint8", mgb::dtype::Uint8()}}; | |||||
const std::string& type_str = type->str(); | |||||
mgb_assert( | |||||
type_map.find(type_str) != type_map.end(), | |||||
"unknown json data type for --input"); | |||||
mgb::DType datatype = type_map.at(type_str); | |||||
mgb::HostTensorND hv; | |||||
hv.comp_node(mgb::CompNode::default_cpu(), true) | |||||
.dtype(datatype) | |||||
.resize(data_shape); | |||||
mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||||
size_t elem_size = datatype.size(); | |||||
// get raw | |||||
const size_t array_size = raw->len(); | |||||
for (size_t idx = 0; idx < array_size; ++idx) { | |||||
double tmp = (*raw)[idx]->number(); | |||||
switch (datatype.enumv()) { | |||||
case megdnn::DTypeEnum::Int32: { | |||||
int32_t ival = std::round(tmp); | |||||
memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); | |||||
} break; | |||||
case megdnn::DTypeEnum::Uint8: | |||||
case megdnn::DTypeEnum::Int8: { | |||||
int8_t cval = std::round(tmp); | |||||
memcpy(((char*)raw_ptr) + idx, &cval, sizeof(int8_t)); | |||||
} break; | |||||
case megdnn::DTypeEnum::Float32: { | |||||
float fval = tmp; | |||||
memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); | |||||
} break; | |||||
default: | |||||
break; | |||||
} | |||||
} | |||||
inputs.insert(std::make_pair(item.first, std::move(hv))); | |||||
} | |||||
} | |||||
void DataParser::parse_image(const std::string& name, const std::string& path) { | |||||
// load binary ppm/pgm | |||||
std::ifstream fin; | |||||
fin.open(path, std::ifstream::binary | std::ifstream::in); | |||||
mgb_assert(fin.is_open(), "open file %s failed for --input", path.c_str()); | |||||
size_t w = 0, h = 0, channel = 0; | |||||
char buf[128] = {0}; | |||||
fin.getline(buf, 128); | |||||
if ('5' == buf[1]) { | |||||
channel = 1; | |||||
} else if ('6' == buf[1]) { | |||||
channel = 3; | |||||
} else { | |||||
mgb_assert(0, "not a formal ppm/pgm"); | |||||
} | |||||
while (fin.getline(buf, 128)) { | |||||
if (buf[0] == '#') { | |||||
continue; | |||||
} | |||||
break; | |||||
} | |||||
std::stringstream ss; | |||||
ss << std::string(buf); | |||||
ss >> w; | |||||
ss >> h; | |||||
mgb_assert(w > 0 and h > 0); | |||||
mgb::HostTensorND hv; | |||||
hv.comp_node(mgb::CompNode::default_cpu(), true) | |||||
.dtype(mgb::dtype::Uint8()) | |||||
.resize({1, h, w, channel}); | |||||
fin.read((char*)(hv.raw_ptr()), hv.layout().total_nr_elems()); | |||||
fin.close(); | |||||
inputs.insert(std::make_pair(name, std::move(hv))); | |||||
} | |||||
void DataParser::parse_npy(const std::string& name, const std::string& path) { | |||||
std::string type_str; | |||||
std::vector<npy::ndarray_len_t> stl_shape; | |||||
std::vector<int8_t> raw; | |||||
npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw); | |||||
megdnn::SmallVector<size_t> shape; | |||||
for (auto val : stl_shape) { | |||||
shape.append({static_cast<size_t>(val)}); | |||||
} | |||||
const std::map<std::string, megdnn::DType> type_map = { | |||||
{"f4", mgb::dtype::Float32()}, {"i4", mgb::dtype::Int32()}, | |||||
{"i2", mgb::dtype::Int16()}, {"u2", mgb::dtype::Uint16()}, | |||||
{"i1", mgb::dtype::Int8()}, {"u1", mgb::dtype::Uint8()}}; | |||||
megdnn::DType hv_type; | |||||
for (auto& item : type_map) { | |||||
if (type_str.find(item.first) != std::string::npos) { | |||||
hv_type = item.second; | |||||
break; | |||||
} | |||||
} | |||||
mgb::HostTensorND hv; | |||||
hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(hv_type).resize(shape); | |||||
mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||||
memcpy(raw_ptr, raw.data(), raw.size()); | |||||
inputs.insert(std::make_pair(name, std::move(hv))); | |||||
} | |||||
void DataParser::parse_string(const std::string name, const std::string& str) { | |||||
// data type | |||||
megdnn::DType data_type = mgb::dtype::Int32(); | |||||
if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | |||||
data_type = mgb::dtype::Float32(); | |||||
} | |||||
// shape | |||||
size_t number_cnt = 0; | |||||
std::shared_ptr<Brace> brace_root = std::make_shared<Brace>(); | |||||
std::shared_ptr<Brace> cur = brace_root; | |||||
for (size_t i = 0; i < str.size(); ++i) { | |||||
char c = str[i]; | |||||
if (c == '[') { | |||||
std::shared_ptr<Brace> child = std::make_shared<Brace>(); | |||||
child->parent = cur; | |||||
cur->chidren.emplace_back(child); | |||||
cur = child; | |||||
} else if (c == ']') { | |||||
cur = cur->parent.lock(); | |||||
} else if (c == ',') { | |||||
number_cnt++; | |||||
} | |||||
continue; | |||||
} | |||||
++number_cnt; | |||||
mgb_assert(cur == brace_root, "braces not closed for --input"); | |||||
megdnn::SmallVector<size_t> shape; | |||||
cur = brace_root; | |||||
while (not cur->chidren.empty()) { | |||||
shape.append({cur->chidren.size()}); | |||||
number_cnt /= cur->chidren.size(); | |||||
cur = cur->chidren[0]; | |||||
} | |||||
mgb_assert(number_cnt > 0); | |||||
shape.append({number_cnt}); | |||||
// data | |||||
std::string json_arr; | |||||
for (size_t i = 0; i < str.size(); ++i) { | |||||
char c = str[i]; | |||||
if (c != '[' and c != ']') { | |||||
json_arr += c; | |||||
} | |||||
} | |||||
json_arr = "[" + json_arr + "]"; | |||||
// reuse json parser to resolve raw data | |||||
mgb::JsonLoader json; | |||||
std::shared_ptr<mgb::JsonLoader::Value> json_root = | |||||
json.load(json_arr.data(), json_arr.size()); | |||||
mgb_assert(json_root != nullptr, "parse json fail in parse_string"); | |||||
mgb::HostTensorND hv; | |||||
hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(data_type).resize(shape); | |||||
mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||||
const size_t array_len = json_root->len(); | |||||
const size_t elem_size = data_type.size(); | |||||
for (size_t idx = 0; idx < array_len; ++idx) { | |||||
double tmp = json_root->array()[idx]->number(); | |||||
switch (data_type.enumv()) { | |||||
case megdnn::DTypeEnum::Int32: { | |||||
int32_t ival = std::round(tmp); | |||||
memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); | |||||
} break; | |||||
case megdnn::DTypeEnum::Float32: { | |||||
float fval = tmp; | |||||
memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); | |||||
} break; | |||||
default: | |||||
break; | |||||
} | |||||
} | |||||
inputs.insert(std::make_pair(name, std::move(hv))); | |||||
} |
@@ -0,0 +1,48 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/data_parser.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <memory> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
#include "megbrain/opr/io.h" | |||||
namespace lar { | |||||
/*! | |||||
* \brief data parser for --input | |||||
* support .json|.ppm|.pgm|.npy data and user define data string | |||||
* data string format: [0,0,227,227] | |||||
*/ | |||||
struct DataParser { | |||||
struct Brace { | |||||
std::weak_ptr<Brace> parent; | |||||
std::vector<std::shared_ptr<Brace>> chidren; | |||||
}; | |||||
void feed(const std::string& path); | |||||
std::unordered_map<std::string, mgb::HostTensorND> inputs; | |||||
private: | |||||
//! parser for json data | |||||
void parse_json(const std::string& path); | |||||
//! parser for .ppm .pgm image | |||||
void parse_image(const std::string& name, const std::string& path); | |||||
//! parser for .npy data | |||||
void parse_npy(const std::string& name, const std::string& path); | |||||
//! parser for user define string | |||||
void parse_string(const std::string name, const std::string& str); | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,297 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/json_loader.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "json_loader.h" | |||||
using namespace mgb; | |||||
template <typename T> | |||||
T* JsonLoader::Value::safe_cast() { | |||||
T* ptr = (T*)(this); | |||||
if (nullptr == ptr) { | |||||
fprintf(stderr, "cast ptr is null\n"); | |||||
} | |||||
return ptr; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||||
const std::string& key) { | |||||
mgb_assert(Type::OBJECT == m_type); | |||||
auto t = safe_cast<JsonLoader::ObjectValue>(); | |||||
return t->m_obj.at(key); | |||||
} | |||||
std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[](const size_t index) { | |||||
mgb_assert(Type::ARRAY == m_type); | |||||
auto t = safe_cast<JsonLoader::ArrayValue>(); | |||||
return t->m_obj[index]; | |||||
} | |||||
std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value:: | |||||
objects() { | |||||
mgb_assert(Type::OBJECT == m_type); | |||||
auto t = safe_cast<JsonLoader::ObjectValue>(); | |||||
return t->m_obj; | |||||
} | |||||
size_t JsonLoader::Value::len() { | |||||
if (Type::ARRAY == m_type) { | |||||
auto t = safe_cast<JsonLoader::ArrayValue>(); | |||||
return t->m_obj.size(); | |||||
} else if (Type::OBJECT == m_type) { | |||||
auto t = safe_cast<JsonLoader::ObjectValue>(); | |||||
return t->m_obj.size(); | |||||
} | |||||
return 0; | |||||
} | |||||
megdnn::SmallVector<std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value::array() { | |||||
mgb_assert(Type::ARRAY == m_type); | |||||
auto t = safe_cast<JsonLoader::ArrayValue>(); | |||||
return t->m_obj; | |||||
} | |||||
double JsonLoader::Value::number() { | |||||
mgb_assert(Type::NUMBER == m_type); | |||||
auto t = safe_cast<JsonLoader::NumberValue>(); | |||||
return t->value(); | |||||
} | |||||
std::string JsonLoader::Value::str() { | |||||
if (Type::STRING == m_type) { | |||||
auto t = safe_cast<StringValue>(); | |||||
return t->value(); | |||||
} | |||||
return std::string(); | |||||
} | |||||
void JsonLoader::expect(char c) { | |||||
mgb_assert(c == (*m_buf)); | |||||
m_buf++; | |||||
} | |||||
void JsonLoader::skip_whitespace() { | |||||
const char* p = m_buf; | |||||
while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { | |||||
++p; | |||||
} | |||||
m_buf = p; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||||
expect('{'); | |||||
skip_whitespace(); | |||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue(); | |||||
if ('}' == *m_buf) { | |||||
m_buf = m_buf + 1; | |||||
ret.reset((JsonLoader::Value*)(pObject)); | |||||
return ret; | |||||
} | |||||
while (true) { | |||||
std::unique_ptr<JsonLoader::Value> key = parse_string(); | |||||
if (m_state != State::OK) { | |||||
return ret; | |||||
} | |||||
skip_whitespace(); | |||||
if (':' != (*m_buf)) { | |||||
m_state = State::MISS_COLON; | |||||
return ret; | |||||
} | |||||
m_buf++; | |||||
skip_whitespace(); | |||||
std::unique_ptr<JsonLoader::Value> pVal = parse_value(); | |||||
if (m_state != State::OK) { | |||||
return ret; | |||||
} | |||||
if (pObject->m_obj.find(pVal->str()) != pObject->m_obj.end()) { | |||||
m_state = State::KEY_NOT_UNIQUE; | |||||
return ret; | |||||
} | |||||
pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); | |||||
skip_whitespace(); | |||||
if (',' == (*m_buf)) { | |||||
m_buf++; | |||||
skip_whitespace(); | |||||
} else if ('}' == (*m_buf)) { | |||||
m_buf++; | |||||
break; | |||||
} else { | |||||
m_state = State::MISS_BRACE; | |||||
break; | |||||
} | |||||
} | |||||
ret.reset((JsonLoader::Value*)(pObject)); | |||||
return ret; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | |||||
expect('['); | |||||
skip_whitespace(); | |||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue(); | |||||
if (']' == *m_buf) { | |||||
m_buf = m_buf + 1; | |||||
ret.reset((JsonLoader::Value*)(pArray)); | |||||
return ret; | |||||
} | |||||
while (true) { | |||||
std::unique_ptr<JsonLoader::Value> pVal = parse_value(); | |||||
if (m_state != State::OK) { | |||||
mgb_assert(0, "parse value failed during pase array"); | |||||
return ret; | |||||
} | |||||
pArray->m_obj.emplace_back(pVal.get()); | |||||
pVal.release(); | |||||
skip_whitespace(); | |||||
if (',' == *m_buf) { | |||||
m_buf++; | |||||
skip_whitespace(); | |||||
} else if (']' == *m_buf) { | |||||
m_buf++; | |||||
break; | |||||
} else { | |||||
m_state = State::BAD_ARRAY; | |||||
return ret; | |||||
} | |||||
} | |||||
ret.reset((JsonLoader::Value*)(pArray)); | |||||
return ret; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | |||||
expect('\"'); | |||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
JsonLoader::StringValue* pStr = new JsonLoader::StringValue(); | |||||
const char* p = m_buf; | |||||
while (true) { | |||||
if (*p == '\"') { | |||||
p++; | |||||
break; | |||||
} else { | |||||
pStr->m_value += (*p); | |||||
p++; | |||||
} | |||||
} | |||||
m_buf = p; | |||||
ret.reset((JsonLoader::Value*)(pStr)); | |||||
return ret; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_number() { | |||||
const char* p = m_buf; | |||||
auto loop_digit = [this](const char*& p) { | |||||
if (not std::isdigit(*p)) { | |||||
m_state = State::BAD_DIGIT; | |||||
return; | |||||
} | |||||
while (std::isdigit(*p)) { | |||||
p++; | |||||
} | |||||
return; | |||||
}; | |||||
if (*p == '-') | |||||
p++; | |||||
if (*p == '0') | |||||
p++; | |||||
else { | |||||
loop_digit(std::ref(p)); | |||||
} | |||||
if (*p == '.') { | |||||
p++; | |||||
loop_digit(std::ref(p)); | |||||
} | |||||
if (*p == 'e' || *p == 'E') { | |||||
p++; | |||||
if (*p == '+' || *p == '-') | |||||
p++; | |||||
loop_digit(std::ref(p)); | |||||
} | |||||
JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue(); | |||||
pNum->m_value = strtod(m_buf, nullptr); | |||||
m_buf = p; | |||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
ret.reset((JsonLoader::Value*)(pNum)); | |||||
return ret; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||||
switch (*m_buf) { | |||||
case '[': | |||||
return parse_array(); | |||||
case '{': | |||||
return parse_object(); | |||||
case '\"': | |||||
return parse_string(); | |||||
case '\0': | |||||
m_state = State::BAD_TYPE; | |||||
break; | |||||
default: | |||||
return parse_number(); | |||||
} | |||||
return nullptr; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::load( | |||||
const char* content, const size_t size) { | |||||
m_buf = content; | |||||
skip_whitespace(); | |||||
std::unique_ptr<JsonLoader::Value> value = parse_value(); | |||||
skip_whitespace(); | |||||
if (m_state != State::OK) { | |||||
return nullptr; | |||||
} | |||||
mgb_assert(size == static_cast<size_t>(m_buf - content)); | |||||
return value; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::load(const char* path) { | |||||
std::unique_ptr<std::FILE, void (*)(std::FILE*)> fin( | |||||
std::fopen(path, "rb"), [](std::FILE* fp) { std::fclose(fp); }); | |||||
mgb_assert(fin.get(), "failed to open %s: %s", path, strerror(errno)); | |||||
std::fseek(fin.get(), 0, SEEK_END); | |||||
const size_t size = ftell(fin.get()); | |||||
std::fseek(fin.get(), 0, SEEK_SET); | |||||
std::unique_ptr<char> buf(static_cast<char*>(malloc(size))); | |||||
auto nr = std::fread(buf.get(), 1, size, fin.get()); | |||||
mgb_assert(nr == size); | |||||
return load(buf.get(), size); | |||||
} |
@@ -0,0 +1,183 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/json_loader.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cctype> | |||||
#include <fstream> | |||||
#include <functional> | |||||
#include <iostream> | |||||
#include <map> | |||||
#include <memory> | |||||
#include "megbrain/common.h" | |||||
#include "megdnn/thin/small_vector.h" | |||||
namespace mgb { | |||||
/*! | |||||
* \brief JSON format data loader for --input | |||||
*/ | |||||
class JsonLoader { | |||||
public: | |||||
// base class for different value format | |||||
class Value { | |||||
protected: | |||||
enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY }; | |||||
Type m_type; | |||||
public: | |||||
template <typename T> | |||||
T* safe_cast(); | |||||
Value() { m_type = Type::UNKNOWN; } | |||||
Value(Type type) : m_type(type) {} | |||||
virtual ~Value() {} | |||||
bool is_array() { return Type::ARRAY == m_type; } | |||||
bool is_object() { return Type::OBJECT == m_type; } | |||||
bool is_number() { return Type::NUMBER == m_type; } | |||||
bool is_str() { return Type::STRING == m_type; } | |||||
std::unique_ptr<Value>& operator[](const std::string& key); | |||||
std::unique_ptr<Value>& operator[](const size_t index); | |||||
std::map<std::string, std::unique_ptr<Value>>& objects(); | |||||
size_t len(); | |||||
megdnn::SmallVector<std::unique_ptr<Value>>& array(); | |||||
double number(); | |||||
std::string str(); | |||||
}; | |||||
void expect(char c); | |||||
void skip_whitespace(); | |||||
std::unique_ptr<Value> parse_object(); | |||||
std::unique_ptr<Value> parse_array(); | |||||
std::unique_ptr<Value> parse_string(); | |||||
std::unique_ptr<Value> parse_number(); | |||||
std::unique_ptr<Value> parse_value(); | |||||
enum struct State : uint8_t { | |||||
OK = 0, | |||||
BAD_TYPE, | |||||
BAD_DIGIT, | |||||
BAD_ARRAY, | |||||
MISS_COLON, | |||||
MISS_BRACE, | |||||
KEY_NOT_UNIQUE | |||||
}; | |||||
JsonLoader() { m_state = State::OK; } | |||||
std::unique_ptr<Value> load(const char* content, const size_t size); | |||||
std::unique_ptr<Value> load(const char* path); | |||||
class NumberValue final : public Value { | |||||
friend std::unique_ptr<Value> JsonLoader::parse_number(); | |||||
double m_value; | |||||
public: | |||||
NumberValue() : Value(Type::NUMBER) {} | |||||
double value() { return m_value; } | |||||
}; | |||||
class StringValue final : public Value { | |||||
std::string m_value; | |||||
public: | |||||
StringValue() : Value(Type::STRING) {} | |||||
std::string value() { return m_value; } | |||||
friend std::unique_ptr<Value> JsonLoader::parse_string(); | |||||
}; | |||||
class ArrayValue final : public Value { | |||||
megdnn::SmallVector<std::unique_ptr<Value>> m_obj; | |||||
public: | |||||
ArrayValue() : Value(Type::ARRAY) {} | |||||
ArrayValue(ArrayValue& arr) : Value(arr) { | |||||
m_obj.clear(); | |||||
for (auto& item : arr.m_obj) { | |||||
m_obj.emplace_back(item.get()); | |||||
item.release(); | |||||
} | |||||
} | |||||
ArrayValue(ArrayValue&& arr) : Value(arr) { | |||||
m_obj.clear(); | |||||
for (auto& item : arr.m_obj) { | |||||
m_obj.emplace_back(item.get()); | |||||
item.release(); | |||||
} | |||||
} | |||||
friend std::unique_ptr<Value> JsonLoader::parse_array(); | |||||
friend std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||||
const size_t index); | |||||
friend megdnn::SmallVector<std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | |||||
Value::array(); | |||||
friend size_t JsonLoader::Value::len(); | |||||
}; | |||||
class ObjectValue final : public Value { | |||||
std::map<std::string, std::unique_ptr<Value>> m_obj; | |||||
public: | |||||
ObjectValue() : Value(Type::OBJECT) {} | |||||
ObjectValue(ObjectValue& arr) : Value(arr) { | |||||
m_obj.clear(); | |||||
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||||
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||||
} | |||||
} | |||||
ObjectValue(ObjectValue&& arr) : Value(arr) { | |||||
m_obj.clear(); | |||||
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||||
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||||
} | |||||
} | |||||
friend std::unique_ptr<Value> JsonLoader::parse_object(); | |||||
friend std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||||
const std::string&); | |||||
friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | |||||
Value::objects(); | |||||
friend size_t JsonLoader::Value::len(); | |||||
}; | |||||
private: | |||||
const char* m_buf; | |||||
State m_state; | |||||
}; | |||||
} // namespace mgb |
@@ -0,0 +1,615 @@ | |||||
/* | |||||
Copyright 2017 Leon Merten Lohse | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
The above copyright notice and this permission notice shall be included in | |||||
all copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||||
SOFTWARE. | |||||
*/ | |||||
#ifndef NPY_H | |||||
#define NPY_H | |||||
#include <algorithm> | |||||
#include <complex> | |||||
#include <cstdint> | |||||
#include <cstring> | |||||
#include <fstream> | |||||
#include <iostream> | |||||
#include <regex> | |||||
#include <sstream> | |||||
#include <stdexcept> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
namespace npy { | |||||
/* Compile-time test for byte order. | |||||
If your compiler does not define these per default, you may want to define | |||||
one of these constants manually. | |||||
Defaults to little endian order. */ | |||||
#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \ | |||||
defined(__BIG_ENDIAN__) || defined(__ARMEB__) || defined(__THUMBEB__) || \ | |||||
defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || \ | |||||
defined(__MIBSEB__) | |||||
const bool big_endian = true; | |||||
#else | |||||
const bool big_endian = false; | |||||
#endif | |||||
const char magic_string[] = "\x93NUMPY"; | |||||
const size_t magic_string_length = 6; | |||||
const char little_endian_char = '<'; | |||||
const char big_endian_char = '>'; | |||||
const char no_endian_char = '|'; | |||||
constexpr char host_endian_char = (big_endian ? big_endian_char : little_endian_char); | |||||
/* npy array length */ | |||||
typedef unsigned long int ndarray_len_t; | |||||
inline void write_magic( | |||||
std::ostream& ostream, unsigned char v_major = 1, unsigned char v_minor = 0) { | |||||
ostream.write(magic_string, magic_string_length); | |||||
ostream.put(v_major); | |||||
ostream.put(v_minor); | |||||
} | |||||
inline void read_magic( | |||||
std::istream& istream, unsigned char& v_major, unsigned char& v_minor) { | |||||
char buf[magic_string_length + 2]; | |||||
istream.read(buf, magic_string_length + 2); | |||||
if (!istream) { | |||||
fprintf(stderr, "io error: failed reading file"); | |||||
} | |||||
if (0 != std::memcmp(buf, magic_string, magic_string_length)) { | |||||
fprintf(stderr, "this file does not have a valid npy format."); | |||||
} | |||||
v_major = buf[magic_string_length]; | |||||
v_minor = buf[magic_string_length + 1]; | |||||
} | |||||
// typestring magic | |||||
struct Typestring { | |||||
private: | |||||
char c_endian; | |||||
char c_type; | |||||
int len; | |||||
public: | |||||
inline std::string str() { | |||||
const size_t max_buflen = 16; | |||||
char buf[max_buflen]; | |||||
std::sprintf(buf, "%c%c%u", c_endian, c_type, len); | |||||
return std::string(buf); | |||||
} | |||||
Typestring(const std::vector<float>&) | |||||
: c_endian{host_endian_char}, c_type{'f'}, len{sizeof(float)} {} | |||||
Typestring(const std::vector<double>&) | |||||
: c_endian{host_endian_char}, c_type{'f'}, len{sizeof(double)} {} | |||||
Typestring(const std::vector<long double>&) | |||||
: c_endian{host_endian_char}, c_type{'f'}, len{sizeof(long double)} {} | |||||
Typestring(const std::vector<char>&) | |||||
: c_endian{no_endian_char}, c_type{'i'}, len{sizeof(char)} {} | |||||
Typestring(const std::vector<short>&) | |||||
: c_endian{host_endian_char}, c_type{'i'}, len{sizeof(short)} {} | |||||
Typestring(const std::vector<int>&) | |||||
: c_endian{host_endian_char}, c_type{'i'}, len{sizeof(int)} {} | |||||
Typestring(const std::vector<long>&) | |||||
: c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long)} {} | |||||
Typestring(const std::vector<long long>&) | |||||
: c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long long)} {} | |||||
Typestring(const std::vector<unsigned char>&) | |||||
: c_endian{no_endian_char}, c_type{'u'}, len{sizeof(unsigned char)} {} | |||||
Typestring(const std::vector<unsigned short>&) | |||||
: c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned short)} {} | |||||
Typestring(const std::vector<unsigned int>&) | |||||
: c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned int)} {} | |||||
Typestring(const std::vector<unsigned long>&) | |||||
: c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned long)} {} | |||||
Typestring(const std::vector<unsigned long long>&) | |||||
: c_endian{host_endian_char}, | |||||
c_type{'u'}, | |||||
len{sizeof(unsigned long long)} {} | |||||
Typestring(const std::vector<std::complex<float>>&) | |||||
: c_endian{host_endian_char}, | |||||
c_type{'c'}, | |||||
len{sizeof(std::complex<float>)} {} | |||||
Typestring(const std::vector<std::complex<double>>&) | |||||
: c_endian{host_endian_char}, | |||||
c_type{'c'}, | |||||
len{sizeof(std::complex<double>)} {} | |||||
Typestring(const std::vector<std::complex<long double>>&) | |||||
: c_endian{host_endian_char}, | |||||
c_type{'c'}, | |||||
len{sizeof(std::complex<long double>)} {} | |||||
}; | |||||
inline void parse_typestring(std::string typestring) { | |||||
std::regex re("'([<>|])([ifuc])(\\d+)'"); | |||||
std::smatch sm; | |||||
std::regex_match(typestring, sm, re); | |||||
if (sm.size() != 4) { | |||||
fprintf(stderr, "invalid typestring"); | |||||
} | |||||
} | |||||
namespace pyparse { | |||||
/** | |||||
Removes leading and trailing whitespaces | |||||
*/ | |||||
inline std::string trim(const std::string& str) { | |||||
const std::string whitespace = " \t"; | |||||
auto begin = str.find_first_not_of(whitespace); | |||||
if (begin == std::string::npos) | |||||
return ""; | |||||
auto end = str.find_last_not_of(whitespace); | |||||
return str.substr(begin, end - begin + 1); | |||||
} | |||||
inline std::string get_value_from_map(const std::string& mapstr) { | |||||
size_t sep_pos = mapstr.find_first_of(":"); | |||||
if (sep_pos == std::string::npos) | |||||
return ""; | |||||
std::string tmp = mapstr.substr(sep_pos + 1); | |||||
return trim(tmp); | |||||
} | |||||
/** | |||||
Parses the string representation of a Python dict | |||||
The keys need to be known and may not appear anywhere else in the data. | |||||
*/ | |||||
inline std::unordered_map<std::string, std::string> parse_dict( | |||||
std::string in, std::vector<std::string>& keys) { | |||||
std::unordered_map<std::string, std::string> map; | |||||
if (keys.size() == 0) | |||||
return map; | |||||
in = trim(in); | |||||
// unwrap dictionary | |||||
if ((in.front() == '{') && (in.back() == '}')) | |||||
in = in.substr(1, in.length() - 2); | |||||
else { | |||||
fprintf(stderr, "Not a Python dictionary."); | |||||
} | |||||
std::vector<std::pair<size_t, std::string>> positions; | |||||
for (auto const& value : keys) { | |||||
size_t pos = in.find("'" + value + "'"); | |||||
if (pos == std::string::npos) { | |||||
fprintf(stderr, "Missing %s key.", value.c_str()); | |||||
} | |||||
std::pair<size_t, std::string> position_pair{pos, value}; | |||||
positions.push_back(position_pair); | |||||
} | |||||
// sort by position in dict | |||||
std::sort(positions.begin(), positions.end()); | |||||
for (size_t i = 0; i < positions.size(); ++i) { | |||||
std::string raw_value; | |||||
size_t begin{positions[i].first}; | |||||
size_t end{std::string::npos}; | |||||
std::string key = positions[i].second; | |||||
if (i + 1 < positions.size()) | |||||
end = positions[i + 1].first; | |||||
raw_value = in.substr(begin, end - begin); | |||||
raw_value = trim(raw_value); | |||||
if (raw_value.back() == ',') | |||||
raw_value.pop_back(); | |||||
map[key] = get_value_from_map(raw_value); | |||||
} | |||||
return map; | |||||
} | |||||
/** | |||||
Parses the string representation of a Python boolean | |||||
*/ | |||||
inline bool parse_bool(const std::string& in) { | |||||
if (in == "True") | |||||
return true; | |||||
if (in == "False") | |||||
return false; | |||||
fprintf(stderr, "Invalid python boolan."); | |||||
return false; | |||||
} | |||||
/** | |||||
Parses the string representation of a Python str | |||||
*/ | |||||
inline std::string parse_str(const std::string& in) { | |||||
if ((in.front() == '\'') && (in.back() == '\'')) | |||||
return in.substr(1, in.length() - 2); | |||||
fprintf(stderr, "Invalid python string."); | |||||
return ""; | |||||
} | |||||
/** | |||||
Parses the string represenatation of a Python tuple into a vector of its items | |||||
*/ | |||||
inline std::vector<std::string> parse_tuple(std::string in) { | |||||
std::vector<std::string> v; | |||||
const char seperator = ','; | |||||
in = trim(in); | |||||
if ((in.front() == '(') && (in.back() == ')')) | |||||
in = in.substr(1, in.length() - 2); | |||||
else { | |||||
fprintf(stderr, "Invalid Python tuple."); | |||||
} | |||||
std::istringstream iss(in); | |||||
for (std::string token; std::getline(iss, token, seperator);) { | |||||
v.push_back(token); | |||||
} | |||||
return v; | |||||
} | |||||
template <typename T> | |||||
inline std::string write_tuple(const std::vector<T>& v) { | |||||
if (v.size() == 0) | |||||
return ""; | |||||
std::ostringstream ss; | |||||
if (v.size() == 1) { | |||||
ss << "(" << v.front() << ",)"; | |||||
} else { | |||||
const std::string delimiter = ", "; | |||||
// v.size() > 1 | |||||
ss << "("; | |||||
std::copy( | |||||
v.begin(), v.end() - 1, | |||||
std::ostream_iterator<T>(ss, delimiter.c_str())); | |||||
ss << v.back(); | |||||
ss << ")"; | |||||
} | |||||
return ss.str(); | |||||
} | |||||
inline std::string write_boolean(bool b) { | |||||
if (b) | |||||
return "True"; | |||||
else | |||||
return "False"; | |||||
} | |||||
} // namespace pyparse | |||||
inline void parse_header(std::string header, std::string& descr) { | |||||
/* | |||||
The first 6 bytes are a magic string: exactly "x93NUMPY". | |||||
The next 1 byte is an unsigned byte: the major version number of the file | |||||
format, e.g. x01. The next 1 byte is an unsigned byte: the minor version | |||||
number of the file format, e.g. x00. Note: the version of the file format | |||||
is not tied to the version of the numpy package. The next 2 bytes form a | |||||
little-endian unsigned short int: the length of the header data | |||||
HEADER_LEN. The next HEADER_LEN bytes form the header data describing the | |||||
array's format. It is an ASCII string which contains a Python literal | |||||
expression of a dictionary. It is terminated by a newline ('n') and | |||||
padded with spaces | |||||
('x20') to make the total length of the magic string + 4 + HEADER_LEN be | |||||
evenly divisible by 16 for alignment purposes. The dictionary contains | |||||
three keys: | |||||
"descr" : dtype.descr | |||||
An object that can be passed as an argument to the numpy.dtype() | |||||
constructor to create the array's dtype. For repeatability and | |||||
readability, this dictionary is formatted using pprint.pformat() so the | |||||
keys are in alphabetic order. | |||||
*/ | |||||
// remove trailing newline | |||||
if (header.back() != '\n') | |||||
fprintf(stderr, "invalid header"); | |||||
header.pop_back(); | |||||
// parse the dictionary | |||||
std::vector<std::string> keys{"descr"}; | |||||
auto dict_map = npy::pyparse::parse_dict(header, keys); | |||||
if (dict_map.size() == 0) | |||||
fprintf(stderr, "invalid dictionary in header"); | |||||
std::string descr_s = dict_map["descr"]; | |||||
parse_typestring(descr_s); | |||||
// remove | |||||
descr = npy::pyparse::parse_str(descr_s); | |||||
return; | |||||
} | |||||
inline void parse_header( | |||||
std::string header, std::string& descr, bool& fortran_order, | |||||
std::vector<ndarray_len_t>& shape) { | |||||
/* | |||||
The first 6 bytes are a magic string: exactly "x93NUMPY". | |||||
The next 1 byte is an unsigned byte: the major version number of the file | |||||
format, e.g. x01. The next 1 byte is an unsigned byte: the minor version | |||||
number of the file format, e.g. x00. Note: the version of the file format | |||||
is not tied to the version of the numpy package. The next 2 bytes form a | |||||
little-endian unsigned short int: the length of the header data | |||||
HEADER_LEN. The next HEADER_LEN bytes form the header data describing the | |||||
array's format. It is an ASCII string which contains a Python literal | |||||
expression of a dictionary. It is terminated by a newline ('n') and | |||||
padded with spaces | |||||
('x20') to make the total length of the magic string + 4 + HEADER_LEN be | |||||
evenly divisible by 16 for alignment purposes. The dictionary contains | |||||
three keys: | |||||
"descr" : dtype.descr | |||||
An object that can be passed as an argument to the numpy.dtype() | |||||
constructor to create the array's dtype. "fortran_order" : bool Whether | |||||
the array data is Fortran-contiguous or not. Since Fortran-contiguous | |||||
arrays are a common form of non-C-contiguity, we allow them to be written | |||||
directly to disk for efficiency. "shape" : tuple of int The shape of the | |||||
array. For repeatability and readability, this dictionary is formatted | |||||
using pprint.pformat() so the keys are in alphabetic order. | |||||
*/ | |||||
// remove trailing newline | |||||
if (header.back() != '\n') | |||||
fprintf(stderr, "invalid header"); | |||||
header.pop_back(); | |||||
// parse the dictionary | |||||
std::vector<std::string> keys{"descr", "fortran_order", "shape"}; | |||||
auto dict_map = npy::pyparse::parse_dict(header, keys); | |||||
if (dict_map.size() == 0) | |||||
fprintf(stderr, "invalid dictionary in header"); | |||||
std::string descr_s = dict_map["descr"]; | |||||
std::string fortran_s = dict_map["fortran_order"]; | |||||
std::string shape_s = dict_map["shape"]; | |||||
// TODO: extract info from typestring | |||||
parse_typestring(descr_s); | |||||
// remove | |||||
descr = npy::pyparse::parse_str(descr_s); | |||||
// convert literal Python bool to C++ bool | |||||
fortran_order = npy::pyparse::parse_bool(fortran_s); | |||||
// parse the shape tuple | |||||
auto shape_v = npy::pyparse::parse_tuple(shape_s); | |||||
if (shape_v.size() == 0) | |||||
fprintf(stderr, "invalid shape tuple in header"); | |||||
for (auto item : shape_v) { | |||||
ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item)); | |||||
shape.push_back(dim); | |||||
} | |||||
} | |||||
inline std::string write_header_dict( | |||||
const std::string& descr, bool fortran_order, | |||||
const std::vector<ndarray_len_t>& shape) { | |||||
std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order); | |||||
std::string shape_s = npy::pyparse::write_tuple(shape); | |||||
return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + | |||||
", 'shape': " + shape_s + ", }"; | |||||
} | |||||
inline void write_header( | |||||
std::ostream& out, const std::string& descr, bool fortran_order, | |||||
const std::vector<ndarray_len_t>& shape_v) { | |||||
std::string header_dict = write_header_dict(descr, fortran_order, shape_v); | |||||
size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1; | |||||
unsigned char version[2] = {1, 0}; | |||||
if (length >= 255 * 255) { | |||||
length = magic_string_length + 2 + 4 + header_dict.length() + 1; | |||||
version[0] = 2; | |||||
version[1] = 0; | |||||
} | |||||
size_t padding_len = 16 - length % 16; | |||||
std::string padding(padding_len, ' '); | |||||
// write magic | |||||
write_magic(out, version[0], version[1]); | |||||
// write header length | |||||
if (version[0] == 1 && version[1] == 0) { | |||||
char header_len_le16[2]; | |||||
uint16_t header_len = | |||||
static_cast<uint16_t>(header_dict.length() + padding.length() + 1); | |||||
header_len_le16[0] = (header_len >> 0) & 0xff; | |||||
header_len_le16[1] = (header_len >> 8) & 0xff; | |||||
out.write(reinterpret_cast<char*>(header_len_le16), 2); | |||||
} else { | |||||
char header_len_le32[4]; | |||||
uint32_t header_len = | |||||
static_cast<uint32_t>(header_dict.length() + padding.length() + 1); | |||||
header_len_le32[0] = (header_len >> 0) & 0xff; | |||||
header_len_le32[1] = (header_len >> 8) & 0xff; | |||||
header_len_le32[2] = (header_len >> 16) & 0xff; | |||||
header_len_le32[3] = (header_len >> 24) & 0xff; | |||||
out.write(reinterpret_cast<char*>(header_len_le32), 4); | |||||
} | |||||
out << header_dict << padding << '\n'; | |||||
} | |||||
inline std::string read_header(std::istream& istream) { | |||||
// check magic bytes an version number | |||||
unsigned char v_major, v_minor; | |||||
read_magic(istream, v_major, v_minor); | |||||
uint32_t header_length = 0; | |||||
if (v_major == 1 && v_minor == 0) { | |||||
char header_len_le16[2]; | |||||
istream.read(header_len_le16, 2); | |||||
header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8); | |||||
if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) { | |||||
// TODO: display warning | |||||
} | |||||
} else if (v_major == 2 && v_minor == 0) { | |||||
char header_len_le32[4]; | |||||
istream.read(header_len_le32, 4); | |||||
header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8) | | |||||
(header_len_le32[2] << 16) | (header_len_le32[3] << 24); | |||||
if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) { | |||||
// TODO: display warning | |||||
} | |||||
} else { | |||||
fprintf(stderr, "unsupported file format version"); | |||||
} | |||||
auto buf_v = std::vector<char>(); | |||||
buf_v.reserve(header_length); | |||||
istream.read(buf_v.data(), header_length); | |||||
std::string header(buf_v.data(), header_length); | |||||
return header; | |||||
} | |||||
inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) { | |||||
ndarray_len_t size = 1; | |||||
for (ndarray_len_t i : shape) | |||||
size *= i; | |||||
return size; | |||||
} | |||||
template <typename Scalar> | |||||
inline void SaveArrayAsNumpy( | |||||
const std::string& filename, bool fortran_order, unsigned int n_dims, | |||||
const unsigned long shape[], const std::vector<Scalar>& data) { | |||||
Typestring typestring_o(data); | |||||
std::string typestring = typestring_o.str(); | |||||
std::ofstream stream(filename, std::ofstream::binary); | |||||
if (!stream) { | |||||
fprintf(stderr, "io error: failed to open a file."); | |||||
} | |||||
std::vector<ndarray_len_t> shape_v(shape, shape + n_dims); | |||||
write_header(stream, typestring, fortran_order, shape_v); | |||||
auto size = static_cast<size_t>(comp_size(shape_v)); | |||||
stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size); | |||||
} | |||||
template <typename Scalar> | |||||
inline void LoadArrayFromNumpy( | |||||
const std::string& filename, std::vector<unsigned long>& shape, | |||||
std::vector<Scalar>& data) { | |||||
bool fortran_order; | |||||
LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data); | |||||
} | |||||
template <typename Scalar> | |||||
inline void LoadArrayFromNumpy( | |||||
const std::string& filename, std::vector<unsigned long>& shape, | |||||
bool& fortran_order, std::vector<Scalar>& data) { | |||||
std::ifstream stream(filename, std::ifstream::binary); | |||||
if (!stream) { | |||||
fprintf(stderr, "io error: failed to open a file."); | |||||
} | |||||
std::string header = read_header(stream); | |||||
// parse header | |||||
std::string typestr; | |||||
parse_header(header, typestr, fortran_order, shape); | |||||
// check if the typestring matches the given one | |||||
Typestring typestring_o{data}; | |||||
std::string expect_typestr = typestring_o.str(); | |||||
if (typestr != expect_typestr) { | |||||
fprintf(stderr, "formatting error: typestrings not matching"); | |||||
} | |||||
// compute the data size based on the shape | |||||
auto size = static_cast<size_t>(comp_size(shape)); | |||||
data.resize(size); | |||||
// read the data | |||||
stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar) * size); | |||||
} | |||||
inline void LoadArrayFromNumpy( | |||||
const std::string& filename, std::string& type_str, | |||||
std::vector<ndarray_len_t>& shape, std::vector<int8_t>& data) { | |||||
std::ifstream stream(filename, std::ifstream::binary); | |||||
if (!stream) { | |||||
fprintf(stderr, "io error: failed to open a file."); | |||||
} | |||||
std::string header = read_header(stream); | |||||
bool fortran_order; | |||||
// parse header | |||||
parse_header(header, type_str, fortran_order, shape); | |||||
// check if the typestring matches the given one | |||||
std::string size_str = type_str.substr(type_str.size() - 1); | |||||
size_t elem_size = atoi(size_str.c_str()); | |||||
// compute the data size based on the shape | |||||
auto byte_size = elem_size * static_cast<size_t>(comp_size(shape)); | |||||
data.resize(byte_size); | |||||
// read the data | |||||
stream.read(reinterpret_cast<char*>(data.data()), byte_size); | |||||
} | |||||
} // namespace npy | |||||
#endif // NPY_H |
@@ -0,0 +1,48 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/outdumper.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "outdumper.h" | |||||
#include "megbrain/utils/debug.h" | |||||
using namespace lar; | |||||
void OutputDumper::set(mgb::SymbolVarArray& symb_var) { | |||||
for (auto&& i : symb_var) { | |||||
auto&& var = i.node(); | |||||
DumpInfo info; | |||||
info.var_info = mgb::cg::dump_var_info({var}); | |||||
info.owner_inputs_info = mgb::cg::dump_var_info(var->owner_opr()->input()); | |||||
info.id = var->id(); | |||||
m_infos.push_back(info); | |||||
} | |||||
} | |||||
mgb::ComputingGraph::Callback OutputDumper::bind() { | |||||
auto& info = m_infos.at(m_bind_id++); | |||||
mgb::ComputingGraph::Callback cb = [&info](const mgb::DeviceTensorND& dv) { | |||||
info.hv.copy_from(dv); | |||||
}; | |||||
return cb; | |||||
} | |||||
void OutputDumper::write_to_file() { | |||||
if (!dump_file.empty()) { | |||||
for (auto&& info : m_infos) { | |||||
auto value = mgb::debug::dump_tensor( | |||||
info.hv, | |||||
mgb::ssprintf( | |||||
"var=%s owner_opr_inputs= %s", info.var_info.c_str(), | |||||
info.owner_inputs_info.c_str())); | |||||
mgb::debug::write_to_file( | |||||
mgb::ssprintf( | |||||
"%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id) | |||||
.c_str(), | |||||
value); | |||||
} | |||||
} | |||||
m_run_id++; | |||||
} |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/outdumper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/serialization/serializer.h" | |||||
namespace lar { | |||||
/*! | |||||
* \brief dumper for only output used for --bin-out-dump | |||||
*/ | |||||
class OutputDumper { | |||||
public: | |||||
struct DumpInfo { | |||||
mgb::HostTensorND hv = {}; | |||||
std::string var_info; | |||||
std::string owner_inputs_info; | |||||
size_t id; | |||||
}; | |||||
//! init the dump_file path | |||||
OutputDumper(const char* file) { dump_file = file; } | |||||
//! set the dump informations | |||||
void set(mgb::SymbolVarArray& symb_var); | |||||
//! callback function for specify output when compile computing graph | |||||
mgb::ComputingGraph::Callback bind(); | |||||
//! write dumped output into dump_file | |||||
void write_to_file(); | |||||
private: | |||||
mgb::SmallVector<DumpInfo> m_infos; | |||||
size_t m_run_id = 0; | |||||
size_t m_bind_id = 0; | |||||
std::string dump_file; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,119 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/text_table.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "text_table.h" | |||||
using namespace mgb; | |||||
namespace { | |||||
inline void mid(std::ostream& os, const std::string& str, size_t max_w) { | |||||
size_t l = (max_w - str.length()) / 2 + str.length(); | |||||
size_t r = max_w - l; | |||||
os << std::setw(l) << std::right << str; | |||||
if (r > 0) | |||||
os << std::setw(r) << ' '; | |||||
} | |||||
inline size_t char_length(char c) { | |||||
return c ? 1 : 0; | |||||
} | |||||
} // namespace | |||||
void TextTable::adjuster_last_row() { | |||||
if (m_rows.empty()) | |||||
return; | |||||
auto& row = m_rows.back(); | |||||
if (row.params.horizontal == 0 or row.params.vertical == 0) { | |||||
row.params.corner = 0; | |||||
} | |||||
if (row.params.horizontal != 0 && row.params.vertical != 0 && | |||||
row.params.corner == 0) { | |||||
row.params.corner = row.params.horizontal; | |||||
} | |||||
} | |||||
void TextTable::show(std::ostream& os) { | |||||
if (m_rows.empty()) | |||||
return; | |||||
auto& last_row = m_rows.front(); | |||||
bool first = true; | |||||
for (auto& row : m_rows) { | |||||
auto& lrow = | |||||
(last_row.values.size() * char_length(last_row.params.horizontal)) > | |||||
(row.values.size() * char_length(row.params.horizontal)) | |||||
? last_row | |||||
: row; | |||||
// line before row | |||||
if (lrow.params.horizontal) { | |||||
if (not first) | |||||
os << std::endl; | |||||
os << m_prefix; | |||||
if (lrow.params.corner) | |||||
os << lrow.params.corner; | |||||
size_t skip_size = 0; | |||||
// table name | |||||
if (first) { | |||||
os << m_name; | |||||
skip_size = m_name.length(); | |||||
} | |||||
for (size_t i = 0; i < lrow.values.size(); ++i) { | |||||
auto max_w = m_cols_max_w.at(i) + m_padding * 2; | |||||
if (max_w + char_length(lrow.params.corner) <= skip_size) { | |||||
skip_size = skip_size - max_w - char_length(lrow.params.corner); | |||||
continue; | |||||
} | |||||
size_t rest = max_w + char_length(lrow.params.corner) - skip_size; | |||||
skip_size = 0; | |||||
if (rest > char_length(lrow.params.corner)) { | |||||
os << std::string( | |||||
rest - char_length(lrow.params.corner), | |||||
lrow.params.horizontal); | |||||
rest = char_length(lrow.params.corner); | |||||
} | |||||
if (rest > 0 && lrow.params.corner) | |||||
os << lrow.params.corner; | |||||
} | |||||
} else if (first) { | |||||
os << m_prefix << ' ' << m_name; | |||||
} | |||||
first = false; | |||||
os << std::endl << m_prefix; | |||||
if (row.params.vertical) | |||||
os << row.params.vertical; | |||||
// row | |||||
for (size_t i = 0; i < row.values.size(); ++i) { | |||||
auto& str = row.values.at(i); | |||||
auto max_w = m_cols_max_w.at(i) + 2 * m_padding; | |||||
if (row.params.align == Align::Mid) { | |||||
mid(os, str, max_w); | |||||
} else if (row.params.align == Align::Left) { | |||||
os << std::setw(max_w) << std::left << str; | |||||
} else { | |||||
os << std::setw(max_w) << std::right << str; | |||||
} | |||||
if (row.params.vertical) | |||||
os << row.params.vertical; | |||||
} | |||||
last_row = row; | |||||
} | |||||
if (last_row.params.horizontal) { | |||||
os << std::endl << m_prefix; | |||||
if (last_row.params.corner) | |||||
os << last_row.params.corner; | |||||
for (size_t i = 0; i < last_row.values.size(); ++i) { | |||||
auto max_w = m_cols_max_w.at(i); | |||||
std::string tmp(max_w + m_padding * 2, last_row.params.horizontal); | |||||
os << tmp; | |||||
if (last_row.params.corner) | |||||
os << last_row.params.corner; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,133 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/text_table.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <array> | |||||
#include <iomanip> | |||||
#include <ostream> | |||||
#include <sstream> | |||||
#include <string> | |||||
#include <tuple> | |||||
#include <type_traits> | |||||
#include <vector> | |||||
#include "megbrain/common.h" | |||||
namespace mgb { | |||||
class TextTable { | |||||
public: | |||||
enum Level { Summary, Detail }; | |||||
enum class Align : int { Left, Right, Mid }; | |||||
explicit TextTable(const std::string& table_name) : m_name(table_name) {} | |||||
TextTable& horizontal(char c) { | |||||
m_row.params.horizontal = c; | |||||
return *this; | |||||
} | |||||
TextTable& vertical(char c) { | |||||
m_row.params.vertical = c; | |||||
return *this; | |||||
} | |||||
TextTable& corner(char c) { | |||||
m_row.params.corner = c; | |||||
return *this; | |||||
} | |||||
TextTable& align(Align v) { | |||||
m_row.params.align = v; | |||||
return *this; | |||||
} | |||||
TextTable& padding(size_t w) { | |||||
m_padding = w; | |||||
return *this; | |||||
} | |||||
TextTable& prefix(const std::string& str) { | |||||
m_prefix = str; | |||||
return *this; | |||||
} | |||||
template <typename T> | |||||
TextTable& add(const T& value) { | |||||
m_row.values.emplace_back(value); | |||||
if (m_cols_max_w.size() < m_row.values.size()) { | |||||
m_cols_max_w.emplace_back(m_row.values.back().length()); | |||||
} else { | |||||
mgb_assert(m_row.values.size() >= 1); | |||||
size_t i = m_row.values.size() - 1; | |||||
m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); | |||||
} | |||||
return *this; | |||||
} | |||||
template < | |||||
typename T, | |||||
typename std::enable_if<std::is_floating_point<T>::value, bool>::type = 0> | |||||
TextTable& add(const T& value) { | |||||
std::stringstream ss; | |||||
ss << std::setiosflags(std::ios::fixed) << std::setprecision(2); | |||||
ss << value; | |||||
m_row.values.emplace_back(ss.str()); | |||||
if (m_cols_max_w.size() < m_row.values.size()) { | |||||
m_cols_max_w.emplace_back(m_row.values.back().length()); | |||||
} else { | |||||
mgb_assert(m_row.values.size() >= 1); | |||||
size_t i = m_row.values.size() - 1; | |||||
m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); | |||||
} | |||||
return *this; | |||||
} | |||||
template < | |||||
typename T, | |||||
typename std::enable_if<std::is_integral<T>::value, bool>::type = 0> | |||||
TextTable& add(const T& value) { | |||||
m_row.values.emplace_back(std::to_string(value)); | |||||
return *this; | |||||
} | |||||
void eor() { | |||||
m_rows.emplace_back(m_row); | |||||
adjuster_last_row(); | |||||
m_row.values.clear(); | |||||
} | |||||
void reset() { | |||||
m_row = {}; | |||||
m_cols_max_w.clear(); | |||||
m_padding = 0; | |||||
m_rows.clear(); | |||||
} | |||||
void show(std::ostream& os); | |||||
private: | |||||
void adjuster_last_row(); | |||||
std::string m_name; | |||||
std::vector<size_t> m_cols_max_w; | |||||
size_t m_padding = 0; | |||||
std::string m_prefix = ""; | |||||
struct Row { | |||||
std::vector<std::string> values; | |||||
struct Params { | |||||
Align align = Align::Left; | |||||
char horizontal = '-', vertical = '|', corner = '+'; | |||||
} params; | |||||
}; | |||||
std::vector<Row> m_rows; | |||||
Row m_row; | |||||
}; | |||||
inline std::ostream& operator<<(std::ostream& stream, TextTable& table) { | |||||
table.show(stream); | |||||
return stream; | |||||
} | |||||
} // namespace mgb |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/main.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | |||||
#include <string> | |||||
#include "strategys/strategy.h" | |||||
int main(int argc, char** argv) { | |||||
std::string usage = "load_and_run <model_path> [options...]"; | |||||
if (argc < 2) { | |||||
printf("usage: %s\n", usage.c_str()); | |||||
return -1; | |||||
} | |||||
gflags::SetUsageMessage(usage); | |||||
gflags::SetVersionString("1.0"); | |||||
gflags::ParseCommandLineFlags(&argc, &argv, true); | |||||
std::string model_path = argv[1]; | |||||
auto strategy = lar::StrategyBase::create_strategy(model_path); | |||||
strategy->run(); | |||||
gflags::ShutDownCommandLineFlags(); | |||||
return 0; | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,60 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model.h" | |||||
#include <iostream> | |||||
#include <memory> | |||||
#include "model_lite.h" | |||||
#include "model_mdl.h" | |||||
using namespace lar; | |||||
ModelType ModelBase::get_model_type(std::string model_path) { | |||||
//! read magic number of dump file | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||||
char buf[16]; | |||||
mgb_assert(fread(buf, 1, 16, fin) == 16, "read model failed"); | |||||
fclose(fin); | |||||
// get model type | |||||
// uint32_t MGB_MAGIC = 0x5342474D | |||||
std::string tag(buf); | |||||
ModelType type; | |||||
if (tag.substr(0, 7) == std::string("mgb0001") || | |||||
tag.substr(0, 8) == std::string("mgb0000a") || | |||||
tag.substr(0, 4) == std::string("MGBS") || | |||||
tag.substr(0, 8) == std::string("mgbtest0")) { | |||||
type = ModelType::MEGDL_MODEL; | |||||
} else { | |||||
type = ModelType::LITE_MODEL; | |||||
} | |||||
return type; | |||||
} | |||||
std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||||
mgb_log_debug("model path %s\n", model_path.c_str()); | |||||
auto model_type = get_model_type(model_path); | |||||
if (ModelType::LITE_MODEL == model_type) { | |||||
return std::make_shared<ModelLite>(model_path); | |||||
} else if (ModelType::MEGDL_MODEL == model_type) { | |||||
if (FLAGS_lite) | |||||
return std::make_shared<ModelLite>(model_path); | |||||
else | |||||
return std::make_shared<ModelMdl>(model_path); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
DEFINE_bool(lite, false, "using lite model to run mdl model"); | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include <string> | |||||
#include "helpers/common.h" | |||||
DECLARE_bool(lite); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: base class of model | |||||
*/ | |||||
class ModelBase { | |||||
public: | |||||
//! get model type by the magic number in dump file | |||||
static ModelType get_model_type(std::string model_path); | |||||
//! create model by different model type | |||||
static std::shared_ptr<ModelBase> create_model(std::string model_path); | |||||
//! type of the model | |||||
virtual ModelType type() = 0; | |||||
//! set model load state | |||||
virtual void set_shared_mem(bool state) = 0; | |||||
//! load model interface for load and run strategy | |||||
virtual void load_model() = 0; | |||||
//! run model interface for load and run strategy | |||||
virtual void run_model() = 0; | |||||
//! wait asynchronous function interface for load and run strategy | |||||
virtual void wait() = 0; | |||||
virtual ~ModelBase() = default; | |||||
}; | |||||
} // namespace lar | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,50 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_lite.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_lite.h" | |||||
#include <gflags/gflags.h> | |||||
#include <cstring> | |||||
#include "misc.h" | |||||
DECLARE_bool(share_param_mem); | |||||
using namespace lar; | |||||
ModelLite::ModelLite(const std::string& path) : model_path(path) { | |||||
LITE_WARN("creat lite model use CPU as default comp node"); | |||||
}; | |||||
void ModelLite::load_model() { | |||||
m_network = std::make_shared<lite::Network>(config, IO); | |||||
if (share_model_mem) { | |||||
//! WARNNING:maybe not right to share param memmory for this | |||||
LITE_WARN("enable share model memory"); | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
void* ptr = malloc(size); | |||||
std::shared_ptr<void> buf{ptr, free}; | |||||
auto nr = fread(buf.get(), 1, size, fin); | |||||
LITE_ASSERT(nr == size, "read model file failed"); | |||||
fclose(fin); | |||||
m_network->load_model(buf.get(), size); | |||||
} else { | |||||
m_network->load_model(model_path); | |||||
} | |||||
} | |||||
void ModelLite::run_model() { | |||||
m_network->forward(); | |||||
} | |||||
void ModelLite::wait() { | |||||
m_network->wait(); | |||||
} |
@@ -0,0 +1,73 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_lite.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include "helpers/common.h" | |||||
#include "helpers/data_parser.h" | |||||
#include "lite/network.h" | |||||
#include "model.h" | |||||
namespace lar { | |||||
/*! | |||||
* \brief: megengine lite model | |||||
*/ | |||||
class ModelLite : public ModelBase { | |||||
public: | |||||
using Strategy = LiteAlgoSelectStrategy; | |||||
ModelLite(const std::string& path); | |||||
//! model type | |||||
ModelType type() override { return ModelType::LITE_MODEL; } | |||||
//! set to load from shared memory | |||||
void set_shared_mem(bool state) override { share_model_mem = state; } | |||||
//! load model from dump file | |||||
void load_model() override; | |||||
//! run model with given runtime parameter | |||||
void run_model() override; | |||||
//! wait the end of asynchronous function execution | |||||
void wait() override; | |||||
//! get the network of lite model | |||||
std::shared_ptr<lite::Network> get_lite_network() { return m_network; } | |||||
//! get the config of lite model | |||||
lite::Config& get_config() { return config; } | |||||
//! get the networkIO of lite model | |||||
lite::NetworkIO& get_networkIO() { return IO; } | |||||
//! get the data parser | |||||
DataParser& get_input_parser() { return parser; } | |||||
//! set the strategy before load model | |||||
void set_lite_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } | |||||
//! get algo strategy | |||||
Strategy& get_lite_strategy() { return m_strategy; } | |||||
private: | |||||
bool share_model_mem; | |||||
std::string model_path; | |||||
DataParser parser; | |||||
lite::Config config; | |||||
lite::NetworkIO IO; | |||||
std::shared_ptr<lite::Network> m_network; | |||||
Strategy m_strategy; | |||||
}; | |||||
} // namespace lar | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,105 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_mdl.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_mdl.h" | |||||
#include <gflags/gflags.h> | |||||
#include <iostream> | |||||
DECLARE_bool(share_param_mem); | |||||
using namespace lar; | |||||
ModelMdl::ModelMdl(const std::string& path) : model_path(path) { | |||||
mgb_log_warn("creat mdl model use XPU as default comp node"); | |||||
m_load_config.comp_graph = mgb::ComputingGraph::make(); | |||||
m_load_config.comp_graph->options().graph_opt_level = 0; | |||||
testcase_num = 0; | |||||
} | |||||
void ModelMdl::load_model() { | |||||
//! read dump file | |||||
if (share_model_mem) { | |||||
mgb_log_warn("enable share model memory"); | |||||
FILE* fin = fopen(model_path.c_str(), "rb"); | |||||
mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
void* ptr = malloc(size); | |||||
std::shared_ptr<void> buf{ptr, free}; | |||||
auto nr = fread(buf.get(), 1, size, fin); | |||||
mgb_assert(nr == size, "read model file failed"); | |||||
fclose(fin); | |||||
m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size); | |||||
} else { | |||||
m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str()); | |||||
} | |||||
//! get dump_with_testcase model testcase number | |||||
char magic[8]; | |||||
m_model_file->read(magic, sizeof(magic)); | |||||
if (strncmp(magic, "mgbtest0", 8)) { | |||||
m_model_file->rewind(); | |||||
} else { | |||||
m_model_file->read(&testcase_num, sizeof(testcase_num)); | |||||
} | |||||
auto format = | |||||
mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file); | |||||
mgb_assert( | |||||
format.valid(), | |||||
"invalid format, please make sure model is dumped by GraphDumper"); | |||||
//! load computing graph of model | |||||
m_loader = mgb::serialization::GraphLoader::make( | |||||
std::move(m_model_file), format.val()); | |||||
m_load_result = m_loader->load(m_load_config, false); | |||||
m_load_config.comp_graph.reset(); | |||||
// get testcase input generated by dump_with_testcase.py | |||||
if (testcase_num) { | |||||
for (auto&& i : m_load_result.tensor_map) { | |||||
test_input_tensors.emplace_back(i.first, i.second.get()); | |||||
} | |||||
std::sort(test_input_tensors.begin(), test_input_tensors.end()); | |||||
} | |||||
// initialize output callback | |||||
for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { | |||||
mgb::ComputingGraph::Callback cb; | |||||
m_callbacks.push_back(cb); | |||||
} | |||||
} | |||||
void ModelMdl::make_output_spec() { | |||||
for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { | |||||
auto item = m_load_result.output_var_list[i]; | |||||
m_output_spec.emplace_back(item, std::move(m_callbacks[i])); | |||||
} | |||||
m_asyc_exec = m_load_result.graph_compile(m_output_spec); | |||||
} | |||||
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader() { | |||||
m_loader = mgb::serialization::GraphLoader::make( | |||||
m_loader->reset_file(), m_loader->format()); | |||||
return m_loader; | |||||
} | |||||
void ModelMdl::run_model() { | |||||
mgb_assert( | |||||
m_asyc_exec != nullptr, | |||||
"empty asychronous function to execute after graph compiled"); | |||||
m_asyc_exec->execute(); | |||||
} | |||||
void ModelMdl::wait() { | |||||
m_asyc_exec->wait(); | |||||
} |
@@ -0,0 +1,117 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_mdl.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/plugin/opr_io_dump.h" | |||||
#include "megbrain/serialization/extern_c_opr.h" | |||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/utils/debug.h" | |||||
#include "megbrain/plugin/num_range_checker.h" | |||||
#include "megbrain/plugin/profiler.h" | |||||
#include "helpers/common.h" | |||||
#include "helpers/data_parser.h" | |||||
#include "model.h" | |||||
namespace lar { | |||||
class ModelMdl : public ModelBase { | |||||
public: | |||||
using Strategy = mgb::opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
//! interface implement of ModelBase | |||||
ModelMdl(const std::string& path); | |||||
ModelType type() override { return ModelType::MEGDL_MODEL; } | |||||
void set_shared_mem(bool state) override { share_model_mem = state; } | |||||
void load_model() override; | |||||
void make_output_spec(); | |||||
void run_model() override; | |||||
void wait() override; | |||||
//! get load result for megDL model | |||||
mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { | |||||
return m_load_result; | |||||
} | |||||
//! get load config for megDL model | |||||
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } | |||||
//! reset the graph loader for dump_with_testcase model | |||||
std::shared_ptr<mgb::serialization::GraphLoader>& reset_loader(); | |||||
//! algo strategy for runing model | |||||
void set_mdl_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } | |||||
Strategy& get_mdl_strategy() { return m_strategy; } | |||||
//! get data parser | |||||
DataParser& get_input_parser() { return parser; } | |||||
uint32_t get_testcase_num() { return testcase_num; } | |||||
std::vector<std::pair<std::string, mgb::HostTensorND*>>& get_test_input() { | |||||
return test_input_tensors; | |||||
} | |||||
//! get output specified configuration | |||||
mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; } | |||||
std::unique_ptr<mgb::cg::AsyncExecutable>& get_async_func() { return m_asyc_exec; } | |||||
void set_output_callback(std::vector<mgb::ComputingGraph::Callback>& cb) { | |||||
mgb_assert( | |||||
m_callbacks.size() == cb.size(), | |||||
"invalid output callback list to set!!"); | |||||
for (size_t i = 0; i < cb.size(); i++) { | |||||
m_callbacks[i] = cb[i]; | |||||
} | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
std::unique_ptr<mgb::GraphProfiler>& get_profiler() { return m_profiler; } | |||||
void set_profiler() { | |||||
m_profiler = | |||||
std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); | |||||
} | |||||
#endif | |||||
void set_num_range_checker(float range) { | |||||
m_num_range_checker = std::make_unique<mgb::NumRangeChecker>( | |||||
m_load_config.comp_graph.get(), range); | |||||
} | |||||
private: | |||||
bool share_model_mem; | |||||
std::string model_path; | |||||
std::unique_ptr<mgb::serialization::InputFile> m_model_file; | |||||
mgb::serialization::GraphLoadConfig m_load_config; | |||||
mgb::serialization::GraphLoader::LoadResult m_load_result; | |||||
std::shared_ptr<mgb::serialization::GraphLoader> m_loader; | |||||
std::unique_ptr<mgb::cg::AsyncExecutable> m_asyc_exec; | |||||
uint32_t testcase_num; | |||||
std::vector<std::pair<std::string, mgb::HostTensorND*>> test_input_tensors; | |||||
DataParser parser; | |||||
Strategy m_strategy = Strategy::HEURISTIC; | |||||
std::vector<mgb::ComputingGraph::Callback> m_callbacks; | |||||
mgb::ComputingGraph::OutputSpec m_output_spec; | |||||
std::unique_ptr<mgb::NumRangeChecker> m_num_range_checker; | |||||
#if MGB_ENABLE_JSON | |||||
std::unique_ptr<mgb::GraphProfiler> m_profiler; | |||||
#endif | |||||
}; | |||||
} // namespace lar | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,200 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/device_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <iostream> | |||||
#include <sstream> | |||||
#include "lite/global.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "misc.h" | |||||
#include "device_options.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
DECLARE_bool(weight_preprocess); | |||||
using namespace lar; | |||||
/////////////////// XPUDeviceOption ////////////////////// | |||||
namespace lar { | |||||
template <> | |||||
void XPUDeviceOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || | |||||
(enable_multithread_default)) { | |||||
LITE_WARN("using cpu device\n"); | |||||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | |||||
} | |||||
#if MGE_WITH_CUDA | |||||
if (enable_cuda) { | |||||
model->get_config().device_type = LiteDeviceType::LITE_CUDA; | |||||
} | |||||
#endif | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto network = model->get_lite_network(); | |||||
if (enable_cpu_default) { | |||||
LITE_WARN("using cpu default device\n"); | |||||
lite::Runtime::set_cpu_inplace_mode(network); | |||||
} | |||||
if (enable_multithread) { | |||||
LITE_WARN("using multithread device\n"); | |||||
lite::Runtime::set_cpu_threads_number(network, thread_num); | |||||
} | |||||
if (enable_multithread_default) { | |||||
LITE_WARN("using multithread default device\n"); | |||||
lite::Runtime::set_cpu_inplace_mode(network); | |||||
lite::Runtime::set_cpu_threads_number(network, thread_num); | |||||
} | |||||
if (enable_set_core_ids) { | |||||
std::string core_str; | |||||
for (auto id : core_ids) { | |||||
core_str += std::to_string(id) + ","; | |||||
} | |||||
LITE_WARN("multi thread core ids: %s\n", core_str.c_str()); | |||||
lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { | |||||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||||
}; | |||||
lite::Runtime::set_runtime_thread_affinity(network, affinity_callback); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (enable_cpu) { | |||||
mgb_log_warn("using cpu device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||||
loc.type = mgb::CompNode::DeviceType::CPU; | |||||
}; | |||||
} | |||||
#if MGE_WITH_CUDA | |||||
if (enable_cuda) { | |||||
mgb_log_warn("using cuda device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||||
loc.type = mgb::CompNode::DeviceType::CUDA; | |||||
}; | |||||
} | |||||
#endif | |||||
if (enable_cpu_default) { | |||||
mgb_log_warn("using cpu default device\n"); | |||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||||
loc.type = mgb::CompNode::DeviceType::CPU; | |||||
loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||||
}; | |||||
} | |||||
if (enable_multithread) { | |||||
mgb_log_warn("using multithread device\n"); | |||||
model->get_mdl_config().comp_node_mapper = | |||||
[&](mgb::CompNode::Locator& loc) { | |||||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||||
loc.device = 0; | |||||
loc.stream = thread_num; | |||||
}; | |||||
} | |||||
if (enable_multithread_default) { | |||||
mgb_log_warn("using multithread default device\n"); | |||||
model->get_mdl_config().comp_node_mapper = | |||||
[&](mgb::CompNode::Locator& loc) { | |||||
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||||
loc.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||||
loc.stream = thread_num; | |||||
}; | |||||
} | |||||
if (enable_set_core_ids) { | |||||
std::string core_str; | |||||
for (auto id : core_ids) { | |||||
core_str += std::to_string(id) + ","; | |||||
} | |||||
mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str()); | |||||
auto affinity_callback = [&](size_t thread_id) { | |||||
mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||||
}; | |||||
mgb::CompNode::Locator loc; | |||||
model->get_mdl_config().comp_node_mapper(loc); | |||||
auto comp_node = mgb::CompNode::load(loc); | |||||
mgb::CompNodeEnv::from_comp_node(comp_node).cpu_env().set_affinity( | |||||
affinity_callback); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
XPUDeviceOption::XPUDeviceOption() { | |||||
m_option_name = "xpu_device"; | |||||
enable_cpu = FLAGS_cpu; | |||||
#if MGE_WITH_CUDA | |||||
enable_cuda = FLAGS_cuda; | |||||
#endif | |||||
enable_cpu_default = FLAGS_cpu_default; | |||||
if (FLAGS_multithread >= 0) { | |||||
thread_num = FLAGS_multithread; | |||||
enable_multithread = true; | |||||
} | |||||
if (FLAGS_multithread_default >= 0) { | |||||
thread_num = FLAGS_multithread_default; | |||||
enable_multithread_default = true; | |||||
} | |||||
if (!FLAGS_multi_thread_core_ids.empty()) { | |||||
mgb_assert(enable_multithread, "core ids should be set after --multithread"); | |||||
std::stringstream id_stream(FLAGS_multi_thread_core_ids); | |||||
std::string id; | |||||
size_t thread_cnt = 0; | |||||
while (getline(id_stream, id, ',')) { | |||||
thread_cnt++; | |||||
core_ids.push_back(atoi(id.c_str())); | |||||
} | |||||
mgb_assert( | |||||
thread_cnt == thread_num, | |||||
"core ids number should be same with thread number set before"); | |||||
enable_set_core_ids = true; | |||||
} | |||||
} | |||||
bool XPUDeviceOption::is_valid() { | |||||
bool ret = FLAGS_cpu || FLAGS_cpu_default; | |||||
#if MGE_WITH_CUDA | |||||
ret = ret || FLAGS_cuda; | |||||
#endif | |||||
ret = ret || FLAGS_multithread >= 0; | |||||
ret = ret || FLAGS_multithread_default >= 0; | |||||
ret = ret || !FLAGS_multi_thread_core_ids.empty(); | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||||
static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption); | |||||
if (XPUDeviceOption::is_valid()) { | |||||
return std::static_pointer_cast<lar::OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void XPUDeviceOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// xpu gflags //////////////////////////// | |||||
DEFINE_bool(cpu, false, "set CPU device as running device"); | |||||
#if MGE_WITH_CUDA | |||||
DEFINE_bool(cuda, false, "set CUDA device as running device "); | |||||
#endif | |||||
DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); | |||||
DEFINE_int32(multithread, -1, "set multithread device as running device"); | |||||
DEFINE_int32( | |||||
multithread_default, -1, | |||||
"set multithread device as running device with inplace mode"); | |||||
DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); | |||||
REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/device_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_bool(cpu); | |||||
#if MGE_WITH_CUDA | |||||
DECLARE_bool(cuda); | |||||
#endif | |||||
DECLARE_bool(cpu_default); | |||||
DECLARE_int32(multithread); | |||||
DECLARE_int32(multithread_default); | |||||
DECLARE_string(multi_thread_core_ids); | |||||
namespace lar { | |||||
class XPUDeviceOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
XPUDeviceOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
bool enable_cpu; | |||||
#if MGE_WITH_CUDA | |||||
bool enable_cuda; | |||||
#endif | |||||
bool enable_cpu_default; | |||||
bool enable_multithread; | |||||
bool enable_multithread_default; | |||||
bool enable_set_core_ids; | |||||
size_t thread_num; | |||||
std::vector<int> core_ids; | |||||
std::string m_option_name; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,216 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/extern_c_opr_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "extern_c_opr_options.h" | |||||
#include "megbrain/utils/debug.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
namespace lar { | |||||
template <> | |||||
void COprLibOption::config_model_internel( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
MGB_MARK_USED_VAR(model); | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (!lib_path.empty()) { | |||||
lite::set_loader_lib_path(lib_path); | |||||
} | |||||
if (c_opr_args.is_run_c_opr_with_param) { | |||||
LITE_THROW( | |||||
"lite model dont't support run with external c opr " | |||||
"parmeter"); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void COprLibOption::config_model_internel( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (!lib_path.empty()) { | |||||
load_lib(); | |||||
} | |||||
if (c_opr_args.is_run_c_opr_with_param) { | |||||
mgb_assert( | |||||
c_opr_args.is_run_c_opr && | |||||
c_opr_args.copr_param_device_ptr_malloc && | |||||
c_opr_args.copr_param_device_ptr_free && | |||||
c_opr_args.copr_param_device_ptr_h2d, | |||||
"--c-opr-lib-with-param need config with --c-opr-lib, also " | |||||
"extern c opr loader need implemente " | |||||
"copr_param_device_ptr_malloc, copr_param_device_ptr_free " | |||||
"and copr_param_device_ptr_h2d symbols"); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||||
if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { | |||||
init_extern_param(model); | |||||
set_Copr_IO(model); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_RUNNING_ITER) { | |||||
if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { | |||||
c_opr_args.copr_param_device_ptr_free(c_opr_param.get()); | |||||
free(c_opr_param->input); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
MGBDType COprLibOption::dtype_cpp2c(megdnn::DType dtype) { | |||||
switch (dtype.enumv()) { | |||||
case megdnn::DTypeEnum::Float32: | |||||
return MGB_DTYPE_FLOAT32; | |||||
case megdnn::DTypeEnum::Int32: | |||||
return MGB_DTYPE_INT32; | |||||
case megdnn::DTypeEnum::Int16: | |||||
return MGB_DTYPE_INT16; | |||||
case megdnn::DTypeEnum::Uint8: | |||||
return MGB_DTYPE_UINT8; | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
case megdnn::DTypeEnum::Float16: | |||||
return MGB_DTYPE_FLOAT16; | |||||
#endif | |||||
default: | |||||
mgb_throw( | |||||
mgb::InternalError, "unsupported dtype for extern C API: %s", | |||||
dtype.name()); | |||||
} | |||||
} | |||||
void COprLibOption::tensor_shape_to_c( | |||||
const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape) { | |||||
mgb_assert( | |||||
shape.ndim <= MGB_TENSOR_MAX_NDIM, "shape ndim too large: %zu", shape.ndim); | |||||
mgb_shape.ndim = shape.ndim; | |||||
for (size_t i = 0; i < shape.ndim; ++i) { | |||||
mgb_shape.shape[i] = shape[i]; | |||||
} | |||||
} | |||||
void COprLibOption::init_extern_param(std::shared_ptr<ModelBase> model_ptr) { | |||||
auto model = std::static_pointer_cast<ModelMdl>(model_ptr); | |||||
auto inp_tensors = model->get_test_input(); | |||||
c_opr_param = std::make_shared<ExternCOprParam>(); | |||||
memset(c_opr_param.get(), 0, sizeof(ExternCOprParam)); | |||||
//! we just test input on npu case, do not test output on | |||||
//! npu case, so we just init input shape and type | |||||
c_opr_param->nr_input = inp_tensors.size(); | |||||
c_opr_param->input = (ExternDeviceTensor*)malloc( | |||||
sizeof(ExternDeviceTensor) * inp_tensors.size()); | |||||
memset(c_opr_param->input, 0, sizeof(ExternDeviceTensor) * inp_tensors.size()); | |||||
//! init input ExternDeviceTensor shape and dtype | |||||
for (size_t input_idx = 0; input_idx < inp_tensors.size(); input_idx++) { | |||||
auto& mgb_tensor_layout = c_opr_param->input[input_idx].layout; | |||||
auto host_tensor_nd_p = inp_tensors[input_idx].second; | |||||
mgb_tensor_layout.dtype = dtype_cpp2c(host_tensor_nd_p->dtype()); | |||||
tensor_shape_to_c( | |||||
inp_tensors[input_idx].second->shape(), mgb_tensor_layout.shape); | |||||
} | |||||
c_opr_param->nr_output = 0; | |||||
//! now call copr_param_device_ptr_malloc to malloc | |||||
//! device_ptr | |||||
c_opr_args.copr_param_device_ptr_malloc(c_opr_param.get()); | |||||
} | |||||
void COprLibOption::load_lib() { | |||||
auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); | |||||
mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); | |||||
const char* entry = MGB_C_OPR_INIT_FUNC_STR; | |||||
auto func = dlsym(handle, entry); | |||||
mgb_assert(func, "can not resolve %s: %s", entry, dlerror()); | |||||
typedef void (*entry_f_t)(void*); | |||||
reinterpret_cast<entry_f_t>(func)( | |||||
reinterpret_cast<void*>(&mgb_get_extern_c_opr_api_versioned)); | |||||
printf("loaded C opr library: %s\n", lib_path.c_str()); | |||||
entry = "copr_param_device_ptr_malloc"; | |||||
func = dlsym(handle, entry); | |||||
if (func) { | |||||
printf("get %s from: %s\n", entry, lib_path.c_str()); | |||||
c_opr_args.copr_param_device_ptr_malloc = | |||||
reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_MEM_T>(func); | |||||
} | |||||
entry = "copr_param_device_ptr_free"; | |||||
func = dlsym(handle, entry); | |||||
if (func) { | |||||
printf("get %s from: %s\n", entry, lib_path.c_str()); | |||||
c_opr_args.copr_param_device_ptr_free = | |||||
reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_MEM_T>(func); | |||||
} | |||||
entry = "copr_param_device_ptr_h2d"; | |||||
func = dlsym(handle, entry); | |||||
if (func) { | |||||
printf("get %s from: %s\n", entry, lib_path.c_str()); | |||||
c_opr_args.copr_param_device_ptr_h2d = | |||||
reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_H2D_T>(func); | |||||
} | |||||
} | |||||
void COprLibOption::set_Copr_IO(std::shared_ptr<ModelBase> model_ptr) { | |||||
auto model = std::static_pointer_cast<ModelMdl>(model_ptr); | |||||
auto inp_tensors = model->get_test_input(); | |||||
auto loader = model->reset_loader(); | |||||
auto testcase = loader->load(model->get_mdl_config(), false); | |||||
mgb_assert(testcase.output_var_list.size() == inp_tensors.size()); | |||||
for (size_t i = 0; i < inp_tensors.size(); ++i) { | |||||
auto&& opr = testcase.output_var_list[i] | |||||
.node() | |||||
->owner_opr() | |||||
->cast_final_safe<mgb::opr::SharedDeviceTensor>(); | |||||
c_opr_args.copr_param_device_ptr_h2d( | |||||
c_opr_param.get(), opr.dev_data()->raw_ptr(), i); | |||||
} | |||||
//! now config c opr dynamic param | |||||
config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); | |||||
} | |||||
COprLibOption::COprLibOption() { | |||||
m_option_name = "c_opr_lib"; | |||||
lib_path = FLAGS_c_opr_lib; | |||||
c_opr_args.is_run_c_opr = !lib_path.empty(); | |||||
c_opr_args.is_run_c_opr_with_param = FLAGS_c_opr_lib_with_param; | |||||
} | |||||
bool COprLibOption::is_valid() { | |||||
return !FLAGS_c_opr_lib.empty() || FLAGS_c_opr_lib_with_param; | |||||
} | |||||
std::shared_ptr<OptionBase> COprLibOption::create_option() { | |||||
static std::shared_ptr<COprLibOption> option(new COprLibOption); | |||||
if (COprLibOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void COprLibOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
DEFINE_string( | |||||
c_opr_lib, "", | |||||
"Load external operator library. It must implement " | |||||
"MGB_C_OPR_INIT_FUNC_STR as the entry point"); | |||||
DEFINE_bool( | |||||
c_opr_lib_with_param, false, | |||||
"Run c opr lib with param, use to benchmark speed and check result, " | |||||
"need c opr loader implemente `copr_param_device_ptr_malloc, " | |||||
"copr_param_device_ptr_free and copr_param_device_ptr_h2d' symbols"); | |||||
REGIST_OPTION_CREATOR(c_opr_lib, lar::COprLibOption::create_option); |
@@ -0,0 +1,64 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/extern_c_opr_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "megbrain/graph/extern_copr_api.h" | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_bool(c_opr_lib_with_param); | |||||
DECLARE_string(c_opr_lib); | |||||
namespace lar { | |||||
struct COprArgs { | |||||
//! for run c opr | |||||
bool is_run_c_opr = false; | |||||
bool is_run_c_opr_with_param = false; | |||||
typedef void (*COPR_PARAM_DEVICE_PTR_MEM_T)(ExternCOprParam* param); | |||||
typedef void (*COPR_PARAM_DEVICE_PTR_H2D_T)( | |||||
ExternCOprParam* param, void* host_ptr, size_t extern_device_tensor_id); | |||||
COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_malloc = nullptr; | |||||
COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_free = nullptr; | |||||
COPR_PARAM_DEVICE_PTR_H2D_T copr_param_device_ptr_h2d = nullptr; | |||||
}; | |||||
class COprLibOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
COprLibOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
void load_lib(); | |||||
MGBDType dtype_cpp2c(megdnn::DType dtype); | |||||
void tensor_shape_to_c(const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape); | |||||
void init_extern_param(std::shared_ptr<ModelBase> model); | |||||
void set_Copr_IO(std::shared_ptr<ModelBase> model); | |||||
std::string m_option_name; | |||||
COprArgs c_opr_args; | |||||
std::string lib_path; | |||||
std::shared_ptr<ExternCOprParam> c_opr_param; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,231 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/fastrun_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | |||||
#if defined(_WIN32) | |||||
#include <io.h> | |||||
#define F_OK 0 | |||||
#define access(a, b) _access(a, b) | |||||
#elif __linux__ || __unix__ || __APPLE__ | |||||
#include <unistd.h> | |||||
#endif | |||||
#include "fastrun_options.h" | |||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
namespace lar { | |||||
template <> | |||||
void FastRunOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
//! set the algo policy before model load | |||||
using Strategy = ModelLite::Strategy; | |||||
uint32_t strategy = 0; | |||||
#if MGB_ENABLE_FASTRUN | |||||
if (enable_full_run) { | |||||
LITE_WARN("enable full-run strategy for algo profile"); | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy; | |||||
} else if (enable_fast_run) { | |||||
LITE_WARN("enable fast-run strategy for algo profile"); | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||||
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | |||||
} else { | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||||
} | |||||
#else | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||||
#endif | |||||
if (batch_binary_equal || enable_reproducible) { | |||||
LITE_WARN("enable reproducible strategy for algo profile"); | |||||
if (batch_binary_equal) | |||||
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) | | |||||
strategy; | |||||
} | |||||
auto lite_strategy = static_cast<Strategy>(strategy); | |||||
model->set_lite_strategy(lite_strategy); | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto lite_network = model->get_lite_network(); | |||||
auto lite_strategy = model->get_lite_strategy(); | |||||
//! set algo policy for model | |||||
lite::Runtime::set_network_algo_policy( | |||||
lite_network, lite_strategy, share_batch_size, batch_binary_equal); | |||||
if (!m_fast_run_cache.empty()) { | |||||
if (!access(m_fast_run_cache.c_str(), F_OK)) { | |||||
lite::set_persistent_cache(m_fast_run_cache); | |||||
} else { | |||||
lite::set_persistent_cache(m_fast_run_cache, true); | |||||
} | |||||
//! TODO:this is from mdl model settings but not matched settings in | |||||
//! lite model | |||||
// if (!enable_full_run && !enable_fast_run) | |||||
// mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
#if MGB_ENABLE_FASTRUN | |||||
//! dump algo cache | |||||
if (!m_fast_run_cache.empty()) { | |||||
lite::dump_persistent_cache(m_fast_run_cache); | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
template <> | |||||
void FastRunOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
//! set the algo policy before model load | |||||
using Strategy = ModelMdl::Strategy; | |||||
auto strategy = static_cast<Strategy>(0); | |||||
#if MGB_ENABLE_FASTRUN | |||||
if (enable_full_run) { | |||||
mgb_log_warn("enable full-run strategy for algo profile"); | |||||
strategy = Strategy::PROFILE | strategy; | |||||
} else if (enable_fast_run) { | |||||
mgb_log_warn("enable fast-run strategy for algo profile"); | |||||
strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; | |||||
} else { | |||||
strategy = Strategy::HEURISTIC | strategy; | |||||
} | |||||
#else | |||||
strategy = Strategy::HEURISTIC | strategy; | |||||
#endif | |||||
if (batch_binary_equal || enable_reproducible) { | |||||
mgb_log_warn("enable reproducible strategy for algo profile"); | |||||
strategy = Strategy::REPRODUCIBLE | strategy; | |||||
} | |||||
model->set_mdl_strategy(strategy); | |||||
//! set binary_equal_between_batch and shared_batch_size | |||||
if (batch_binary_equal) { | |||||
mgb_log_warn("enable batch binary equal"); | |||||
model->get_mdl_config() | |||||
.comp_graph->options() | |||||
.fast_run_config.binary_equal_between_batch = true; | |||||
} | |||||
if (share_batch_size > 0) { | |||||
mgb_log_warn("set shared shared batch"); | |||||
model->get_mdl_config() | |||||
.comp_graph->options() | |||||
.fast_run_config.shared_batch_size = share_batch_size; | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto vars = model->get_mdl_load_result().output_var_list; | |||||
auto strategy = model->get_mdl_strategy(); | |||||
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); | |||||
// set algo cache path | |||||
if (!m_fast_run_cache.empty()) { | |||||
if (!access(m_fast_run_cache.c_str(), F_OK)) { | |||||
mgb::PersistentCache::set_impl( | |||||
std::make_shared<mgb::InFilePersistentCache>( | |||||
m_fast_run_cache.c_str())); | |||||
} else { | |||||
mgb::PersistentCache::set_impl( | |||||
std::make_shared<mgb::InFilePersistentCache>()); | |||||
} | |||||
#if MGB_ENABLE_FASTRUN | |||||
if (!enable_full_run && !enable_fast_run) | |||||
#endif | |||||
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
#if MGB_ENABLE_FASTRUN | |||||
//! dump algo cache | |||||
if (!m_fast_run_cache.empty()) { | |||||
static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst()) | |||||
.dump_cache(m_fast_run_cache.c_str()); | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
FastRunOption::FastRunOption() { | |||||
m_option_name = "fastrun"; | |||||
#if MGB_ENABLE_FASTRUN | |||||
enable_fast_run = FLAGS_fast_run; | |||||
enable_full_run = FLAGS_full_run; | |||||
#endif | |||||
batch_binary_equal = FLAGS_binary_equal_between_batch; | |||||
enable_reproducible = FLAGS_reproducible; | |||||
m_fast_run_cache = FLAGS_fast_run_algo_policy; | |||||
share_batch_size = FLAGS_fast_run_shared_batch_size; | |||||
#if MGB_ENABLE_FASTRUN | |||||
//! while fastrun cache file path is not empty and can't be accessed | |||||
if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { | |||||
mgb_assert( | |||||
enable_full_run || enable_fast_run, | |||||
"--fast-run or --full-run should be enabled"); | |||||
} | |||||
if (share_batch_size) { | |||||
mgb_assert( | |||||
enable_full_run || enable_fast_run || !m_fast_run_cache.empty(), | |||||
"--fast-run-shared-batch-size should be used with " | |||||
"--fast-run|--full-run|--fast-run-algo-policy"); | |||||
} | |||||
#endif | |||||
} | |||||
bool FastRunOption::is_valid() { | |||||
bool ret = false; | |||||
#if MGB_ENABLE_FASTRUN | |||||
ret = ret || FLAGS_fast_run; | |||||
ret = ret || FLAGS_full_run; | |||||
#endif | |||||
ret = ret || FLAGS_binary_equal_between_batch; | |||||
ret = ret || FLAGS_fast_run_shared_batch_size > 0; | |||||
ret = ret || FLAGS_reproducible; | |||||
ret = ret || FLAGS_fast_run_algo_policy.size() > 0; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||||
static std::shared_ptr<FastRunOption> option(new FastRunOption); | |||||
if (FastRunOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void FastRunOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
#if MGB_ENABLE_FASTRUN | |||||
DEFINE_bool(fast_run, false, "whether to use fast-run in model run"); | |||||
DEFINE_bool(full_run, false, "whether to use full-run in model run"); | |||||
#endif | |||||
DEFINE_bool( | |||||
binary_equal_between_batch, false, | |||||
"Each batch of output is promised binary equal if each batch of " | |||||
"input is binary equal\n Note that if this option is turned on, " | |||||
"`--reproducible` will also be turned on."); | |||||
DEFINE_bool( | |||||
reproducible, false, | |||||
"Enable choose algo which is reproducible. It mainly used for " | |||||
"cudnn algos.See " | |||||
"https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" | |||||
"index.html#reproducibility" | |||||
"for more details."); | |||||
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||||
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | |||||
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); |
@@ -0,0 +1,57 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/fastrun_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
#if MGB_ENABLE_FASTRUN | |||||
DECLARE_bool(fast_run); | |||||
DECLARE_bool(full_run); | |||||
#endif | |||||
DECLARE_bool(reproducible); | |||||
DECLARE_bool(binary_equal_between_batch); | |||||
DECLARE_uint32(fast_run_shared_batch_size); | |||||
DECLARE_string(fast_run_algo_policy); | |||||
namespace lar { | |||||
class FastRunOption final : public OptionBase { | |||||
public: | |||||
//! get condition for construct FastRunOption | |||||
static bool is_valid(); | |||||
//! creat option using condition from cmdline args | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
//! configure model for different runtime_param | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
//! get options name for quickly search | |||||
std::string option_name() const override { return m_option_name; } | |||||
private: | |||||
FastRunOption(); | |||||
//! config template for different model | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | |||||
#if MGB_ENABLE_FASTRUN | |||||
bool enable_fast_run; //! fast run strategy flag | |||||
bool enable_full_run; //! full run strategy flag | |||||
#endif | |||||
bool batch_binary_equal; //! fast run stratgey setting | |||||
bool enable_reproducible; //! enable reproducible strategy | |||||
size_t share_batch_size; //! fast run strategy share batch size setting | |||||
std::string m_fast_run_cache; //! fast run cache file path | |||||
std::string m_option_name; //! option name | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,295 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/io_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <map> | |||||
#include "helpers/data_parser.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
#include "io_options.h" | |||||
namespace lar { | |||||
template <> | |||||
void InputOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto parser = model->get_input_parser(); | |||||
auto io = model->get_networkIO(); | |||||
for (size_t idx = 0; idx < data_path.size(); ++idx) { | |||||
parser.feed(data_path[idx].c_str()); | |||||
} | |||||
auto inputs = parser.inputs; | |||||
bool is_host = true; | |||||
for (auto& i : inputs) { | |||||
io.inputs.push_back({i.first, is_host}); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto config = model->get_config(); | |||||
auto parser = model->get_input_parser(); | |||||
auto network = model->get_lite_network(); | |||||
//! datd type map from mgb data type to lite data type | |||||
std::map<megdnn::DTypeEnum, LiteDataType> type_map = { | |||||
{megdnn::DTypeEnum::Float32, LiteDataType::LITE_FLOAT}, | |||||
{megdnn::DTypeEnum::Int32, LiteDataType::LITE_INT}, | |||||
{megdnn::DTypeEnum::Int8, LiteDataType::LITE_INT8}, | |||||
{megdnn::DTypeEnum::Uint8, LiteDataType::LITE_UINT8}}; | |||||
for (auto& i : parser.inputs) { | |||||
//! get tensor information from data parser | |||||
auto tensor = i.second; | |||||
auto data_type = tensor.dtype(); | |||||
auto tensor_shape = tensor.shape(); | |||||
mgb::dt_byte* src = tensor.raw_ptr(); | |||||
//! set lite layout | |||||
lite::Layout layout; | |||||
layout.ndim = tensor_shape.ndim; | |||||
for (size_t idx = 0; idx < tensor_shape.ndim; idx++) { | |||||
layout.shapes[idx] = tensor_shape[idx]; | |||||
} | |||||
layout.data_type = type_map[data_type.enumv()]; | |||||
//! set network input tensor | |||||
std::shared_ptr<lite::Tensor> input_tensor = | |||||
network->get_io_tensor(i.first); | |||||
input_tensor->reset(src, layout); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void InputOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto parser = model->get_input_parser(); | |||||
for (size_t idx = 0; idx < data_path.size(); ++idx) { | |||||
parser.feed(data_path[idx].c_str()); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto parser = model->get_input_parser(); | |||||
auto network = model->get_mdl_load_result(); | |||||
auto tensormap = network.tensor_map; | |||||
for (auto& i : parser.inputs) { | |||||
mgb_assert( | |||||
tensormap.find(i.first) != tensormap.end(), | |||||
"can't find tesnor named %s", i.first.c_str()); | |||||
auto& in = tensormap.find(i.first)->second; | |||||
in->copy_from(i.second); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void IOdumpOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (enable_io_dump) { | |||||
LITE_WARN("enable text io dump"); | |||||
lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); | |||||
} | |||||
if (enable_bin_io_dump) { | |||||
LITE_WARN("enable binary io dump"); | |||||
lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); | |||||
} | |||||
//! FIX:when add API in lite complate this | |||||
if (enable_io_dump_stdout || enable_io_dump_stderr) { | |||||
LITE_THROW("lite model don't support the stdout or stderr io dump"); | |||||
} | |||||
if (enable_bin_out_dump) { | |||||
LITE_THROW("lite model don't support the binary output dump"); | |||||
} | |||||
if (enable_copy_to_host) { | |||||
LITE_WARN("lite model set copy to host defaultly"); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void IOdumpOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (enable_io_dump) { | |||||
mgb_log_warn("enable text io dump"); | |||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||||
model->get_mdl_config().comp_graph.get(), dump_path.c_str()); | |||||
iodump->print_addr(false); | |||||
io_dumper = std::move(iodump); | |||||
} | |||||
if (enable_io_dump_stdout) { | |||||
mgb_log_warn("enable text io dump to stdout"); | |||||
std::shared_ptr<FILE> std_out(stdout, [](FILE*) {}); | |||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||||
model->get_mdl_config().comp_graph.get(), std_out); | |||||
iodump->print_addr(false); | |||||
io_dumper = std::move(iodump); | |||||
} | |||||
if (enable_io_dump_stderr) { | |||||
mgb_log_warn("enable text io dump to stderr"); | |||||
std::shared_ptr<FILE> std_err(stderr, [](FILE*) {}); | |||||
auto iodump = std::make_unique<mgb::TextOprIODump>( | |||||
model->get_mdl_config().comp_graph.get(), std_err); | |||||
iodump->print_addr(false); | |||||
io_dumper = std::move(iodump); | |||||
} | |||||
if (enable_bin_io_dump) { | |||||
mgb_log_warn("enable binary io dump"); | |||||
auto iodump = std::make_unique<mgb::BinaryOprIODump>( | |||||
model->get_mdl_config().comp_graph.get(), dump_path); | |||||
io_dumper = std::move(iodump); | |||||
} | |||||
if (enable_bin_out_dump) { | |||||
mgb_log_warn("enable binary output dump"); | |||||
out_dumper = std::make_unique<OutputDumper>(dump_path.c_str()); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (enable_bin_out_dump) { | |||||
auto load_result = model->get_mdl_load_result(); | |||||
out_dumper->set(load_result.output_var_list); | |||||
std::vector<mgb::ComputingGraph::Callback> cb; | |||||
for (size_t i = 0; i < load_result.output_var_list.size(); i++) { | |||||
cb.push_back(out_dumper->bind()); | |||||
} | |||||
model->set_output_callback(cb); | |||||
} | |||||
if (enable_copy_to_host) { | |||||
auto load_result = model->get_mdl_load_result(); | |||||
std::vector<mgb::ComputingGraph::Callback> cb; | |||||
for (size_t i = 0; i < load_result.output_var_list.size(); i++) { | |||||
mgb::HostTensorND val; | |||||
auto callback = [val](const mgb::DeviceTensorND& dv) mutable { | |||||
val.copy_from(dv); | |||||
}; | |||||
cb.push_back(callback); | |||||
} | |||||
model->set_output_callback(cb); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_RUNNING_WAIT) { | |||||
if (enable_bin_out_dump) { | |||||
out_dumper->write_to_file(); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
////////////////////// Input options //////////////////////// | |||||
using namespace lar; | |||||
InputOption::InputOption() { | |||||
m_option_name = "input"; | |||||
size_t start = 0; | |||||
auto end = FLAGS_input.find(";", start); | |||||
while (end != std::string::npos) { | |||||
std::string path = FLAGS_input.substr(start, end - start); | |||||
data_path.emplace_back(path); | |||||
start = end + 1; | |||||
end = FLAGS_input.find(";", start); | |||||
} | |||||
data_path.emplace_back(FLAGS_input.substr(start)); | |||||
} | |||||
std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() { | |||||
static std::shared_ptr<InputOption> m_option(new InputOption); | |||||
if (InputOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(m_option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void InputOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
////////////////////// OprIOdump options //////////////////////// | |||||
IOdumpOption::IOdumpOption() { | |||||
m_option_name = "iodump"; | |||||
size_t valid_flag = 0; | |||||
if (!FLAGS_io_dump.empty()) { | |||||
dump_path = FLAGS_io_dump; | |||||
enable_io_dump = true; | |||||
valid_flag = valid_flag | (1 << 0); | |||||
} | |||||
if (!FLAGS_bin_io_dump.empty()) { | |||||
dump_path = FLAGS_bin_io_dump; | |||||
enable_bin_io_dump = true; | |||||
valid_flag = valid_flag | (1 << 1); | |||||
} | |||||
if (!FLAGS_bin_out_dump.empty()) { | |||||
dump_path = FLAGS_bin_out_dump; | |||||
enable_bin_out_dump = true; | |||||
valid_flag = valid_flag | (1 << 2); | |||||
} | |||||
if (FLAGS_io_dump_stdout) { | |||||
enable_io_dump_stdout = FLAGS_io_dump_stdout; | |||||
valid_flag = valid_flag | (1 << 3); | |||||
} | |||||
if (FLAGS_io_dump_stderr) { | |||||
enable_io_dump_stderr = FLAGS_io_dump_stderr; | |||||
valid_flag = valid_flag | (1 << 4); | |||||
} | |||||
// not only one dump set valid | |||||
if (valid_flag && (valid_flag & (valid_flag - 1))) { | |||||
mgb_log_warn( | |||||
"ONLY the last io dump option is validate and others is " | |||||
"skipped!!!"); | |||||
} | |||||
enable_copy_to_host = FLAGS_copy_to_host; | |||||
} | |||||
bool IOdumpOption::is_valid() { | |||||
bool ret = !FLAGS_io_dump.empty(); | |||||
ret = ret || FLAGS_io_dump_stdout; | |||||
ret = ret || FLAGS_io_dump_stderr; | |||||
ret = ret || !FLAGS_bin_io_dump.empty(); | |||||
ret = ret || !FLAGS_bin_out_dump.empty(); | |||||
ret = ret || FLAGS_copy_to_host; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> IOdumpOption::create_option() { | |||||
static std::shared_ptr<IOdumpOption> option(new IOdumpOption); | |||||
if (IOdumpOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void IOdumpOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
////////////////////// Input gflags //////////////////////// | |||||
DEFINE_string( | |||||
input, "", "Set up inputs data for model --input [ file_path | data_string]"); | |||||
////////////////////// OprIOdump gflags //////////////////////// | |||||
DEFINE_string(io_dump, "", "set the io dump file path in text format"); | |||||
DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); | |||||
DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); | |||||
DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); | |||||
DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); | |||||
DEFINE_bool(copy_to_host, false, "copy device data to host"); | |||||
REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); | |||||
REGIST_OPTION_CREATOR(iodump, lar::IOdumpOption::create_option); |
@@ -0,0 +1,78 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/io_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "helpers/outdumper.h" | |||||
#include "megbrain/plugin/opr_io_dump.h" | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_string(input); | |||||
DECLARE_string(io_dump); | |||||
DECLARE_bool(io_dump_stdout); | |||||
DECLARE_bool(io_dump_stderr); | |||||
DECLARE_string(bin_io_dump); | |||||
DECLARE_string(bin_out_dump); | |||||
DECLARE_bool(copy_to_host); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: input option for --input set | |||||
*/ | |||||
class InputOption final : public OptionBase { | |||||
public: | |||||
//! static function for registe options | |||||
static bool is_valid() { return !FLAGS_input.empty(); }; | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
//! interface implement from OptionBase | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
InputOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
std::vector<std::string> data_path; // data string or data file path | |||||
}; | |||||
class IOdumpOption : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
//! config the model, if different has different configure code, then | |||||
//! dispatch | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
IOdumpOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
bool enable_io_dump; | |||||
bool enable_io_dump_stdout; | |||||
bool enable_io_dump_stderr; | |||||
bool enable_bin_io_dump; | |||||
bool enable_bin_out_dump; | |||||
bool enable_copy_to_host; | |||||
std::string m_option_name; | |||||
std::string dump_path; | |||||
std::unique_ptr<mgb::OprIODumpBase> io_dumper; | |||||
std::unique_ptr<OutputDumper> out_dumper; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,171 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
#include "layout_options.h" | |||||
namespace lar { | |||||
template <> | |||||
void LayoutOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
#define ENABLE_LAYOUT(layout) \ | |||||
LITE_WARN("enable " #layout " optimization"); \ | |||||
model->get_config().options.enable_##layout = true; \ | |||||
break; | |||||
switch (option_flag) { | |||||
case OptLayoutType::NCHW4: | |||||
ENABLE_LAYOUT(nchw4) | |||||
case OptLayoutType::CHWN4: | |||||
LITE_THROW("lite model unsupport chwn4 layout"); | |||||
break; | |||||
case OptLayoutType::NCHW44: | |||||
ENABLE_LAYOUT(nchw44) | |||||
case OptLayoutType::NCHW88: | |||||
ENABLE_LAYOUT(nchw88) | |||||
case OptLayoutType::NCHW32: | |||||
ENABLE_LAYOUT(nchw32) | |||||
case OptLayoutType::NCHW64: | |||||
ENABLE_LAYOUT(nchw64) | |||||
case OptLayoutType::NHWCD4: | |||||
ENABLE_LAYOUT(nhwcd4) | |||||
case OptLayoutType::NCHW44_DOT: | |||||
ENABLE_LAYOUT(nchw44_dot) | |||||
default: | |||||
break; | |||||
} | |||||
#undef ENABLE_LAYOUT | |||||
} | |||||
} | |||||
template <> | |||||
void lar::LayoutOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
mgb_log_debug("mdl layout config start"); | |||||
#define ENABLE_LAYOUT(layout) \ | |||||
mgb_log_warn("enable " #layout " optimization"); \ | |||||
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | |||||
break; | |||||
switch (option_flag) { | |||||
case OptLayoutType::NCHW4: | |||||
ENABLE_LAYOUT(nchw4) | |||||
case OptLayoutType::CHWN4: | |||||
ENABLE_LAYOUT(chwn4) | |||||
case OptLayoutType::NCHW44: | |||||
ENABLE_LAYOUT(nchw44) | |||||
case OptLayoutType::NCHW88: | |||||
ENABLE_LAYOUT(nchw88) | |||||
case OptLayoutType::NCHW32: | |||||
ENABLE_LAYOUT(nchw32) | |||||
case OptLayoutType::NCHW64: | |||||
ENABLE_LAYOUT(nchw64) | |||||
case OptLayoutType::NHWCD4: | |||||
ENABLE_LAYOUT(nhwcd4) | |||||
case OptLayoutType::NCHW44_DOT: | |||||
ENABLE_LAYOUT(nchw44_dot) | |||||
default: | |||||
break; | |||||
} | |||||
mgb_log_debug("mdl layout config end"); | |||||
#undef ENABLE_LAYOUT | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
OptLayoutType LayoutOption::option_flag; | |||||
LayoutOption::LayoutOption() { | |||||
m_option_name = "layout"; | |||||
} | |||||
bool LayoutOption::is_valid() { | |||||
size_t valid_flag = 0; | |||||
if (FLAGS_enable_nchw4) { | |||||
valid_flag = valid_flag | (1 << 0); | |||||
} | |||||
if (FLAGS_enable_chwn4) { | |||||
valid_flag = valid_flag | (1 << 1); | |||||
} | |||||
if (FLAGS_enable_nchw44) { | |||||
valid_flag = valid_flag | (1 << 2); | |||||
} | |||||
if (FLAGS_enable_nchw88) { | |||||
valid_flag = valid_flag | (1 << 3); | |||||
} | |||||
if (FLAGS_enable_nchw32) { | |||||
valid_flag = valid_flag | (1 << 4); | |||||
} | |||||
if (FLAGS_enable_nchw64) { | |||||
valid_flag = valid_flag | (1 << 5); | |||||
} | |||||
if (FLAGS_enable_nhwcd4) { | |||||
valid_flag = valid_flag | (1 << 6); | |||||
} | |||||
if (FLAGS_enable_nchw44_dot) { | |||||
valid_flag = valid_flag | (1 << 7); | |||||
} | |||||
bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); | |||||
if (ret) { | |||||
option_flag = static_cast<OptLayoutType>(valid_flag); | |||||
} else { | |||||
option_flag = static_cast<OptLayoutType>(0); | |||||
} | |||||
return ret; | |||||
}; | |||||
std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||||
static std::shared_ptr<LayoutOption> option(new LayoutOption); | |||||
if (LayoutOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void LayoutOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
DEFINE_bool(enable_nchw4, false, "enable nchw4 layout optimization!!"); | |||||
DEFINE_bool(enable_chwn4, false, "enable chwn4 layout optimization!!"); | |||||
DEFINE_bool(enable_nchw44, false, "enable nchw44 layout optimization!!"); | |||||
DEFINE_bool(enable_nchw88, false, "enable nchw88 layout optimization!!"); | |||||
DEFINE_bool(enable_nchw32, false, "enable nchw32 layout optimization!!"); | |||||
DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!"); | |||||
DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); | |||||
DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); | |||||
REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "helpers/common.h" | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_bool(enable_nchw4); | |||||
DECLARE_bool(enable_chwn4); | |||||
DECLARE_bool(enable_nchw44); | |||||
DECLARE_bool(enable_nchw88); | |||||
DECLARE_bool(enable_nchw32); | |||||
DECLARE_bool(enable_nchw64); | |||||
DECLARE_bool(enable_nhwcd4); | |||||
DECLARE_bool(enable_nchw44_dot); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: layout option for optimization | |||||
*/ | |||||
class LayoutOption final : public OptionBase { | |||||
public: | |||||
//! check the validation of option flag | |||||
static bool is_valid(); | |||||
//! creat options when option is used | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
//! config the model, dispatch configuration for different model implement | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
//! get option name | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
//! Constructor | |||||
LayoutOption(); | |||||
//! configuration for different model implement | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
static OptLayoutType option_flag; | |||||
std::string m_option_name; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,600 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/optimize_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "megbrain/gopt/inference.h" | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
#include "megbrain/tensorrt/tensorrt_engine_cache.h" | |||||
#endif | |||||
#include "lite/global.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
#include "optimize_options.h" | |||||
///////////////////////// fuse and preprocess optimize options /////////////// | |||||
namespace lar { | |||||
template <> | |||||
void FusePreprocessOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (enable_fuse_preprocess) { | |||||
LITE_WARN("enable fuse-preprocess optimization"); | |||||
model->get_config().options.fuse_preprocess = true; | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void FusePreprocessOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (enable_fuse_preprocess) { | |||||
mgb_log_warn("enable fuse-preprocess optimization"); | |||||
graph_option.graph_opt.enable_fuse_preprocess(); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
FusePreprocessOption::FusePreprocessOption() { | |||||
m_option_name = "fuse_preprocess"; | |||||
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | |||||
} | |||||
bool FusePreprocessOption::is_valid() { | |||||
bool ret = FLAGS_enable_fuse_preprocess; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||||
static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption); | |||||
if (FusePreprocessOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void FusePreprocessOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// weight preprocess optimize options /////////////// | |||||
namespace lar { | |||||
template <> | |||||
void WeightPreprocessOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (weight_preprocess) { | |||||
LITE_WARN("enable weight-preprocess optimization"); | |||||
model->get_config().options.weight_preprocess = true; | |||||
//! FIXME: algo searcher enable weight preprocess for opencl( | |||||
//! implement below has some problem); | |||||
// #if MGB_OPENCL | |||||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||||
// enable_weight_preprocess(); | |||||
// #endif | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (weight_preprocess) { | |||||
mgb_log_warn("enable weight-preprocess optimization"); | |||||
graph_option.graph_opt.enable_weight_preprocess(); | |||||
//! FIXME: this implemment is not right | |||||
// #if MGB_OPENCL | |||||
// megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||||
// enable_weight_preprocess(); | |||||
// #endif | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
WeightPreprocessOption::WeightPreprocessOption() { | |||||
m_option_name = "weight_preprocess"; | |||||
weight_preprocess = FLAGS_weight_preprocess; | |||||
} | |||||
bool WeightPreprocessOption::is_valid() { | |||||
bool ret = FLAGS_weight_preprocess; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||||
static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption); | |||||
if (WeightPreprocessOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void WeightPreprocessOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///// fuse conv bias and nonlinear activation opr optimize options //////// | |||||
namespace lar { | |||||
template <> | |||||
void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
LITE_MARK_USED_VAR(model); | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (enable_fuse_conv_bias_nonlinearity) { | |||||
LITE_THROW("fuse conv+bias+nonlinearity not supported in lite model"); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (enable_fuse_conv_bias_nonlinearity) { | |||||
mgb_log_warn("enable fuse conv+bias+nonlinearity optimization"); | |||||
graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | |||||
m_option_name = "fuse_conv_bias_nonlinear"; | |||||
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||||
} | |||||
bool FuseConvBiasNonlinearOption::is_valid() { | |||||
bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||||
static std::shared_ptr<FuseConvBiasNonlinearOption> option( | |||||
new FuseConvBiasNonlinearOption); | |||||
if (FuseConvBiasNonlinearOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void FuseConvBiasNonlinearOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// fuse and preprocess optimize options /////////////// | |||||
namespace lar { | |||||
template <> | |||||
void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
LITE_MARK_USED_VAR(model); | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (enable_fuse_conv_bias_with_z) { | |||||
LITE_THROW( | |||||
"fuse conv+bias+z optimization not supported in lite " | |||||
"model"); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (enable_fuse_conv_bias_with_z) { | |||||
mgb_log_warn("enable fuse conv+bias+z optimization"); | |||||
graph_option.graph_opt.enable_fuse_conv_bias_with_z(); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | |||||
m_option_name = "fuse_conv_bias_z"; | |||||
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | |||||
} | |||||
bool FuseConvBiasElemwiseAddOption::is_valid() { | |||||
bool ret = FLAGS_enable_fuse_conv_bias_with_z; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||||
static std::shared_ptr<FuseConvBiasElemwiseAddOption> option( | |||||
new FuseConvBiasElemwiseAddOption); | |||||
if (FuseConvBiasElemwiseAddOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void FuseConvBiasElemwiseAddOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// graph retrict options ///////////////////////// | |||||
namespace lar { | |||||
template <> | |||||
void GraphRecordOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& config_option = model->get_config().options; | |||||
if (const_shape) { | |||||
LITE_WARN("enable const var shape"); | |||||
config_option.const_shape = true; | |||||
} | |||||
if (fake_first) { | |||||
LITE_WARN("enable fake-first optimization"); | |||||
config_option.fake_next_exec = true; | |||||
} | |||||
if (no_sanity_check) { | |||||
LITE_WARN("disable var sanity check optimization"); | |||||
config_option.var_sanity_check_first_run = false; | |||||
} | |||||
if (m_record_comp_seq == 1) { | |||||
LITE_WARN("set record_comp_seq_level to 1"); | |||||
} | |||||
if (m_record_comp_seq == 2) { | |||||
mgb_assert( | |||||
no_sanity_check, | |||||
"--no-sanity-check should be set before " | |||||
"--record-comp-seq2"); | |||||
LITE_WARN("set record_comp_seq_level to 2"); | |||||
} | |||||
config_option.comp_node_seq_record_level = m_record_comp_seq; | |||||
} | |||||
} | |||||
template <> | |||||
void GraphRecordOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (const_shape) { | |||||
mgb_log_warn("enable const var shape"); | |||||
model->get_mdl_config().const_var_shape = true; | |||||
} | |||||
if (fake_first) { | |||||
mgb_log_warn("enable fake-first optimization"); | |||||
graph_option.fake_next_exec = true; | |||||
} | |||||
if (no_sanity_check) { | |||||
mgb_log_warn("disable var sanity check optimization"); | |||||
graph_option.var_sanity_check_first_run = false; | |||||
} | |||||
if (m_record_comp_seq == 1) { | |||||
mgb_log_warn("set record_comp_seq_level to 1"); | |||||
} | |||||
if (m_record_comp_seq == 2) { | |||||
mgb_assert( | |||||
no_sanity_check && !fake_first, | |||||
"--no-sanity-check should be set before " | |||||
"--record-comp-seq2 and --fake-first should not be set"); | |||||
mgb_log_warn("set record_comp_seq_level to 2"); | |||||
} | |||||
graph_option.comp_node_seq_record_level = m_record_comp_seq; | |||||
} | |||||
} | |||||
} // namespace lar | |||||
GraphRecordOption::GraphRecordOption() { | |||||
m_option_name = "graph_record"; | |||||
m_record_comp_seq = 0; | |||||
const_shape = FLAGS_const_shape; | |||||
fake_first = FLAGS_fake_first; | |||||
no_sanity_check = FLAGS_no_sanity_check; | |||||
if (FLAGS_record_comp_seq) { | |||||
m_record_comp_seq = 1; | |||||
} | |||||
if (FLAGS_record_comp_seq2) { | |||||
m_record_comp_seq = 2; | |||||
} | |||||
} | |||||
bool GraphRecordOption::is_valid() { | |||||
bool ret = FLAGS_const_shape; | |||||
ret = ret || FLAGS_fake_first; | |||||
ret = ret || FLAGS_no_sanity_check; | |||||
ret = ret || FLAGS_record_comp_seq; | |||||
ret = ret || FLAGS_record_comp_seq2; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||||
static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption); | |||||
if (GraphRecordOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void GraphRecordOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// graph retrict options ///////////////////////// | |||||
namespace lar { | |||||
template <> | |||||
void MemoryOptimizeOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
LITE_MARK_USED_VAR(model); | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (disable_mem_opt) { | |||||
LITE_THROW("lite model don't support disable memory optimization"); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (workspace_limit != SIZE_MAX) { | |||||
LITE_WARN("set workspace limit to %ld", workspace_limit); | |||||
lite::Runtime::set_network_algo_workspace_limit( | |||||
model->get_lite_network(), workspace_limit); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (disable_mem_opt) { | |||||
mgb_log_warn("disable memory optimization"); | |||||
graph_option.seq_opt.enable_mem_plan_opt = false; | |||||
graph_option.seq_opt.enable_mem_reuse_alloc = false; | |||||
} | |||||
if (workspace_limit < SIZE_MAX) { | |||||
mgb_log_warn("set workspace limit to %ld", workspace_limit); | |||||
auto output_spec = model->get_output_spec(); | |||||
mgb::SymbolVarArray vars; | |||||
for (auto i : output_spec) { | |||||
vars.push_back(i.first); | |||||
} | |||||
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
MemoryOptimizeOption::MemoryOptimizeOption() { | |||||
m_option_name = "memory_optimize"; | |||||
disable_mem_opt = FLAGS_disable_mem_opt; | |||||
workspace_limit = FLAGS_workspace_limit; | |||||
} | |||||
bool MemoryOptimizeOption::is_valid() { | |||||
bool ret = FLAGS_disable_mem_opt; | |||||
ret = ret || FLAGS_workspace_limit < SIZE_MAX; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() { | |||||
static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption); | |||||
if (MemoryOptimizeOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void MemoryOptimizeOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// other options for optimization ///////////////// | |||||
namespace lar { | |||||
template <> | |||||
void JITOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& config_option = model->get_config().options; | |||||
if (enable_jit) { | |||||
LITE_WARN("enable JIT (level 1)"); | |||||
config_option.jit_level = 1; | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void JITOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (enable_jit) { | |||||
mgb_log_warn("enable JIT (level 1)"); | |||||
graph_option.graph_opt.jit = 1; | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
JITOption::JITOption() { | |||||
m_option_name = "JIT"; | |||||
enable_jit = FLAGS_enable_jit; | |||||
} | |||||
bool JITOption::is_valid() { | |||||
bool ret = FLAGS_enable_jit; | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> JITOption::create_option() { | |||||
static std::shared_ptr<JITOption> option(new JITOption); | |||||
if (JITOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void JITOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////////// other options for optimization ///////////////// | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
namespace lar { | |||||
template <> | |||||
void TensorRTOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
if (!tensorrt_cache.empty()) { | |||||
LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||||
lite::set_tensor_rt_cache(tensorrt_cache); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (enable_tensorrt) { | |||||
LITE_WARN("enable TensorRT"); | |||||
lite::Runtime::use_tensorrt(model->get_lite_network()); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
if (!tensorrt_cache.empty()) { | |||||
lite::dump_tensor_rt_cache(); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void TensorRTOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||||
if (enable_tensorrt) { | |||||
mgb_log_warn("using tensorRT"); | |||||
graph_option.graph_opt.tensorrt = true; | |||||
} | |||||
if (!tensorrt_cache.empty()) { | |||||
mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||||
mgb::TensorRTEngineCache::enable_engine_cache(true); | |||||
mgb::TensorRTEngineCache::set_impl( | |||||
std::make_shared<mgb::TensorRTEngineCacheIO>( | |||||
tensorrt_cache.c_str())); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
if (!tensorrt_cache.empty()) { | |||||
if (mgb::TensorRTEngineCache::enable_engine_cache()) { | |||||
mgb::TensorRTEngineCache::inst().dump_cache(); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
TensorRTOption::TensorRTOption() { | |||||
m_option_name = "tensorRT"; | |||||
enable_tensorrt = FLAGS_tensorrt; | |||||
tensorrt_cache = FLAGS_tensorrt_cache; | |||||
} | |||||
bool TensorRTOption::is_valid() { | |||||
bool ret = FLAGS_tensorrt; | |||||
ret = ret || !FLAGS_tensorrt_cache.empty(); | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> TensorRTOption::create_option() { | |||||
static std::shared_ptr<TensorRTOption> option(new TensorRTOption); | |||||
if (TensorRTOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void TensorRTOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
#endif | |||||
///////////////////////// fuse and preprocess optimize options /////////////// | |||||
DEFINE_bool( | |||||
enable_fuse_preprocess, false, | |||||
"Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); | |||||
DEFINE_bool( | |||||
weight_preprocess, false, | |||||
"Execute operators with weight preprocess, which can optimize the " | |||||
"operator execution time with algo of winograd, im2col ,etc., but " | |||||
"it may consume more memory."); | |||||
DEFINE_bool( | |||||
enable_fuse_conv_bias_nonlinearity, false, | |||||
"whether to fuse conv+bias+nonlinearity"); | |||||
DEFINE_bool( | |||||
enable_fuse_conv_bias_with_z, false, | |||||
"fuse conv,bias (elemwise add),z(elemwise add) into one opr " | |||||
"(only support on GPU)"); | |||||
///////////////////////// graph retrict options ///////////////////////// | |||||
DEFINE_bool( | |||||
const_shape, false, | |||||
"set const_var_shape to reduce memory usage, since some static " | |||||
"inference data structures can be omitted"); | |||||
DEFINE_bool( | |||||
fake_first, false, | |||||
"Enable fake exec for the first run. In fake exec mode, some " | |||||
"initialization job would be done, but no actual computing is " | |||||
"performed."); | |||||
DEFINE_bool(no_sanity_check, false, "Disable var sanity check on the first run"); | |||||
DEFINE_bool( | |||||
record_comp_seq, false, | |||||
"Record the computing sequence, in level 1 . It reduces overhead of API" | |||||
"calls of some asynchronous computing devices"); | |||||
DEFINE_bool( | |||||
record_comp_seq2, false, | |||||
"Record the computing sequence, in level 2, the computing graph can be" | |||||
"destructed to reduce memory usage"); | |||||
DEFINE_bool(disable_mem_opt, false, "disable memory optimization!!"); | |||||
DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit"); | |||||
///////////////////////// other options for optimization ///////////////// | |||||
DEFINE_bool( | |||||
enable_jit, false, | |||||
" Execute supported operators with JIT(now only support NVRTC). " | |||||
"Can only be used on Nvidia GPUs"); | |||||
#if MGB_ENABLE_ANDROID_NN | |||||
DEFINE_bool( | |||||
android_nn, false, | |||||
"Execute supported operators with Android NN. Can only be used " | |||||
"with --cpu."); | |||||
#endif | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
DEFINE_bool( | |||||
tensorrt, false, | |||||
" Execute supported operators with TensorRT. Can only be used on " | |||||
"Nvidia GPUs,i.e. comp node is xpu or gpu."); | |||||
DEFINE_string( | |||||
tensorrt_cache, "", | |||||
"Set the TensorRT engine cache path for serialized prebuilt " | |||||
"ICudaEngine"); | |||||
#endif | |||||
REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); | |||||
REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | |||||
REGIST_OPTION_CREATOR( | |||||
fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option); | |||||
REGIST_OPTION_CREATOR( | |||||
fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||||
REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); | |||||
REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); | |||||
REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
REGIST_OPTION_CREATOR(tensorRT, lar::TensorRTOption::create_option); | |||||
#endif |
@@ -0,0 +1,207 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/optimize_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "helpers/common.h" | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_bool(enable_fuse_preprocess); | |||||
DECLARE_bool(weight_preprocess); | |||||
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | |||||
DECLARE_bool(enable_fuse_conv_bias_with_z); | |||||
DECLARE_bool(const_shape); | |||||
DECLARE_bool(fake_first); | |||||
DECLARE_bool(no_sanity_check); | |||||
DECLARE_bool(record_comp_seq); | |||||
DECLARE_bool(record_comp_seq2); | |||||
DECLARE_bool(disable_mem_opt); | |||||
DECLARE_uint64(workspace_limit); | |||||
DECLARE_bool(enable_jit); | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
DECLARE_bool(tensorrt); | |||||
DECLARE_string(tensorrt_cache); | |||||
#endif | |||||
namespace lar { | |||||
///////////////////////// fuse_preprocess optimize options ////////////// | |||||
class FusePreprocessOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
FusePreprocessOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool enable_fuse_preprocess; | |||||
}; | |||||
///////////////////////// weight preprocess optimize options ////////////// | |||||
class WeightPreprocessOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
WeightPreprocessOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool weight_preprocess; | |||||
}; | |||||
/////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | |||||
class FuseConvBiasNonlinearOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
FuseConvBiasNonlinearOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool enable_fuse_conv_bias_nonlinearity; | |||||
}; | |||||
///////////////////////// fuse_conv_bias_with_z optimize options ////////////// | |||||
class FuseConvBiasElemwiseAddOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
FuseConvBiasElemwiseAddOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool enable_fuse_conv_bias_with_z; | |||||
}; | |||||
///////////////////////// graph record options /////////////////////////// | |||||
class GraphRecordOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
GraphRecordOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
size_t m_record_comp_seq; | |||||
bool const_shape; | |||||
bool fake_first; | |||||
bool no_sanity_check; | |||||
}; | |||||
///////////////////////// memory optimize options ///////////////////////// | |||||
class MemoryOptimizeOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
MemoryOptimizeOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool disable_mem_opt; | |||||
uint64_t workspace_limit; | |||||
}; | |||||
///////////////////////// other options for optimization ///////////////// | |||||
class JITOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
JITOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool enable_jit; | |||||
}; | |||||
///////////////////////// TensorRT options for optimization ///////////////// | |||||
#if MGB_ENABLE_TENSOR_RT | |||||
class TensorRTOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
TensorRTOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
std::string m_option_name; | |||||
bool enable_tensorrt; | |||||
std::string tensorrt_cache; | |||||
}; | |||||
#endif | |||||
} // namespace lar |
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/option_base.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <functional> | |||||
#include <iostream> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
#include "megbrain/common.h" | |||||
#include "helpers/common.h" | |||||
#include "models/model.h" | |||||
namespace lar { | |||||
/*! | |||||
* \brief: base class of options | |||||
*/ | |||||
class OptionBase { | |||||
public: | |||||
//! configure model in different runtime state | |||||
virtual void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) = 0; | |||||
//! get depend options | |||||
virtual std::vector<std::string> depend_option() const { return {}; }; | |||||
//! get option name | |||||
virtual std::string option_name() const = 0; | |||||
virtual ~OptionBase() = default; | |||||
}; | |||||
/*! | |||||
* \brief: Singleton option factory for register options before main function | |||||
*/ | |||||
class OptionFactory { | |||||
public: | |||||
using OptionCreator = std::function<std::shared_ptr<OptionBase>()>; | |||||
using OptionMap = std::unordered_map<std::string, OptionCreator>; | |||||
//! get Singleton option factory | |||||
static OptionFactory& get_Instance() { | |||||
static OptionFactory instance; | |||||
return instance; | |||||
} | |||||
//! registe option creator into option map | |||||
void registe_options(std::string name, OptionCreator creator) { | |||||
if (option_creator_map.count(name) == 0) { | |||||
option_creator_map[name] = creator; | |||||
} | |||||
} | |||||
//! get creator map | |||||
OptionMap* get_option_creator_map() { return &option_creator_map; } | |||||
private: | |||||
OptionFactory(){}; | |||||
OptionMap option_creator_map; | |||||
}; | |||||
} // namespace lar | |||||
#define REGIST_OPTION_CREATOR(name_, creator_) \ | |||||
struct OptionRegister_##name_ { \ | |||||
OptionRegister_##name_() { \ | |||||
lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \ | |||||
} \ | |||||
}; \ | |||||
OptionRegister_##name_ name_; | |||||
#define CONFIG_MODEL_FUN \ | |||||
if (model->type() == ModelType::LITE_MODEL) { \ | |||||
config_model_internel<ModelLite>( \ | |||||
runtime_param, std::static_pointer_cast<ModelLite>(model)); \ | |||||
} else if (model->type() == ModelType::MEGDL_MODEL) { \ | |||||
config_model_internel<ModelMdl>( \ | |||||
runtime_param, std::static_pointer_cast<ModelMdl>(model)); \ | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,401 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/plugin_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "plugin_options.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
///////////////////// Plugin options/////////////////////////// | |||||
namespace lar { | |||||
template <> | |||||
void PluginOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
LITE_ASSERT(range == 0, "lite model don't support NumRangeChecker plugin"); | |||||
LITE_ASSERT( | |||||
!enable_check_dispatch, | |||||
"lite model don't support CPUDispatchChecker plugin"); | |||||
LITE_ASSERT( | |||||
var_value_check_str.empty(), | |||||
"lite model don't support VarValueChecker plugin"); | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (!profile_path.empty()) { | |||||
if (!enable_profile_host) { | |||||
LITE_WARN("enable profiling"); | |||||
model->get_lite_network()->enable_profile_performance(profile_path); | |||||
} else { | |||||
LITE_WARN("enable profiling for host"); | |||||
model->get_lite_network()->enable_profile_performance(profile_path); | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
template <> | |||||
void PluginOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto config = model->get_mdl_config(); | |||||
if (range > 0) { | |||||
mgb_log_warn("enable number range check"); | |||||
model->set_num_range_checker(float(range)); | |||||
} | |||||
if (enable_check_dispatch) { | |||||
mgb_log_warn("enable cpu dispatch check"); | |||||
cpu_dispatch_checker = | |||||
std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get()); | |||||
} | |||||
if (!var_value_check_str.empty()) { | |||||
mgb_log_warn("enable variable value check"); | |||||
size_t init_idx = 0, switch_interval; | |||||
auto sep = var_value_check_str.find(':'); | |||||
if (sep != std::string::npos) { | |||||
switch_interval = std::stoul(var_value_check_str.substr(0, sep)); | |||||
init_idx = std::stoul(var_value_check_str.substr(sep + 1)); | |||||
} else { | |||||
switch_interval = std::stoul(var_value_check_str); | |||||
} | |||||
var_value_checker = std::make_unique<mgb::VarValueChecker>( | |||||
config.comp_graph.get(), switch_interval, init_idx); | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
if (!profile_path.empty()) { | |||||
if (!enable_profile_host) { | |||||
mgb_log_warn("enable profiling"); | |||||
} else { | |||||
mgb_log_warn("enable profiling for host"); | |||||
} | |||||
model->set_profiler(); | |||||
} | |||||
#endif | |||||
} | |||||
else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
#if MGB_ENABLE_JSON | |||||
if (!profile_path.empty()) { | |||||
mgb_log_warn("filename %s", profile_path.c_str()); | |||||
if (model->get_profiler()) { | |||||
model->get_profiler() | |||||
->to_json_full(model->get_async_func().get()) | |||||
->writeto_fpath(profile_path); | |||||
mgb_log_warn("profiling result written to %s", profile_path.c_str()); | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
PluginOption::PluginOption() { | |||||
m_option_name = "plugin"; | |||||
range = FLAGS_range; | |||||
enable_check_dispatch = FLAGS_check_dispatch; | |||||
var_value_check_str = FLAGS_check_var_value; | |||||
#if MGB_ENABLE_JSON | |||||
enable_profile_host = false; | |||||
if (!FLAGS_profile.empty()) { | |||||
profile_path = FLAGS_profile; | |||||
} | |||||
if (!FLAGS_profile_host.empty()) { | |||||
enable_profile_host = !FLAGS_profile_host.empty(); | |||||
profile_path = FLAGS_profile_host; | |||||
} | |||||
#endif | |||||
} | |||||
bool PluginOption::is_valid() { | |||||
bool ret = FLAGS_check_dispatch; | |||||
ret = ret || FLAGS_range > 0; | |||||
ret = ret || !FLAGS_check_var_value.empty(); | |||||
#if MGB_ENABLE_JSON | |||||
ret = ret || !FLAGS_profile.empty(); | |||||
ret = ret || !FLAGS_profile_host.empty(); | |||||
#endif | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> PluginOption::create_option() { | |||||
static std::shared_ptr<PluginOption> option(new PluginOption); | |||||
if (PluginOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void PluginOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////// Debug options/////////////////////////// | |||||
namespace lar { | |||||
template <> | |||||
void DebugOption::format_and_print( | |||||
const std::string& tablename, std::shared_ptr<ModelLite> model) { | |||||
auto table = mgb::TextTable(tablename); | |||||
auto network = model->get_lite_network(); | |||||
table.padding(1); | |||||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||||
auto to_string = [&](lite::Layout& layout) { | |||||
std::string shape("{"); | |||||
for (size_t i = 0; i < layout.ndim; i++) { | |||||
if (i) | |||||
shape.append(","); | |||||
shape.append(std::to_string(layout.shapes[i])); | |||||
} | |||||
shape.append("}"); | |||||
return shape; | |||||
}; | |||||
auto input_name = network->get_all_input_name(); | |||||
for (auto& i : input_name) { | |||||
auto layout = network->get_io_tensor(i)->get_layout(); | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("INPUT") | |||||
.add(i) | |||||
.add(to_string(layout)) | |||||
.eor(); | |||||
} | |||||
auto output_name = network->get_all_output_name(); | |||||
for (auto& i : output_name) { | |||||
auto layout = network->get_io_tensor(i)->get_layout(); | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("OUTPUT") | |||||
.add(i) | |||||
.add(to_string(layout)) | |||||
.eor(); | |||||
} | |||||
std::stringstream ss; | |||||
ss << table; | |||||
printf("%s\n\n", ss.str().c_str()); | |||||
} | |||||
template <> | |||||
void DebugOption::format_and_print( | |||||
const std::string& tablename, std::shared_ptr<ModelMdl> model) { | |||||
auto table = mgb::TextTable(tablename); | |||||
table.padding(1); | |||||
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||||
for (auto&& i : model->get_mdl_load_result().tensor_map) { | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("INPUT") | |||||
.add(i.first) | |||||
.add(i.second->shape().to_string()) | |||||
.eor(); | |||||
} | |||||
for (auto&& i : model->get_mdl_load_result().output_var_list) { | |||||
table.align(mgb::TextTable::Align::Mid) | |||||
.add("OUTPUT") | |||||
.add(i.node()->name()) | |||||
.add(i.shape().to_string()) | |||||
.eor(); | |||||
} | |||||
std::stringstream ss; | |||||
ss << table; | |||||
printf("%s\n\n", ss.str().c_str()); | |||||
} | |||||
template <> | |||||
void DebugOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
LITE_ASSERT( | |||||
!disable_assert_throw, "lite model don't support disable assert throw"); | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
LITE_ASSERT( | |||||
static_mem_log_dir_path.empty(), | |||||
"lite model don't support static memory information export"); | |||||
#endif | |||||
#endif | |||||
if (enable_verbose) { | |||||
LITE_WARN("enable verbose"); | |||||
lite::set_log_level(LiteLogLevel::DEBUG); | |||||
} | |||||
#if __linux__ || __unix__ | |||||
if (enable_wait_gdb) { | |||||
printf("wait for gdb attach (pid=%d): ", getpid()); | |||||
getchar(); | |||||
} | |||||
#endif | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
if (enable_display_model_info) { | |||||
LITE_WARN("enable display model information"); | |||||
format_and_print<ModelLite>("Runtime Model Info", model); | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
if (enable_display_model_info) { | |||||
format_and_print<ModelLite>("Runtime Model Info", model); | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void DebugOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
auto config = model->get_mdl_config(); | |||||
if (enable_verbose) { | |||||
mgb_log_warn("enable verbose"); | |||||
mgb::set_log_level(mgb::LogLevel::DEBUG); | |||||
} | |||||
#if __linux__ || __unix__ | |||||
if (enable_wait_gdb) { | |||||
printf("wait for gdb attach (pid=%d): ", getpid()); | |||||
getchar(); | |||||
} | |||||
#endif | |||||
} else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | |||||
if (enable_display_model_info) { | |||||
mgb_log_warn("enable display model information"); | |||||
format_and_print<ModelMdl>("Runtime Model Info", model); | |||||
} | |||||
if (disable_assert_throw) { | |||||
mgb_log_warn("disable assert throw"); | |||||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||||
if (opr->same_type<mgb::opr::AssertEqual>()) { | |||||
opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error(); | |||||
} | |||||
}; | |||||
mgb::cg::DepOprIter iter{on_opr}; | |||||
for (auto&& i : model->get_output_spec()) { | |||||
iter.add(i.first.node()->owner_opr()); | |||||
} | |||||
} | |||||
} else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | |||||
//! FIX:it don't work for cpu build (nothing dumped) | |||||
//! megbrain/sdk origin code will assert(m_recorded) in | |||||
//! EventImplHelper::finished(); | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
if (!static_mem_log_dir_path.empty()) { | |||||
mgb_log_warn("enable get static memeory information"); | |||||
model->get_async_func()->get_static_memory_alloc_info( | |||||
static_mem_log_dir_path); | |||||
} | |||||
#endif | |||||
#endif | |||||
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
if (enable_display_model_info) { | |||||
format_and_print<ModelMdl>("Runtime Model Info", model); | |||||
} | |||||
} | |||||
} | |||||
} // namespace lar | |||||
DebugOption::DebugOption() { | |||||
m_option_name = "debug"; | |||||
enable_display_model_info = FLAGS_model_info; | |||||
enable_verbose = FLAGS_verbose; | |||||
disable_assert_throw = FLAGS_disable_assert_throw; | |||||
#if __linux__ || __unix__ | |||||
enable_wait_gdb = FLAGS_wait_gdb; | |||||
#endif | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
static_mem_log_dir_path = FLAGS_get_static_mem_info; | |||||
#endif | |||||
#endif | |||||
} | |||||
bool DebugOption::is_valid() { | |||||
bool ret = FLAGS_model_info; | |||||
ret = ret || FLAGS_verbose; | |||||
ret = ret || FLAGS_disable_assert_throw; | |||||
#if __linux__ || __unix__ | |||||
ret = ret || FLAGS_wait_gdb; | |||||
#endif | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
ret = ret || !FLAGS_get_static_mem_info.empty(); | |||||
#endif | |||||
#endif | |||||
return ret; | |||||
} | |||||
std::shared_ptr<OptionBase> DebugOption::create_option() { | |||||
static std::shared_ptr<DebugOption> option(new DebugOption); | |||||
if (DebugOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void DebugOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
///////////////////// Plugin gflags/////////////////////////// | |||||
DEFINE_double( | |||||
range, 0, | |||||
"check whether absolute value of all numbers in computing graph " | |||||
"is in the given range"); | |||||
DEFINE_bool( | |||||
check_dispatch, false, | |||||
"check whether an operator call dispatch on cpu comp nodes"); | |||||
DEFINE_string( | |||||
check_var_value, "", | |||||
"--check-var-value [interval]|[interval:init_idx], Enable " | |||||
"VarValueChecker plugin. Refer to its doc for more details"); | |||||
#if MGB_ENABLE_JSON | |||||
DEFINE_string( | |||||
profile, "", | |||||
"Write profiling result to given file. The output file is in " | |||||
"JSON format"); | |||||
DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); | |||||
#endif | |||||
///////////////////// Debug gflags/////////////////////////// | |||||
DEFINE_bool( | |||||
model_info, false, | |||||
" Format and display model input/output tensor inforamtion"); | |||||
DEFINE_bool(verbose, false, "get more inforamtion for debug"); | |||||
DEFINE_bool(disable_assert_throw, false, "disable assert throw on error check"); | |||||
#if __linux__ || __unix__ | |||||
DEFINE_bool(wait_gdb, false, "print current process PID and wait for gdb attach"); | |||||
#endif | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
DEFINE_string( | |||||
get_static_mem_info, "", | |||||
"Record the static computing graph's static memory information"); | |||||
#endif | |||||
#endif | |||||
REGIST_OPTION_CREATOR(plugin, lar::PluginOption::create_option); | |||||
REGIST_OPTION_CREATOR(debug, lar::DebugOption::create_option); |
@@ -0,0 +1,105 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/plugin_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#if __linux__ || __unix__ | |||||
#include <unistd.h> | |||||
#endif | |||||
#include "megbrain/plugin/cpu_dispatch_checker.h" | |||||
#include "megbrain/plugin/var_value_checker.h" | |||||
#include "helpers/common.h" | |||||
#include "helpers/text_table.h" | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_bool(check_dispatch); | |||||
DECLARE_double(range); | |||||
DECLARE_string(check_var_value); | |||||
#if MGB_ENABLE_JSON | |||||
DECLARE_string(profile); | |||||
DECLARE_string(profile_host); | |||||
#endif | |||||
DECLARE_bool(model_info); | |||||
DECLARE_bool(verbose); | |||||
DECLARE_bool(disable_assert_throw); | |||||
#if __linux__ || __unix__ | |||||
DECLARE_bool(wait_gdb); | |||||
#endif | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
DECLARE_string(get_static_mem_info); | |||||
#endif | |||||
#endif | |||||
namespace lar { | |||||
class PluginOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
PluginOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
double range; | |||||
bool enable_check_dispatch; | |||||
#if MGB_ENABLE_JSON | |||||
bool enable_profile_host; | |||||
std::string profile_path; | |||||
#endif | |||||
std::string var_value_check_str; | |||||
std::string m_option_name; | |||||
std::unique_ptr<mgb::VarValueChecker> var_value_checker; | |||||
std::unique_ptr<mgb::CPUDispatchChecker> cpu_dispatch_checker; | |||||
}; | |||||
class DebugOption final : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
DebugOption(); | |||||
template <typename ModelImpl> | |||||
void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){}; | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||||
bool enable_display_model_info; | |||||
bool enable_verbose; | |||||
bool disable_assert_throw; | |||||
#if __linux__ || __unix__ | |||||
bool enable_wait_gdb; | |||||
#endif | |||||
#ifndef __IN_TEE_ENV__ | |||||
#if MGB_ENABLE_JSON | |||||
std::string static_mem_log_dir_path; | |||||
#endif | |||||
#endif | |||||
std::string m_option_name; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,96 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/strategy_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy_options.h" | |||||
#include "models/model_mdl.h" | |||||
using namespace lar; | |||||
DECLARE_bool(c_opr_lib_with_param); | |||||
StrategyOption::StrategyOption() { | |||||
m_option_name = "run_strategy"; | |||||
warmup_iter = FLAGS_warmup_iter; | |||||
run_iter = FLAGS_iter; | |||||
threads = FLAGS_thread; | |||||
} | |||||
std::shared_ptr<OptionBase> StrategyOption::create_option() { | |||||
static std::shared_ptr<StrategyOption> option(new StrategyOption); | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} | |||||
void StrategyOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||||
model->set_shared_mem(FLAGS_share_param_mem); | |||||
runtime_param.warmup_iter = warmup_iter; | |||||
runtime_param.run_iter = run_iter; | |||||
runtime_param.threads = threads; | |||||
runtime_param.testcase_num = 1; | |||||
} else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { | |||||
if (model->type() == ModelType::MEGDL_MODEL) { | |||||
auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | |||||
auto num = model_ptr->get_testcase_num(); | |||||
if (num != 0) | |||||
runtime_param.testcase_num = num; | |||||
model_ptr->make_output_spec(); | |||||
} | |||||
} | |||||
} | |||||
TestcaseOption::TestcaseOption() { | |||||
m_option_name = "run_testcase"; | |||||
} | |||||
std::shared_ptr<OptionBase> TestcaseOption::create_option() { | |||||
static std::shared_ptr<TestcaseOption> option(new TestcaseOption); | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} | |||||
void TestcaseOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
if (model->type() == ModelType::MEGDL_MODEL) { | |||||
auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | |||||
if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { | |||||
if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||||
auto load_result = model_ptr->get_mdl_load_result(); | |||||
auto input_tensor = model_ptr->get_test_input(); | |||||
auto loader = model_ptr->reset_loader(); | |||||
auto testcase = loader->load(model_ptr->get_mdl_config(), false); | |||||
mgb_assert(testcase.output_var_list.size() == input_tensor.size()); | |||||
for (size_t i = 0; i < input_tensor.size(); ++i) { | |||||
auto&& opr = | |||||
testcase.output_var_list[i] | |||||
.node() | |||||
->owner_opr() | |||||
->cast_final_safe<mgb::opr::SharedDeviceTensor>(); | |||||
input_tensor[i].second->copy_from( | |||||
mgb::HostTensorND::make_proxy(*opr.dev_data())); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
DEFINE_int32(iter, 10, "iteration number for run model"); | |||||
DEFINE_int32(warmup_iter, 1, "iteration number for warm up model before run"); | |||||
DEFINE_int32( | |||||
thread, 1, | |||||
"thread number for run model while <thread> is supported( NOTE: " | |||||
"this is not a mapper device setting just for load and run)"); | |||||
DEFINE_bool(share_param_mem, false, "load model from shared memeory"); | |||||
REGIST_OPTION_CREATOR(run_strategy, lar::StrategyOption::create_option); | |||||
REGIST_OPTION_CREATOR(run_testcase, lar::TestcaseOption::create_option); |
@@ -0,0 +1,68 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/strategy_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_int32(iter); | |||||
DECLARE_int32(warmup_iter); | |||||
DECLARE_int32(thread); | |||||
DECLARE_bool(share_param_mem); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: strategy option for running model | |||||
*/ | |||||
class StrategyOption final : public OptionBase { | |||||
public: | |||||
//! creat options when option is used | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
//! config the model, dispatch configuration for different model implement | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
//! get option name | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
//! Constructor | |||||
StrategyOption(); | |||||
//! configuration for different model implement | |||||
std::string m_option_name; | |||||
size_t warmup_iter; //! warm up number before running model | |||||
size_t run_iter; //! iteration number for running model | |||||
size_t threads; //! thread number for running model (NOTE:it's different | |||||
//! from multithread device ) | |||||
}; | |||||
class TestcaseOption final : public OptionBase { | |||||
public: | |||||
//! creat options when option is used | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
//! config the model, dispatch configuration for different model implement | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
//! get option name | |||||
std::string option_name() const override { return m_option_name; }; | |||||
private: | |||||
//! Constructor | |||||
TestcaseOption(); | |||||
//! configuration for different model implement | |||||
std::string m_option_name; | |||||
}; | |||||
} // namespace lar |
@@ -0,0 +1,24 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy.h" | |||||
#include <iostream> | |||||
using namespace lar; | |||||
std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) { | |||||
if (FLAGS_fitting) { | |||||
return std::make_shared<FittingStrategy>(model_path); | |||||
} else { | |||||
return std::make_shared<NormalStrategy>(model_path); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,63 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include "helpers/common.h" | |||||
#include "models/model.h" | |||||
#include "options/option_base.h" | |||||
DECLARE_bool(fitting); | |||||
namespace lar { | |||||
/*! | |||||
* \brief: load and run strategy base class | |||||
*/ | |||||
class StrategyBase { | |||||
public: | |||||
static std::shared_ptr<StrategyBase> create_strategy(std::string model_path); | |||||
virtual void run() = 0; | |||||
virtual ~StrategyBase() = default; | |||||
RuntimeParam m_runtime_param; | |||||
std::unordered_map<std::string, std::shared_ptr<OptionBase>> m_options; | |||||
}; | |||||
/*! | |||||
* \brief: normal strategy for running | |||||
*/ | |||||
class NormalStrategy : public StrategyBase { | |||||
public: | |||||
NormalStrategy(std::string model_path); | |||||
//! run model with runtime parameter | |||||
void run() override; | |||||
private: | |||||
//! run model subline for multiple thread | |||||
void run_subline(); | |||||
std::string m_model_path; | |||||
}; | |||||
/*! | |||||
* \brief: Fitting strategy for running | |||||
*/ | |||||
class FittingStrategy : public StrategyBase { | |||||
public: | |||||
FittingStrategy(std::string model_path); | |||||
void run() override; | |||||
}; | |||||
} // namespace lar | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,24 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy_fitting.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy.h" | |||||
using namespace lar; | |||||
FittingStrategy::FittingStrategy(std::string) { | |||||
mgb_assert("this version don't support Fitting Strategy"); | |||||
}; | |||||
void FittingStrategy::run() { | |||||
mgb_assert("this version don't support Fitting Strategy"); | |||||
}; | |||||
DEFINE_bool( | |||||
fitting, false, | |||||
"whether to use the fitting model, which will auto profile and get " | |||||
"the best option set!"); |
@@ -0,0 +1,167 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy_normal.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <iostream> | |||||
#include <thread> | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/utils/timer.h" | |||||
#include "megbrain/version.h" | |||||
#include "megdnn/version.h" | |||||
#include "misc.h" | |||||
#include "strategy.h" | |||||
using namespace lar; | |||||
NormalStrategy::NormalStrategy(std::string model_path) { | |||||
mgb::set_log_level(mgb::LogLevel::WARN); | |||||
lite::set_log_level(LiteLogLevel::WARN); | |||||
m_model_path = model_path; | |||||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||||
mgb_log_debug("option map size: %lu", option_creator_map->size()); | |||||
auto construct_option = [&](std::string name) -> void { | |||||
auto& creator = (*option_creator_map)[name]; | |||||
auto option = creator(); | |||||
if (option) { | |||||
m_options.insert({name, option}); | |||||
} | |||||
}; | |||||
for (auto& creator : *option_creator_map) { | |||||
auto name = creator.first; | |||||
if (m_options.count(name) == 0) { | |||||
construct_option(name); | |||||
} | |||||
} | |||||
} | |||||
void NormalStrategy::run_subline() { | |||||
auto model = ModelBase::create_model(m_model_path); | |||||
mgb_assert(model != nullptr, "create model failed!!"); | |||||
auto stage_config_model = [&]() { | |||||
for (auto& option : m_options) { | |||||
option.second->config_model(m_runtime_param, model); | |||||
} | |||||
}; | |||||
//! execute before load config | |||||
m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||||
stage_config_model(); | |||||
mgb::RealTimer timer; | |||||
model->load_model(); | |||||
printf("load model: %.3fms\n", timer.get_msecs_reset()); | |||||
//! after load configure | |||||
m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD; | |||||
stage_config_model(); | |||||
m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET; | |||||
stage_config_model(); | |||||
// for get static memmory information options | |||||
m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET; | |||||
stage_config_model(); | |||||
auto warm_up = [&]() { | |||||
auto warmup_num = m_runtime_param.warmup_iter; | |||||
for (size_t i = 0; i < warmup_num; i++) { | |||||
printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset()); | |||||
model->run_model(); | |||||
model->wait(); | |||||
printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset()); | |||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||||
stage_config_model(); | |||||
} | |||||
}; | |||||
auto run_iter = [&](int idx) { | |||||
double time_sqrsum = 0, time_sum = 0, | |||||
min_time = std::numeric_limits<double>::max(), max_time = 0; | |||||
auto run_num = m_runtime_param.run_iter; | |||||
for (size_t i = 0; i < run_num; i++) { | |||||
timer.reset(); | |||||
model->run_model(); | |||||
auto exec_time = timer.get_msecs(); | |||||
model->wait(); | |||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||||
stage_config_model(); | |||||
auto cur = timer.get_msecs(); | |||||
printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); | |||||
time_sum += cur; | |||||
time_sqrsum += cur * cur; | |||||
fflush(stdout); | |||||
min_time = std::min(min_time, cur); | |||||
max_time = std::max(max_time, cur); | |||||
} | |||||
printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||||
"sexec=%.3fms min=%.3fms max=%.3fms\n\n", | |||||
idx, time_sum, time_sum / run_num, | |||||
std::sqrt( | |||||
(time_sqrsum * run_num - time_sum * time_sum) / | |||||
(run_num * (run_num - 1))), | |||||
min_time, max_time); | |||||
return time_sum; | |||||
}; | |||||
//! model with testcase | |||||
size_t iter_num = m_runtime_param.testcase_num; | |||||
double tot_time = 0; | |||||
for (size_t idx = 0; idx < iter_num; idx++) { | |||||
//! config when running model | |||||
mgb_log_warn("run testcase: %zu ", idx); | |||||
m_runtime_param.stage = RunStage::MODEL_RUNNING; | |||||
stage_config_model(); | |||||
if (!idx) { | |||||
warm_up(); | |||||
} | |||||
tot_time += run_iter(idx); | |||||
m_runtime_param.stage = RunStage::AFTER_RUNNING_ITER; | |||||
stage_config_model(); | |||||
} | |||||
printf("=== total time: %.3fms\n", tot_time); | |||||
//! execute after run | |||||
m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | |||||
stage_config_model(); | |||||
}; | |||||
void NormalStrategy::run() { | |||||
auto v0 = mgb::get_version(); | |||||
auto v1 = megdnn::get_version(); | |||||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||||
size_t thread_num = m_runtime_param.threads; | |||||
auto run_sub = [&]() { run_subline(); }; | |||||
if (thread_num == 1) { | |||||
run_sub(); | |||||
} else if (thread_num > 1) { | |||||
#if MGB_HAVE_THREAD | |||||
std::vector<std::thread> threads; | |||||
for (size_t i = 0; i < thread_num; ++i) { | |||||
threads.emplace_back(run_sub); | |||||
} | |||||
for (auto&& i : threads) { | |||||
i.join(); | |||||
} | |||||
#else | |||||
mgb_log_error( | |||||
"%d threads requested, but load_and_run was compiled " | |||||
"without <thread> support.", | |||||
thread_num); | |||||
#endif | |||||
} else { | |||||
mgb_assert(false, "--thread must input a positive number!!"); | |||||
} | |||||
//! execute before run | |||||
} |