Browse Source

fix(mgb): append json file for dump and ready for midout open source

GitOrigin-RevId: 71ae7f1f4a
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
5b6ebeb563
9 changed files with 255 additions and 4 deletions
  1. +3
    -0
      .gitmodules
  2. +24
    -2
      CMakeLists.txt
  3. +3
    -0
      dnn/src/CMakeLists.txt
  4. +3
    -0
      dnn/test/CMakeLists.txt
  5. +1
    -0
      scripts/cmake-build/cross_build_android_arm_inference.sh
  6. +5
    -2
      sdk/load-and-run/dump_with_testcase_mge.py
  7. +1
    -0
      third_party/midout
  8. +1
    -0
      third_party/prepare.sh
  9. +214
    -0
      tools/gen_header_for_bin_reduce.py

+ 3
- 0
.gitmodules View File

@@ -31,3 +31,6 @@
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third_party/midout"]
path = third_party/midout
url = https://github.com/MegEngine/midout.git

+ 24
- 2
CMakeLists.txt View File

@@ -30,6 +30,8 @@ set (MGE_EXPORT_TARGETS MegEngine-targets)

option(MGE_WITH_JIT "Build MegEngine with JIT." ON)
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON)
option(MGE_WITH_MIDOUT_PROFILE "Build MegEngine with Midout profile." OFF)
option(MGE_WITH_MINIMUM_SIZE "Swith off MGE_ENABLE_RTTI、MGE_ENABLE_EXCEPTIONS、MGE_ENABLE_LOGGING and switch on MGE_INFERENCE_ONLY so that compile minimum load_and_run. Take effect only when MGE_BIN_REDUCE was set" OFF)
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF)
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF)
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF)
@@ -53,6 +55,26 @@ option(MGE_INFERENCE_ONLY "Build inference only library." OFF)
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON)
option(MGE_WITH_ROCM "Enable ROCM support" OFF)

if(NOT ${MGE_BIN_REDUCE} STREQUAL "")
message("build with BIN REDUCE")
if(MGE_WITH_MINIMUM_SIZE)
set(MGE_ENABLE_RTTI OFF)
set(MGE_ENABLE_LOGGING OFF)
set(MGE_ENABLE_EXCEPTIONS OFF)
set(MGE_INFERENCE_ONLY ON)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -include ${MGE_BIN_REDUCE}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -include ${MGE_BIN_REDUCE}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -flto=full")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -flto=full")
endif()

if(MGE_WITH_MIDOUT_PROFILE)
message("build with MIDOUT PROFILE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMIDOUT_PROFILING")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMIDOUT_PROFILING")
endif()

if (APPLE)
set (BUILD_SHARED_LIBS OFF)
message("build static for xcode framework require")
@@ -235,7 +257,7 @@ if(NOT MGE_ENABLE_RTTI)
endif()

if(NOT MGE_ENABLE_EXCEPTIONS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exception")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions")
endif()

if(MGE_WITH_TEST)
@@ -297,7 +319,7 @@ if(MGE_WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-rtti")
endif()
if(NOT MGE_ENABLE_EXCEPTIONS)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-exception")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-exceptions")
endif()

if(NOT MGE_CUDA_GENCODE)


+ 3
- 0
dnn/src/CMakeLists.txt View File

@@ -36,6 +36,9 @@ if(NOT ${MGE_ARCH} STREQUAL "naive")
endif()
endif()

if(MGE_WITH_MIDOUT_PROFILE)
list(APPEND SOURCES ${PROJECT_SOURCE_DIR}/third_party/midout/src/midout.cpp)
endif()

###############################################################################
# HIP_COMPILE


+ 3
- 0
dnn/test/CMakeLists.txt View File

@@ -25,6 +25,9 @@ if(MGE_WITH_CUDA)
list(APPEND SOURCES ${CUSOURCES})
endif()

if(MGE_WITH_MIDOUT_PROFILE)
list(APPEND SOURCES ${PROJECT_SOURCE_DIR}/third_party/midout/src/midout.cpp)
endif()

if(MGE_WITH_CAMBRICON)
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp)


+ 1
- 0
scripts/cmake-build/cross_build_android_arm_inference.sh View File

@@ -119,6 +119,7 @@ function cmake_build() {
mkdir -p $BUILD_DIR
mkdir -p $INSTALL_DIR
cd $BUILD_DIR
unset IFS
cmake -G "$MAKEFILE_TYPE Makefiles" \
-DCMAKE_TOOLCHAIN_FILE="$NDK_ROOT/build/cmake/android.toolchain.cmake" \
-DANDROID_NDK="$NDK_ROOT" \


+ 5
- 2
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -471,8 +471,11 @@ def main():
assert not testcase, 'extra inputs provided in testcase: {}'.format(
testcase.keys()
)
mgb.serialize_comp_graph_to_file(args.output, output_mgbvars, append=True)

mgb.serialize_comp_graph_to_file(
args.output,
output_mgbvars,
append=True,
output_strip_info=args.output_strip_info)

if __name__ == '__main__':
main()

+ 1
- 0
third_party/midout

@@ -0,0 +1 @@
Subproject commit 3b8ae875a9e5c95031aca5edcc4233051d774eb5

+ 1
- 0
third_party/prepare.sh View File

@@ -15,6 +15,7 @@ git submodule foreach --recursive git reset --hard
git submodule foreach --recursive git clean -fd


git submodule update --init midout
git submodule update --init intel-mkl-dnn
git submodule update --init Halide
git submodule update --init protobuf


+ 214
- 0
tools/gen_header_for_bin_reduce.py View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import re

if sys.version_info[0] != 3 or sys.version_info[1] < 5:
print('This script requires Python version 3.5')
sys.exit(1)

import argparse
import json
import os
import subprocess
import tempfile
from pathlib import Path

MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'

class HeaderGen:
_dtypes = None
_oprs = None
_fout = None
_elemwise_modes = None
_has_netinfo = False
_midout_files = None

_file_without_hash = False

def __init__(self):
self._dtypes = set()
self._oprs = set()
self._elemwise_modes = set()
self._graph_hashes = set()
self._midout_files = []

_megvii3_root_cache = None
@classmethod
def get_megvii3_root(cls):
if cls._megvii3_root_cache is not None:
return cls._megvii3_root_cache
wd = Path(__file__).resolve().parent
while wd.parent != wd:
workspace_file = wd / 'WORKSPACE'
if workspace_file.is_file():
cls._megvii3_root_cache = str(wd)
return cls._megvii3_root_cache
wd = wd.parent
raise RuntimeError('This script is supposed to run in megvii3.')

def extend_netinfo(self, data):
self._has_netinfo = True
if 'hash' not in data:
self._file_without_hash = True
else:
self._graph_hashes.add(str(data['hash']))
for i in data['dtypes']:
self._dtypes.add(i)
for i in data['opr_types']:
self._oprs.add(i)
for i in data['elemwise_modes']:
self._elemwise_modes.add(i)

def extend_midout(self, fname):
self._midout_files.append(fname)

def generate(self, fout):
self._fout = fout
self._write_def('MGB_BINREDUCE_VERSION', '20190219')
if self._has_netinfo:
self._write_dtype()
self._write_elemwise_modes()
self._write_oprs()
self._write_hash()
self._write_midout()
del self._fout

def strip_opr_name_with_version(self, name):
pos = len(name)
t = re.search(r'V\d+$', name)
if t:
pos = t.start()
return name[:pos]

def _write_oprs(self):
defs = ['}', 'namespace opr {']
already_declare = set()
already_instance = set()
for i in self._oprs:
i = self.strip_opr_name_with_version(i)
if i in already_declare:
continue
else:
already_declare.add(i)

defs.append('class {};'.format(i))
defs.append('}')
defs.append('namespace serialization {')
defs.append("""
template<class Opr, class Callee>
struct OprRegistryCaller {
}; """)
for i in sorted(self._oprs):
i = self.strip_opr_name_with_version(i)
if i in already_instance:
continue
else:
already_instance.add(i)

defs.append("""
template<class Callee>
struct OprRegistryCaller<opr::{}, Callee>: public
OprRegistryCallerDefaultImpl<Callee> {{
}}; """.format(i))
self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)

def _write_elemwise_modes(self):
with tempfile.NamedTemporaryFile() as ftmp:
fpath = os.path.realpath(ftmp.name)
subprocess.check_call(
['./brain/megbrain/dnn/scripts/gen_param_defs.py',
'--write-enum-items', 'Elemwise:Mode',
'./brain/megbrain/dnn/scripts/opr_param_defs.py',
fpath],
cwd=self.get_megvii3_root()
)

with open(fpath) as fin:
mode_list = [i.strip() for i in fin]

for i in mode_list:
if i in self._elemwise_modes:
content = '_cb({})'.format(i)
else:
content = ''
self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')

def _write_dtype(self):
if 'Float16' not in self._dtypes:
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
# support in the past; however `FLOT16' is really a typo. We plan to
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
# To prevent issues in the transition, we decide to define both
# macros (`FLOT16' and `FLOAT16') here.
#
# In the future when the situation is settled and no one would ever
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
# safely deleted.
self._write_def('MEGDNN_DISABLE_FLOT16', 1)
self._write_def('MEGDNN_DISABLE_FLOAT16', 1)

def _write_hash(self):
if self._file_without_hash:
print('WARNING: network info has no graph hash. Using json file '
'generated by MegBrain >= 7.28.0 is recommended')
else:
defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)

def _write_def(self, name, val):
if isinstance(val, list):
val = '\n'.join(val)
val = str(val).strip().replace('\n', ' \\\n')
self._fout.write('#define {} {}\n'.format(name, val))

def _write_midout(self):
if not self._midout_files:
return

gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout',
'gen_header.py')
cvt = subprocess.run(
[gen] + self._midout_files,
stdout=subprocess.PIPE, check=True,
).stdout.decode('utf-8')
self._fout.write('// midout \n')
self._fout.write(cvt)

def main():
parser = argparse.ArgumentParser(
description='generate header file for reducing binary size by '
'stripping unused oprs in a particular network; output file would '
'be written to bin_reduce.h',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'inputs', nargs='+',
help='input files that describe specific traits of the network; '
'can be one of the following:'
' 1. json files generated by '
'megbrain.serialize_comp_graph_to_file() in python; '
' 2. trace files generated by midout library')
parser.add_argument('-o', '--output', help='output file',
default=os.path.join(HeaderGen.get_megvii3_root(),
'utils', 'bin_reduce.h'))
args = parser.parse_args()

gen = HeaderGen()
for i in args.inputs:
print('==== processing {}'.format(i))
with open(i) as fin:
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
gen.extend_midout(i)
else:
fin.seek(0)
gen.extend_netinfo(json.loads(fin.read()))

with open(args.output, 'w') as fout:
gen.generate(fout)

if __name__ == '__main__':
main()

Loading…
Cancel
Save