GitOrigin-RevId: f3b6e492d7
tags/v1.0.0-rc1
@@ -47,10 +47,9 @@ option(MGE_DEBUG_UTIL "Enable debug utility" ON) | |||
option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) | |||
option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | |||
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | |||
option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt.so instead of _mgb.so " OFF) | |||
option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) | |||
option(MGE_BUILD_SDK "Build load_and_run" ON) | |||
option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | |||
option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) | |||
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | |||
option(MGE_WITH_ROCM "Enable ROCM support" OFF) | |||
@@ -256,8 +255,8 @@ endif() | |||
if(MGE_INFERENCE_ONLY) | |||
message("-- Disable distributed support for inference only build.") | |||
set(MGE_WITH_DISTRIBUTED OFF) | |||
message("-- Disable python module for inference only build.") | |||
set(MGE_WITH_PYTHON_MODULE OFF) | |||
message("-- Disable imperative_rt python module for inference only build.") | |||
set(MGE_BUILD_IMPERATIVE_RT OFF) | |||
endif() | |||
if(MGE_WITH_DISTRIBUTED) | |||
@@ -694,43 +693,18 @@ if(MGE_BUILD_SDK) | |||
add_subdirectory(sdk/load-and-run) | |||
endif() | |||
if(MGE_WITH_PYTHON_MODULE) | |||
if(MGE_BUILD_IMPERATIVE_RT) | |||
add_subdirectory(imperative) | |||
message("-- Enable imperative python wrapper runtime") | |||
else() | |||
add_subdirectory(python_module) | |||
message("-- Enable legacy python wrapper runtime") | |||
endif() | |||
if(MGE_BUILD_IMPERATIVE_RT) | |||
add_subdirectory(imperative) | |||
message("-- Enable imperative python wrapper runtime") | |||
endif() | |||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
add_subdirectory(test) | |||
endif() | |||
if(TARGET mgb) | |||
add_custom_target( | |||
develop | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:mgb> | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:mgb> | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/mgb.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/mgb.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/include | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/include | |||
DEPENDS mgb | |||
VERBATIM | |||
) | |||
elseif(TARGET _imperative_rt) | |||
if(TARGET _imperative_rt) | |||
add_custom_target( | |||
develop | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
@@ -183,25 +183,16 @@ def typename(type): | |||
# parse typing.Union | |||
if sys.version_info < (3, 6): | |||
def parse_union(ann): | |||
def parse_union(ann): | |||
if hasattr(typing, "UnionMeta"): | |||
if type(ann) is not typing.UnionMeta: | |||
return | |||
return ann.__union_params__ | |||
elif sys.version_info < (3, 7): | |||
def parse_union(ann): | |||
elif hasattr(typing, "_Union"): | |||
if type(ann) is not typing._Union: | |||
return | |||
return ann.__args__ | |||
elif sys.version_info < (3, 8): | |||
def parse_union(ann): | |||
elif hasattr(typing, "_GenericAlias"): | |||
if type(ann) is not typing._GenericAlias: | |||
if type(ann) is not typing.Union: | |||
return | |||
@@ -209,11 +200,9 @@ elif sys.version_info < (3, 8): | |||
if ann.__origin__ is not typing.Union: | |||
return | |||
return ann.__args__ | |||
else: | |||
def parse_union(ann): | |||
elif hasattr(typing, "Union"): | |||
if typing.get_origin(ann) is not typing.Union: | |||
return | |||
return typing.get_args(ann) | |||
else: | |||
raise NotImplementedError("unsupported Python version") |
@@ -6,6 +6,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import re | |||
import pathlib | |||
@@ -55,11 +56,13 @@ package_data = [ | |||
str(f.relative_to('megengine')) | |||
for f in pathlib.Path('megengine', 'core', 'include').glob('**/*') | |||
] | |||
package_data += [ | |||
str(f.relative_to('megengine')) | |||
for f in pathlib.Path('megengine', 'core', 'lib').glob('**/*') | |||
] | |||
with open('requires.txt') as f: | |||
requires = f.read().splitlines() | |||
with open('requires-style.txt') as f: | |||
@@ -67,6 +70,7 @@ with open('requires-style.txt') as f: | |||
with open('requires-test.txt') as f: | |||
requires_test = f.read().splitlines() | |||
prebuild_modules=[PrecompiledExtesion('megengine.core._imperative_rt')] | |||
setup_kwargs = dict( | |||
name=package_name, | |||
version=__version__, | |||
@@ -78,7 +82,7 @@ setup_kwargs = dict( | |||
package_data={ | |||
'megengine': package_data, | |||
}, | |||
ext_modules=[PrecompiledExtesion('megengine.core._imperative_rt')], | |||
ext_modules=prebuild_modules, | |||
install_requires=requires, | |||
extras_require={ | |||
'dev': requires_style + requires_test, | |||
@@ -87,6 +91,7 @@ setup_kwargs = dict( | |||
cmdclass={'build_ext': build_ext}, | |||
) | |||
setup_kwargs.update(dict( | |||
classifiers=[ | |||
'Development Status :: 3 - Alpha', | |||
@@ -0,0 +1,21 @@ | |||
#!/bin/bash -e | |||
test_dirs="test" | |||
TEST_PLAT=$1 | |||
if [[ "$TEST_PLAT" == cpu ]]; then | |||
echo "only test cpu pytest" | |||
elif [[ "$TEST_PLAT" == cuda ]]; then | |||
echo "test both cpu and gpu pytest" | |||
else | |||
log "Argument must cpu or cuda" | |||
exit 1 | |||
fi | |||
pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | |||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'not isolated_distributed' | |||
if [[ "$TEST_PLAT" == cuda ]]; then | |||
echo "test GPU pytest now" | |||
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'isolated_distributed' | |||
fi | |||
popd >/dev/null |
@@ -1,8 +0,0 @@ | |||
/megbrain/_mgb.so | |||
/megbrain/_mgb.*.so | |||
/MegBrain.egg-info/ | |||
/dist | |||
/dist_cuda | |||
/dist_nocuda | |||
/wheel_dist | |||
.cache |
@@ -1,113 +0,0 @@ | |||
cmake_policy(SET CMP0086 NEW) | |||
find_package(PythonLibs ${PYTHON_VERSION_STRING} EXACT REQUIRED) | |||
find_package(Git) | |||
if(GIT_FOUND) | |||
message("git found: ${GIT_EXECUTABLE}") | |||
endif() | |||
find_package(NumPy REQUIRED) | |||
find_package(SWIG REQUIRED) | |||
set(SWIG_SRC src/swig/mgb.i) | |||
if(MSVC OR WIN32) | |||
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -DSWIGWORDSIZE64) | |||
message("WARN: swig have some define issue at windows(64) env") | |||
message("Please refs scripts/whl/BUILD_PYTHON_WHL_README.md to init windows build env") | |||
else() | |||
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) | |||
endif() | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | |||
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
file(GLOB_RECURSE PYTHON_SRCS setup.py | |||
src/python/*.py | |||
test/*.py | |||
megengine/*.py) | |||
list(REMOVE_ITEM PYTHON_SRCS | |||
${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/mgb.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/opr.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/megengine/_internal/opr_param_defs.py | |||
) | |||
list(APPEND PYTHON_SRCS ${MGB_SRCS}) | |||
file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
${PROJECT_SOURCE_DIR}/src/core/include/* | |||
${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
${PROJECT_SOURCE_DIR}/src/serialization/include/* | |||
${PROJECT_SOURCE_DIR}/src/plugin/include/* | |||
${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
add_custom_command( | |||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr_param_defs.py | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/src/python ${CMAKE_CURRENT_BINARY_DIR}/src/python | |||
COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal | |||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/src/python/genopr.py ${OPR_DECL_SRCS} | |||
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr_param_defs.py | |||
DEPENDS ${OPR_DECL_SRCS} | |||
VERBATIM | |||
) | |||
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | |||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||
include(UseSWIG) | |||
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | |||
# cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS | |||
# Add -I${PROJECT_BINARY_DIR}/genfiles in order to include megbrain_build_config.h so that we don't need to pass cmake flags by -D. | |||
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/serialization/include -I${PROJECT_BINARY_DIR}/genfiles) | |||
set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | |||
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | |||
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS}) | |||
set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||
add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) | |||
set_target_properties(mgb PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | |||
if (APPLE) | |||
target_link_libraries(mgb megbrain megdnn) | |||
set_target_properties(mgb PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") | |||
elseif (MSVC OR WIN32) | |||
target_link_libraries(mgb megbrain megdnn) | |||
else() | |||
target_link_libraries(mgb megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) | |||
endif() | |||
target_include_directories(mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_CURRENT_BINARY_DIR} ${NUMPY_INCLUDE_DIR}) | |||
# only windows need link PYTHON_LIBRARIES | |||
if(MSVC OR WIN32) | |||
target_link_libraries(mgb ${PYTHON_LIBRARIES}) | |||
endif() | |||
if (MGE_WITH_DISTRIBUTED) | |||
target_link_libraries(mgb megray) | |||
endif() | |||
add_dependencies(mgb mgb_opr_py version_ld) | |||
add_custom_command( | |||
TARGET mgb POST_BUILD | |||
COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/megengine ${CMAKE_CURRENT_BINARY_DIR}/megengine | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/test ${CMAKE_CURRENT_BINARY_DIR}/test | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/setup.py ${CMAKE_CURRENT_BINARY_DIR}/setup.py | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/requires.txt | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/requires-style.txt | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requires-test.txt ${CMAKE_CURRENT_BINARY_DIR}/requires-test.txt | |||
COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/src/cpp/megbrain_pubapi.h ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include/megbrain_pubapi.h | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/core/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/opr/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/serialization/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/src/plugin/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/dnn/include ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include | |||
COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_BINARY_DIR}/genfiles/megbrain_build_config.h ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/include/megbrain_build_config.h | |||
) | |||
@@ -1,11 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .core import * | |||
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
from .version import __version__ |
@@ -1,729 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""the megbrain python package | |||
Note that all the submodules are automatically imported, so you usually only | |||
need to ``import megengine._internal as mgb``. | |||
""" | |||
import collections | |||
import json | |||
import os | |||
import sys | |||
import platform | |||
import ctypes | |||
if sys.platform == "win32": | |||
lib_path = os.path.join(os.path.dirname(__file__), "lib") | |||
Lib_path = os.path.join(os.path.dirname(__file__), "Lib") | |||
dll_paths = list(filter(os.path.exists, [lib_path, Lib_path])) | |||
assert len(dll_paths) > 0 | |||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) | |||
has_load_library_attr = hasattr(kernel32, "AddDllDirectory") | |||
old_error_mode = kernel32.SetErrorMode(0x0001) | |||
kernel32.LoadLibraryW.restype = ctypes.c_void_p | |||
if has_load_library_attr: | |||
kernel32.AddDllDirectory.restype = ctypes.c_void_p | |||
kernel32.LoadLibraryExW.restype = ctypes.c_void_p | |||
for dll_path in dll_paths: | |||
if sys.version_info >= (3, 8): | |||
os.add_dll_directory(dll_path) | |||
elif has_load_library_attr: | |||
res = kernel32.AddDllDirectory(dll_path) | |||
if res is None: | |||
err = ctypes.WinError(ctypes.get_last_error()) | |||
err.strerror += ' Error adding "{}" to the DLL search PATH.'.format( | |||
dll_path | |||
) | |||
raise err | |||
else: | |||
print("WARN: python or OS env have some issue, may load DLL failed!!!") | |||
import glob | |||
dlls = glob.glob(os.path.join(lib_path, "*.dll")) | |||
path_patched = False | |||
for dll in dlls: | |||
is_loaded = False | |||
if has_load_library_attr: | |||
res = kernel32.LoadLibraryExW(dll, None, 0x00001100) | |||
last_error = ctypes.get_last_error() | |||
if res is None and last_error != 126: | |||
err = ctypes.WinError(last_error) | |||
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( | |||
dll | |||
) | |||
raise err | |||
elif res is not None: | |||
is_loaded = True | |||
if not is_loaded: | |||
if not path_patched: | |||
os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) | |||
path_patched = True | |||
res = kernel32.LoadLibraryW(dll) | |||
if res is None: | |||
err = ctypes.WinError(ctypes.get_last_error()) | |||
err.strerror += ' Error loading "{}" or one of its dependencies.'.format( | |||
dll | |||
) | |||
raise err | |||
kernel32.SetErrorMode(old_error_mode) | |||
import numpy as np | |||
from . import comp_graph_tools as cgtools | |||
from . import config, craniotome, dtype | |||
from . import global_init as _global_init | |||
from . import helper as _helper | |||
from . import mgb as _detail | |||
from . import opr, opr_extra, opr_param_defs, plugin | |||
from .exc import MegBrainError | |||
from .logconf import get_logger | |||
from .mgb import ( | |||
CompGraph, | |||
CompNode, | |||
SharedND, | |||
SharedScalar, | |||
SymbolVar, | |||
TensorValueDumperContext, | |||
TensorValueLoaderContext, | |||
) | |||
from .mgb import as_comp_node as comp_node | |||
from .mgb_helper import SharedNDLazyInitializer, callback_lazycopy, copy_output | |||
from .plugin import CompGraphProfiler | |||
from .plugin import GlobalInfkernFinder as _GlobalInfkernFinder | |||
from .plugin import NumRangeChecker | |||
from .version import __version__, version_info | |||
if sys.version_info.major < 3: | |||
raise ImportError("megbrain requires python 3") | |||
class ProxySharedNDAndSymbolVar(_detail.SymbolVar): | |||
"""this is a :class:`.SymbolVar` with a corresponding :class:`.SharedND`. | |||
It can participate in graph computating and also provides :meth:`set_value` | |||
and :meth:`get_value`. It should be constructed by :func:`make_shared`. | |||
""" | |||
__shared_nd = None | |||
__kwargs = None | |||
def __init__(self, snd, comp_graph, name, **kwargs): | |||
self.__shared_nd = snd | |||
self.__kwargs = kwargs | |||
self.this = snd.symvar(comp_graph=comp_graph, name=name, **kwargs).this | |||
def set_value(self, v, **kwargs): | |||
ret = self.__shared_nd.set_value(v, **kwargs) | |||
self._reeval_if_eager_eval() | |||
return ret | |||
def get_value(self): | |||
return self.__shared_nd.get_value() | |||
def reset_zero(self): | |||
self.__shared_nd.reset_zero() | |||
def make_shared( | |||
comp_node, | |||
*, | |||
dtype=None, | |||
shape=None, | |||
value=None, | |||
comp_graph=None, | |||
name=None, | |||
volatile=None | |||
): | |||
"""make a shared tensor which is stored on device and could be modified | |||
later, either as a :class:`.SymbolVar` or a :class:`.SharedND` object | |||
:param comp_node: computing node | |||
:type comp_node: :class:`.CompNode` | |||
:param dtype: data type; if it is None, then dtype of value would be used | |||
if value is not None, and float32 would be used as default dtype if | |||
value is None | |||
:type dtype: :class:`numpy.dtype` compatible | |||
:param value: initializing value | |||
:type value: None or :class:`numpy.ndarray` | |||
:param comp_graph: the computing graph to which this shared value should | |||
belong; if provided, the retuned object could be used as a | |||
:class:`.SymbolVar` | |||
:type comp_graph: None or :class:`.CompGraph` | |||
:param name: node name to be used in computing graph; only meaningful if | |||
*comp_graph* is provided | |||
:param volatile: if *comp_graph* is given then *volatile* indicates whether | |||
shape or mem ptr of this SharedND can be changed | |||
:rtype: :class:`.SharedND` if *comp_graph* is not given; or | |||
:class:`ProxySharedNDAndSymbolVar` otherwise | |||
""" | |||
if dtype is None: | |||
if value is not None: | |||
value = np.ascontiguousarray(value) | |||
dtype = to_mgb_supported_dtype(value.dtype) | |||
else: | |||
dtype = np.float32 | |||
comp_node = _detail.as_comp_node(comp_node) | |||
rst = _detail.SharedND(comp_node, dtype) | |||
if value is not None: | |||
assert shape is None, "could not provide both value and shape" | |||
rst.set_value(value) | |||
elif shape is not None: | |||
rst._set_init_shape(shape) | |||
if comp_graph is None: | |||
assert name is None and volatile is None | |||
return rst | |||
assert isinstance(comp_graph, CompGraph), "expect CompGraph but got {}".format( | |||
comp_graph | |||
) | |||
if volatile is None: | |||
volatile = False | |||
else: | |||
assert isinstance(volatile, bool) | |||
return ProxySharedNDAndSymbolVar(rst, comp_graph, name, volatile=volatile) | |||
def make_immutable(comp_node, comp_graph, value, *, dtype=None, name=None): | |||
"""make a graph node containing an immutable tensor from host tensor value | |||
:param dtype: required data type; if not None, the data would be converted | |||
to that type; otherwise | |||
""" | |||
comp_node = _detail.as_comp_node(comp_node) | |||
assert isinstance( | |||
comp_graph, _detail.CompGraph | |||
), "expect CompGraph but got {!r}".format(comp_graph) | |||
config = _detail.make_opr_config(name, comp_node) | |||
return _helper.cvt_opr_result( | |||
_detail._make_immutable(comp_graph, value, dtype, config) | |||
) | |||
def make_arg( | |||
comp_node, | |||
comp_graph, | |||
*, | |||
dtype=np.float32, | |||
shape=None, | |||
name=None, | |||
value=None, | |||
enable_static_infer=True | |||
): | |||
"""make an argument to be passed to compiled function during runtime; | |||
:type shape: None or tuple of int | |||
:param shape: expected tensor shape to be used for shape inferring; actual | |||
tesor shape could be different | |||
:type name: str | |||
:param name: name of the generated var node | |||
:type value: None or ndarray-compatible | |||
:param value: initial value used for static inference; if not given, static | |||
infer would be deferred to first graph execution | |||
:param enable_static_infer: whether to enable static inference for this var | |||
""" | |||
comp_node = _detail.as_comp_node(comp_node) | |||
host_val = mgb._HostSharedND(comp_node, dtype) | |||
if value is not None: | |||
value = np.ascontiguousarray(value, dtype=dtype) | |||
if shape is None: | |||
shape = value.shape | |||
else: | |||
assert shape == value.shape | |||
if shape is not None: | |||
host_val._resize(shape) | |||
if value is not None: | |||
host_val.set_value(value) | |||
return _helper.cvt_opr_result( | |||
ProxySharedNDAndSymbolVar( | |||
host_val, comp_graph, name, enable_static_infer=enable_static_infer | |||
) | |||
) | |||
def comp_graph(*, extra_opts=None, check_env_var=True): | |||
"""allocate a new computing graph | |||
:param extra_opts: extra options to be set; would be updated (modified | |||
inplace) from ``MGB_COMP_GRAPH_OPT`` environment var. See | |||
:func:`.set_comp_graph_option` for list of supported options. | |||
:type extra_opts: dict | |||
:param check_env_var: whether to check environment vars | |||
:type check_env_var: bool | |||
:return: the comp graph object | |||
:rtype: :class:`.CompGraph` | |||
""" | |||
cg = _detail.CompGraph() | |||
if extra_opts is None: | |||
extra_opts = {} | |||
if check_env_var: | |||
setting = os.getenv("MGB_COMP_GRAPH_OPT") | |||
if setting: | |||
for item in setting.split(";"): | |||
k, v = item.split("=", 1) | |||
extra_opts.setdefault(k, v) | |||
get_logger().warning( | |||
"set comp graph option from env: {}".format(extra_opts) | |||
) | |||
user_data = os.getenv("MGB_COMP_GRAPH_USER_DATA") | |||
if user_data: | |||
storage = cg.user_data | |||
for ud in user_data.split(";"): | |||
k, v = ud.split("=", 1) | |||
storage[k] = eval(v) | |||
_GlobalInfkernFinder.add_graph(cg) | |||
for k, v in extra_opts.items(): | |||
cg.set_option(k, v) | |||
return cg | |||
def grad( | |||
target, wrt, warn_mid_wrt=True, use_virtual_grad=None, return_zero_for_nodep=True | |||
): | |||
r"""compute symbolic grad | |||
:param target: grad target var | |||
:type target: :class:`.SymbolVar` | |||
:param wrt: with respect to which to compute the grad | |||
:type wrt: :class:`.SymbolVar` or Iterable[SymbolVar] | |||
:param warn_mid_wrt: whether to give warning if *wrt* is not endpoint | |||
:type warn_mid_wrt: bool | |||
:param use_virtual_grad: whether to use virtual grad opr, so fwd graph can | |||
be optimized before applying grad; if ``None`` is given, then virtual | |||
grad would be used if ``graph_opt_level >= 2`` | |||
:type use_virtual_grad: :class:`bool` or ``None`` | |||
:param return_zero_for_nodep: if *target* does not depend on *wrt*, set to True to return | |||
a zero-valued `.SymbolVar` rather than ``None``; can't be set to False when using | |||
virtual grad opr. | |||
:type return_zero_for_nodep: bool | |||
:rtype: :class:`.SymbolVar` or None | |||
:return: :math:`\frac{\partial\text{target}}{\partial\text{wrt}}` | |||
""" | |||
if use_virtual_grad is None: | |||
use_virtual_grad = -1 | |||
else: | |||
use_virtual_grad = 1 if use_virtual_grad else 0 | |||
if isinstance(wrt, SymbolVar): | |||
wrts = [ | |||
wrt, | |||
] | |||
else: | |||
wrts = wrt | |||
assert isinstance(wrts, collections.Iterable) | |||
# return a invalid SymbolVar (with nullptr VarNode*) when return_zero_for_nodep is False | |||
# and target doesn't depend on wrt | |||
grads = _detail._grad( | |||
target, wrts, bool(warn_mid_wrt), use_virtual_grad, return_zero_for_nodep | |||
) | |||
grads = list(grads) | |||
for i in range(len(grads)): | |||
if not grads[i].valid: | |||
assert ( | |||
not return_zero_for_nodep | |||
), "invalid grad SymbolVar: target={}, wrt={}".format(target, wrts[i]) | |||
grads[i] = None | |||
if len(grads) == 1: | |||
grads = grads[0] | |||
return grads | |||
def current_grad_target(comp_graph): | |||
"""get current target var to compute grad, used for implementing custom | |||
gradient""" | |||
return _detail._current_grad_target(comp_graph) | |||
def add_device_map(map_location): | |||
"""add map location while loading models""" | |||
_detail.CompNode.cn_thread_local.__setattr__("map_location", map_location) | |||
def del_device_map(): | |||
"""delete map location""" | |||
_detail.CompNode.cn_thread_local.__delattr__("map_location") | |||
def inter_graph_trans_var(dest_graph, src): | |||
"""get the corresponding var of *src* in *dest_graph*; assuming | |||
*dest_graph* is a copy of owner graph of *src*; usually used in callback of | |||
set_grad to get grad of vars in loop | |||
:param dest_graph: target computing graph | |||
:type dest_graph: :class:`.CompGraph` | |||
:param src: source var node | |||
:type src: :class:`.SymbolVar` | |||
:return: corresponding var in *dest_graph* | |||
:rtype: :class:`.SymbolVar` | |||
""" | |||
return _detail._inter_graph_trans_var(dest_graph, src) | |||
def get_graph_optimizer_replaced_var(src): | |||
"""get optimized var corresponding to given var; usually used in callback | |||
of set_grad to get grad w.r.t. some var | |||
:param src: source var node | |||
:type src: :class:`.SymbolVar` | |||
:rtype: :class:`.SymbolVar` | |||
""" | |||
return _detail._get_graph_optimizer_replaced_var(src) | |||
CompGraphSerializationResult = collections.namedtuple( | |||
"CompGraphSerializationResult", | |||
[ | |||
"nr_opr", | |||
"tot_bytes", | |||
"tensor_value_bytes", | |||
"content_hash", | |||
"inputs", | |||
"outputs", | |||
"params", | |||
], | |||
) | |||
def serialize_comp_graph_to_file( | |||
fpath, | |||
output_vars, | |||
*, | |||
keep_var_name=1, | |||
keep_param_name=False, | |||
keep_opr_priority=False, | |||
tensor_value_dumper=None, | |||
output_strip_info=False, | |||
append=False, | |||
format=None, | |||
**kwargs | |||
): | |||
"""serialize this computing graph and write result to a file. Note: | |||
``kwargs`` exists for backward compatibility; there is no additional | |||
arguments. | |||
:parma fpath: path for the output file | |||
:type fpath: ``str`` | |||
:param output_vars: output variables that need to be retrieved when | |||
deserializing | |||
.. note:: | |||
The underlying C++ API only accepts a var list. If a dict is given, | |||
the vars would be renamed to given names. | |||
:type output_vars: dict(name => :class:`.SymbolVar`), or a list of vars | |||
:param keep_var_name: level for keeping variable names: | |||
* 0: none of the names are kept | |||
* 1: keep names of output vars | |||
* 2: keep names of all (output and internal) vars | |||
:param keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
:param keep_opr_priority: whether to keep priority setting for operators | |||
:param tensor_value_dumper: a callable to dump tensor values; it should | |||
only write the tensor value without layout information. It would be | |||
given a :class:`.TensorValueDumperContext` object as its sole argument. | |||
:param output_strip_info: if set to True, then a json file containing | |||
information for code strip would be written to ``fpath+'.json'`` | |||
:param append: whether to open output file in append mode | |||
:return: an instance of namedtuple :class:`CompGraphSerializationResult`, | |||
whose fields are: | |||
* ``nr_opr`` number of operators dumped | |||
* ``tot_bytes`` total bytes for the whole graph | |||
* ``tensor_value_bytes`` bytes consumed for dumping tensor values | |||
* ``inputs`` names of input tensors | |||
* ``params`` list of names of dumped params | |||
* ``outputs`` names of output vars | |||
:param format: serialization format of the resulting model, should be either | |||
"mdl" or "fbs"; none means default. | |||
:type format: ``str`` | |||
""" | |||
assert isinstance(fpath, str), "bad file path: {!r}".format(fpath) | |||
ov = _detail._VectorSymbolVar() | |||
SUPPORTED_FORMATS = { | |||
# default | |||
None: _detail.GraphDumpFormat_FLATBUFFERS, | |||
"fbs": _detail.GraphDumpFormat_FLATBUFFERS, | |||
} | |||
resolved_fmt = SUPPORTED_FORMATS.get(format, None) | |||
if resolved_fmt is None: | |||
raise ValueError( | |||
"unknown format {} requested, supported ones are {}".format( | |||
format, list(filter(None, SUPPORTED_FORMATS.keys())) | |||
) | |||
) | |||
if isinstance(output_vars, dict): | |||
used_vars = set() | |||
for name, var in output_vars.items(): | |||
assert isinstance(var, _detail.SymbolVar), "bad output var: {!r}".format( | |||
var | |||
) | |||
assert var.id not in used_vars, ( | |||
"var name is associated with a var object, so we can not have " | |||
"two names given to the same var: {}".format(var) | |||
) | |||
used_vars.add(var.id) | |||
var.rename(name) | |||
ov.push_back(var) | |||
else: | |||
for i in output_vars: | |||
assert isinstance(i, _detail.SymbolVar), "bad output var: {!r}".format(i) | |||
ov.push_back(i) | |||
if tensor_value_dumper is not None: | |||
assert isinstance(tensor_value_dumper, collections.Callable) | |||
class Callback(_detail._TensorValueDumperCallback): | |||
def call(self, ctx, *, _f=tensor_value_dumper): | |||
_f(ctx) | |||
tensor_value_dumper = Callback() | |||
# for backward compatibility | |||
mangle_opr_name = kwargs.pop("mangle_opr_name", ov) | |||
if mangle_opr_name is not ov: | |||
get_logger().warning("mangle_opr_name is deprecated; use keep_var_name instead") | |||
keep_var_name = 1 if mangle_opr_name else 2 | |||
mangle_param_name = kwargs.pop("mangle_param_name", ov) | |||
assert ( | |||
not kwargs | |||
), "extra kwargs provided to serialize_comp_graph_to_file: {}".format(kwargs) | |||
if mangle_param_name is not ov: | |||
get_logger().warning( | |||
"mangle_param_name is deprecated; use keep_param_name instead" | |||
) | |||
keep_param_name = not mangle_param_name | |||
inputs = _detail._VectorString() | |||
outputs = _detail._VectorString() | |||
params = _detail._VectorString() | |||
stat = _detail._VectorSizeT() | |||
_detail._serialize_comp_graph_to_file( | |||
fpath, | |||
append, | |||
resolved_fmt, | |||
ov, | |||
keep_var_name, | |||
keep_param_name, | |||
keep_opr_priority, | |||
tensor_value_dumper, | |||
stat, | |||
inputs, | |||
outputs, | |||
params, | |||
) | |||
dump_ret = CompGraphSerializationResult( | |||
*stat, list(inputs), list(outputs), list(params) | |||
) | |||
if output_strip_info: | |||
with open(fpath + ".json", "w") as fout: | |||
strip_info = _detail._get_info_for_strip(ov) | |||
strip_info_dict = json.loads(strip_info) | |||
strip_info_dict["hash"] = dump_ret.content_hash | |||
json.dump(strip_info_dict, fout) | |||
return dump_ret | |||
CompGraphLoadResult = collections.namedtuple( | |||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||
) | |||
def load_comp_graph_from_file( | |||
fpath, *, comp_node_mapper=None, tensor_value_loader=None | |||
): | |||
"""Load a serialized computing graph from file. | |||
:parma fpath: Path for the output file | |||
:type fpath: ``str`` | |||
:param comp_node_mapper: A callable to modify comp node locator, takes old | |||
locator as argument and returns new locator. | |||
:type comp_node_mapper: Callable[[str], str] | |||
:param tensor_value_loader: A callable to load tensor values. It should | |||
read the tensor value with the given shape and dtype and return it as | |||
NumPy ndarray. It would be given a :class:`.TensorValueLoaderContext` | |||
object as its sole argument. | |||
:type tensor_value_loader: Callable[[TensorValueLoaderContext], numpy.ndarray] | |||
:return: An instance of namedtuple :class:`CompGraphLoadResult`, | |||
whose fields are: | |||
* ``graph`` loaded CompGraph | |||
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar | |||
* ``output_vars_list`` A Python list, containing output vars in the | |||
order passed to serialize_comp_graph_to_file | |||
""" | |||
assert isinstance(fpath, str), "bad file path: {!r}".format(fpath) | |||
if comp_node_mapper is not None: | |||
assert isinstance(comp_node_mapper, collections.Callable) | |||
class Callback(_detail._CompNodeMapperCallback): | |||
def call(self, desc, *, _f=comp_node_mapper): | |||
return _f(desc) | |||
comp_node_mapper = Callback() | |||
if tensor_value_loader is not None: | |||
assert isinstance(tensor_value_loader, collections.Callable) | |||
class Callback(_detail._TensorValueLoaderCallback): | |||
def call(self, ctx, *, _f=tensor_value_loader): | |||
return _f(ctx) | |||
tensor_value_loader = Callback() | |||
output_vars_map = _detail._VectorPairStringSymbolVar() | |||
output_vars_list = _detail._VectorSymbolVar() | |||
cg = _detail._load_comp_graph_from_file( | |||
fpath, comp_node_mapper, tensor_value_loader, output_vars_map, output_vars_list | |||
) | |||
return CompGraphLoadResult(cg, dict(list(output_vars_map)), list(output_vars_list)) | |||
def optimize_for_inference( | |||
output_vars, | |||
*, | |||
f16_io_f32_comp=False, | |||
f16_io_comp=False, | |||
use_nhwcd4=False, | |||
fuse_conv_bias_nonlinearity=False, | |||
use_nchw32=False, | |||
fuse_conv_bias_with_z=False, | |||
use_nchw4=False, | |||
use_nchw88=False, | |||
use_nchw44=False, | |||
use_nchw44_dot=False, | |||
use_chwn4=False | |||
): | |||
"""optimize computing graph for inference | |||
This applies a predefined set of optimization passes. Refer to the mnist | |||
sdk example and C++ code for fine-grained control. | |||
:param output_vars: output symvars | |||
:type output_vars: list of :class:`.SymbolVar` | |||
:param f16_io_f32_comp: whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
changed to float16 | |||
:param f16_io_comp: whether to use float16 for both I/O and computation | |||
precision | |||
:param use_nhwcd4: whether to use NHWCD4 data format. This is faster on some | |||
OpenCL devices | |||
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. This is supported only in NHWCD4 format. | |||
:param use_nchw4: whether to use NCHW4 tensor format. | |||
:param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some | |||
times. | |||
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some | |||
times. | |||
:param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is | |||
optimized for inference in armv8.2 | |||
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for | |||
nvidia tensorcore. | |||
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for | |||
nvidia tensorcore. | |||
:return: list of transformed vars corresponding to given output vars | |||
""" | |||
assert isinstance(output_vars, (list, tuple)) | |||
opt = _detail._OptimizeForInferenceOptions() | |||
settings = locals() | |||
for i in [ | |||
"f16_io_f32_comp", | |||
"f16_io_comp", | |||
"fuse_conv_bias_nonlinearity", | |||
"fuse_conv_bias_with_z", | |||
]: | |||
if settings[i]: | |||
getattr(opt, "enable_{}".format(i))() | |||
layout_tranform = None | |||
for k, v in { | |||
"use_nchw4": "nchw4", | |||
"use_nhwcd4": "nhwcd4", | |||
"use_nchw32": "nchw32", | |||
"use_nchw88": "nchw88", | |||
"use_nchw44": "nchw44", | |||
"use_nchw44_dot": "nchw44_dot", | |||
"use_chwn4": "chwn4", | |||
}.items(): | |||
if settings[k]: | |||
assert ( | |||
not layout_tranform | |||
), "Only one layout transform supported, both {} and {}".format( | |||
layout_tranform, k | |||
) | |||
getattr(opt, "enable_{}".format(v))() | |||
layout_tranform = k | |||
vec = _detail._VectorSymbolVar() | |||
for i in output_vars: | |||
assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i) | |||
vec.push_back(i) | |||
return list(_detail._optimize_for_inference(vec, opt)) | |||
def get_opr_fp_graph_exec(comp_graph, output_vars): | |||
"""get opr footprint and graph exec info | |||
This function will recompile the compute graph, the AsyncExecutable compiled | |||
before will be invalid. | |||
:param comp_graph: ComputingGraph | |||
:param output_vars: list of :class:'.SymbolVar' | |||
""" | |||
assert isinstance(output_vars, (list, tuple)) | |||
vec = _detail._VectorSymbolVar() | |||
for i in output_vars: | |||
assert isinstance(i, _detail.SymbolVar), "bad var: {}".format(i) | |||
vec.push_back(i) | |||
return json.loads(_detail._get_opr_fp_graph_exec(comp_graph, output_vars)) | |||
def to_mgb_supported_dtype(dtype_): | |||
"""get the dtype supported by megbrain nearest to given dtype""" | |||
if ( | |||
dtype.is_lowbit(dtype_) | |||
or dtype.is_quantize(dtype_) | |||
or dtype.is_bfloat16(dtype_) | |||
): | |||
return dtype_ | |||
return _detail._to_mgb_supported_dtype(dtype_) | |||
def return_free_memory(): | |||
"""return free memory chunks on all devices. | |||
This function will try it best to free all consecutive free chunks back to | |||
operating system, small pieces may not be returned. | |||
Please notice that this function will not move any memory in-use. | |||
""" | |||
_detail.CompNode._try_coalesce_all_free_memory() |
@@ -1,37 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import os | |||
import sys | |||
import megengine._internal.mgb as _mgb | |||
try: | |||
from setproctitle import setproctitle | |||
except ImportError: | |||
setproctitle = None | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description="entry point for fork-exec callback in TimedFuncInvoker;" | |||
" this file should not be used directly by normal user." | |||
) | |||
parser.add_argument("user_data") | |||
args = parser.parse_args() | |||
if setproctitle: | |||
setproctitle("megbrain:timed_func_exec:ppid={}".format(os.getppid())) | |||
_mgb._timed_func_exec_cb(args.user_data) | |||
raise SystemError("_timed_func_exec_cb returned") | |||
if __name__ == "__main__": | |||
main() |
@@ -1,274 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""tools for graph manipulation""" | |||
import collections | |||
from . import mgb as _mgb | |||
def get_dep_vars(var, var_type=None): | |||
"""return :class:`.SymbolVar` of type ``var_type`` that input ``var`` | |||
depands on. If ``var_type`` is None, return all types. | |||
:type var: an instance or iterable of :class:`.SymbolVar` | |||
:type var_type: ``str`` or an iterable of ``str`` | |||
"rtype: list of :class:`.SymbolVar` | |||
""" | |||
outputs = [] | |||
memo = set() | |||
if isinstance(var, _mgb.SymbolVar): | |||
var = [var] | |||
if isinstance(var_type, str): | |||
var_type = [var_type] | |||
q = list(var) | |||
while q: | |||
v = q.pop() | |||
if v in memo: | |||
continue | |||
memo.add(v) | |||
q.extend(get_inputs(v)) | |||
if var_type is not None: | |||
if get_type(v) in var_type: | |||
outputs.append(v) | |||
else: | |||
outputs.append(v) | |||
return outputs | |||
def get_inputs(var): | |||
"""get the inputs of owner opr of a variable | |||
:type var: :class:`.SymbolVar` | |||
:rtype: list of :class:`.SymbolVar` | |||
""" | |||
assert isinstance(var, _mgb.SymbolVar) | |||
return _mgb._get_owner_opr_inputs(var) | |||
def get_type(var): | |||
"""get the type of owner opr of a variable | |||
:type var: :class:`.SymbolVar` | |||
:rtype: ``str`` | |||
""" | |||
assert isinstance(var, _mgb.SymbolVar) | |||
return _mgb._get_owner_opr_type(var) | |||
def get_opr_type(opr): | |||
"""get the type of a opr | |||
:type var: :class:`.Operator` | |||
:rtype: ``str`` | |||
""" | |||
assert isinstance(opr, _mgb.Operator) | |||
return _mgb._get_opr_type(opr) | |||
def graph_traversal(outputs): | |||
"""helper function to traverse the computing graph and reeturn enough useful information | |||
:param outputs: model outputs | |||
:type outputs: :class:`.Symbolvar` | |||
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) | |||
WHERE | |||
map_oprs is dict from opr_id to actual opr | |||
map_vars is dict from var_id to actual var | |||
var2oprs is dict from var to dest oprs along with index | |||
opr2receivers is dict from current opr to next opr | |||
indegree2opr is dict from in_degree to opr in computing graph | |||
opr2indegree is dict from opr in computing graph to in_degree | |||
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function | |||
""" | |||
# meta information for comp graph | |||
map_oprs = collections.defaultdict(set) | |||
map_vars = collections.defaultdict(set) | |||
var2oprs = collections.defaultdict(list) | |||
opr2receivers = collections.defaultdict(list) | |||
queue = list(map(lambda x: x.owner_opr, outputs)) | |||
visited = set(map(lambda x: x.id, queue)) | |||
# iterate through whole comp_graph, fill in meta information | |||
indegree2opr = collections.defaultdict(set) | |||
opr2indegree = {} | |||
idx = 0 | |||
while idx < len(queue): | |||
cur_opr = queue[idx] | |||
map_oprs[cur_opr.id] = cur_opr | |||
idx += 1 | |||
indegree = 0 | |||
for var_idx, var in enumerate(cur_opr.inputs): | |||
map_vars[var.id] = var | |||
var2oprs[var.id].append((cur_opr.id, var_idx)) | |||
pre_opr = var.owner_opr | |||
if pre_opr.id not in visited: | |||
visited.add(pre_opr.id) | |||
queue.append(pre_opr) | |||
indegree += 1 | |||
opr2receivers[pre_opr.id].append(cur_opr.id) | |||
indegree2opr[indegree].add(cur_opr.id) | |||
opr2indegree[cur_opr.id] = indegree | |||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
def get_oprs_seq(outputs, prune_reshape=False): | |||
"""get oprs in some topological order for a dumped model | |||
:param outputs: model outputs | |||
:param prune_reshape: whether to prune the operators useless during inference | |||
:return: opr list with some correct execution order | |||
""" | |||
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): | |||
# generate an execution order with topological sort algorithm | |||
oprs_seq = [] | |||
nr_remain = len(map_oprs) | |||
while indegree2opr[0]: | |||
opr_id = indegree2opr[0].pop() | |||
opr = map_oprs[opr_id] | |||
nr_remain -= 1 | |||
# skip const value generation operator | |||
if get_opr_type(opr) != "ImmutableTensor": | |||
oprs_seq.append(opr) | |||
for post_id in opr2receivers[opr_id]: | |||
indegree = opr2indegree[post_id] | |||
indegree2opr[indegree].remove(post_id) | |||
indegree -= 1 | |||
indegree2opr[indegree].add(post_id) | |||
opr2indegree[post_id] = indegree | |||
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | |||
nr_remain | |||
) | |||
return oprs_seq | |||
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | |||
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | |||
def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | |||
def iterative_pruning(cur_opr, post_opr, marked_opr_ids): | |||
useless = True | |||
for oup in cur_opr.outputs: | |||
if "workspace" not in oup.name: | |||
var_idx = post_opr.inputs.index(oup) | |||
var2oprs[oup.id].remove((post_opr.id, var_idx)) | |||
useless = useless and (len(var2oprs[oup.id]) == 0) | |||
if useless: | |||
marked_opr_ids.append(cur_opr.id) | |||
for inp in cur_opr.inputs: | |||
iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids) | |||
reshape_vars = get_dep_vars(outputs, "Reshape") | |||
reshape_oprs = [var.owner_opr for var in reshape_vars] | |||
marked_opr_ids = [] | |||
for reshape_opr in reshape_oprs: | |||
iterative_pruning( | |||
reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids | |||
) | |||
# filter out all marked oprs | |||
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | |||
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | |||
outputs | |||
) | |||
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | |||
if prune_reshape is True: | |||
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | |||
return oprs_seq | |||
def replace_vars(dst, varmap): | |||
"""replace vars in the graph | |||
:param dst: target vars representing the graph | |||
:type dst: list of :class:`.SymbolVar` | |||
:param varmap: the map that specifies how to replace the vars | |||
:type varmap: dict that maps from src var to dst var | |||
:return: new vars that correspond to ``dst`` with all the dependencies | |||
replaced | |||
:rtype: list of :class:`.SymbolVar` | |||
""" | |||
dst_vec = _mgb._VectorSymbolVar() | |||
repl_src_vec = _mgb._VectorSymbolVar() | |||
repl_dst_vec = _mgb._VectorSymbolVar() | |||
for i in dst: | |||
assert isinstance(i, _mgb.SymbolVar) | |||
dst_vec.push_back(i) | |||
for i, j in getattr(varmap, "items", lambda: varmap)(): | |||
assert isinstance(i, _mgb.SymbolVar) | |||
assert isinstance(j, _mgb.SymbolVar) | |||
repl_src_vec.push_back(i) | |||
repl_dst_vec.push_back(j) | |||
return _mgb._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) | |||
def replace_oprs(dst, oprmap): | |||
"""Replace operators in the graph. Roughly equivalent to | |||
:param dst: target vars representing the graph | |||
:type dst: list of :class:`.SymbolVar` | |||
:param oprmap: the map that specifies how to replace the operators | |||
:type oprmap: dict that maps from src operator to dst operator | |||
:return: new vars that correspond to ``dst`` with all the dependencies | |||
replaced | |||
:rtype: list of :class:`.SymbolVar` | |||
""" | |||
dst_vec = _mgb._VectorSymbolVar() | |||
repl_src_vec = _mgb._VectorOperator() | |||
repl_dst_vec = _mgb._VectorOperator() | |||
for i in dst: | |||
assert isinstance(i, _mgb.SymbolVar) | |||
dst_vec.push_back(i) | |||
for i, j in getattr(oprmap, "items", lambda: oprmap)(): | |||
assert isinstance(i, _mgb.Operator) | |||
assert isinstance(j, _mgb.Operator) | |||
repl_src_vec.push_back(i) | |||
repl_dst_vec.push_back(j) | |||
return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | |||
def set_priority_to_id(dest_vars): | |||
"""For all oprs in the subgraph constructed by dest_vars | |||
set its priority to id if its original priority is zero | |||
:param dest_vars: target vars representing the graph | |||
""" | |||
dest_vec = _mgb._VectorSymbolVar() | |||
for i in dest_vars: | |||
assert isinstance(i, _mgb.SymbolVar) | |||
dest_vec.push_back(i) | |||
_mgb._set_priority_to_id(dest_vec) |
@@ -1,439 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import os | |||
from . import mgb as _mgb | |||
_default_device_type = "CUDA" | |||
def set_device_map(logical_dev, physical_dev, device_type=None): | |||
"""map from *logical_dev* to *physical_dev* for furture comp node | |||
loading | |||
example:: | |||
set_device_map(0, 2, 'CPU') # cpu0 -> cpu2 | |||
set_device_map('gpu3', 'gpu0') # gpu0 -> gpu0 | |||
:param device_type: specify the device type if devices are given by | |||
integers; if devices are given by integers and ``device_type`` is not | |||
given, the default value ``'CUDA'`` would be used. Possible values are | |||
``'CUDA'`` and ``'CPU'``. | |||
""" | |||
if device_type is None: | |||
device_type = _default_device_type | |||
if device_type == "CUDA": | |||
xpu = "gpu" | |||
else: | |||
assert device_type == "CPU" | |||
xpu = "cpu" | |||
def rmxpu(v): | |||
if isinstance(v, str): | |||
assert v.startswith(xpu) or v.startswith("xpu"), ( | |||
"bad comp node in set_device_map: " | |||
"device_type={} comp_node={}".format(device_type, v) | |||
) | |||
return v[3:] | |||
return v | |||
logical_dev, physical_dev = map(rmxpu, [logical_dev, physical_dev]) | |||
_mgb.CompNode._set_device_map(device_type, int(logical_dev), int(physical_dev)) | |||
def set_default_device(physical_dev, device_type=None): | |||
"""set physcal device for xpux | |||
when *device_type* is None and *physical_dev* starts with *gpu* or *cpu*, | |||
the default device type would be modified accordingly for future calls to | |||
:func:`set_device_map` when remapping device number. | |||
""" | |||
global _default_device_type | |||
if ( | |||
device_type is None | |||
and isinstance(physical_dev, str) | |||
and not physical_dev.isdigit() | |||
and not physical_dev.startswith("xpu") | |||
): | |||
t = physical_dev[:3] | |||
if t == "gpu": | |||
_default_device_type = "CUDA" | |||
else: | |||
assert t == "cpu", "bad physical_dev: {}".format(physical_dev) | |||
_default_device_type = "CPU" | |||
set_default_device_type(_default_device_type) | |||
device_type = _default_device_type | |||
set_device_map(-1, physical_dev, device_type) | |||
def set_default_device_type(device_type): | |||
"""set device type for xpu""" | |||
global _default_device_type | |||
device_type = device_type.upper() | |||
_mgb.CompNode._set_unspec_device_type(device_type) | |||
_default_device_type = device_type | |||
def set_fork_cuda_warning_flag(flag): | |||
"""set warning to be printed at fork if cuda has been initialized | |||
:type flag: int | |||
:param flag: controls how the warning should be printed: | |||
* 0: disable warning | |||
* 1: print warning to log | |||
* 2: print warning to log and raise exception | |||
""" | |||
_mgb._config.set_fork_cuda_warning_flag(int(flag)) | |||
def get_device_count(device_type="xpu", warn=True): | |||
"""get number of devices installed on this system | |||
:param device_type: device type, one of 'xpu', 'gpu' or 'cpu' | |||
:type device_type: str | |||
""" | |||
return _mgb.CompNode._get_device_count(device_type.upper(), warn) | |||
def parse_locator(device_name: str) -> tuple: | |||
"""get the tensor locator expression by device name. | |||
:param device_name: device name, like 'cpu0', 'gpu1' and 'xpux' | |||
:type device_name: str | |||
:return: (device_type, dev_num, stream_num) | |||
""" | |||
return _mgb.CompNode._parse_locator(device_name) | |||
def set_mem_reserve_size(size): | |||
"""set memory reserve size: | |||
* If *size* is greater than 1, it is the absolute amount of memory to | |||
be reserved in MB; | |||
* If *size* is in the range (0, 1), it is the ratio of total memory; | |||
* If *size* is 0, memory reservation and pre-allocation would be | |||
disabled; | |||
* If *size* is -1, disable custom memory allocator and use cuda APIs | |||
directly. | |||
""" | |||
_mgb._config.set_mem_reserve_size(float(size)) | |||
def set_comp_graph_option(comp_graph, name, val): | |||
"""set computing graph option and return its old value | |||
:type comp_graph: :class:`.CompGraph` | |||
:param comp_graph: the computing graph whose option should be modified | |||
:type name: str | |||
:param name: option name | |||
Currently supported options are: | |||
* "no_profiling_on_shape_change": bool; | |||
When execution strategy is set to profiling, always use the | |||
initial profile result and do not re-run profiling even if input | |||
shape changes. | |||
* "seq_opt.enable_mem_plan_opt": bool | |||
* "seq_opt.enable_mem_reuse_alloc": bool | |||
* "seq_opt.enable_seq_comp_node_opt": bool | |||
* "force_dynamic_alloc": bool | |||
* "var_sanity_check_first_run": bool | |||
* "enable_sublinear_memory_opt": bool | |||
* "enable_memory_swap": bool; whether to enable memory swap; it | |||
usually performs worse than sublinear memory | |||
* "enable_var_mem_defragment": bool | |||
* "allocate_static_mem_after_graph_compile": bool | |||
* "enable_grad_var_static_reshape": bool: | |||
If set to ``True``, dynamically-shaped gradients whose original | |||
shape is statically inferrable would be reshaped, so static | |||
shape inference can continue | |||
* "async_exec_level": int | |||
* ``0``: do not dispatch asynchronously | |||
* ``1``: async dispatch if there are more than 1 cuda comp | |||
nodes | |||
* mask ``0b10``: async for comp nodes with unlimited queue | |||
(e.g. CPU comp nodes) | |||
* mask ``0b100``: async for even one comp node | |||
* "log_level": int | |||
* ``0``: no log info for graph construction/compiling | |||
* ``1``: static memory allocation status, | |||
WorkspaceLimitGetter summary, and optimizer summary | |||
* ``2``: optimizer details and duplicated operators tha are | |||
removed | |||
* "graph_opt.jit": whether to enable JIT | |||
* "graph_opt.tensorrt": whether to enable fine-grained automatic | |||
replacement for TensorRT operators | |||
* "graph_opt.android_nn": whether to enable fine-grained automatic | |||
replacement for Android NN operators | |||
* "graph_opt_level": int | |||
* ``0``: disable | |||
* ``1``: level-1: inplace arith transformations during graph | |||
construction | |||
* ``2``: (default) level-2: level-1, plus global optimization | |||
before graph compiling | |||
* ``3``: also enable JIT | |||
:param val: new option value | |||
:return: old option value | |||
""" | |||
if name == "log_static_mem_alloc": | |||
name = "log_level" | |||
if name == "enable_async_exec": | |||
name = "async_exec_level" | |||
return _mgb._config.set_comp_graph_option(comp_graph, name, int(val)) | |||
def comp_graph_is_eager(comp_graph): | |||
return _mgb._config.comp_graph_is_eager(comp_graph) | |||
def add_extra_vardep(var, dep): | |||
"""add *dep* as an extra dependency of *var*, so if *var* is required to | |||
compute the final output when compiling a comp graph, *dep* would also be | |||
included in the computing sequence. Note that the order computing of these | |||
two vars is not guaranteed. | |||
""" | |||
assert isinstance(var, _mgb.SymbolVar) and isinstance(dep, _mgb.SymbolVar) | |||
assert var.owner_graph == dep.owner_graph | |||
return _mgb._config.add_extra_vardep(var, dep) | |||
class _GraphPropertyBase: | |||
"""helper class for implementing operator property setter context managers""" | |||
_cur_graph = None | |||
_graph2stack = None | |||
"""class attribute that maintains mapping from graph to property stack; | |||
should be defined by child classes""" | |||
__prop_setup__ = None | |||
"""overwritten by subclass to setup property""" | |||
__prop_clear__ = None | |||
"""overwritten by subclass to clear property""" | |||
def __init__(self, comp_graph, prop): | |||
""":param comp_graph: computing graph, or None to not set this | |||
property""" | |||
if comp_graph is not None: | |||
assert isinstance( | |||
comp_graph, _mgb.CompGraph | |||
), "invalid comp graph: {!r}".format(comp_graph) | |||
self._cur_graph = comp_graph | |||
self._graph2stack.setdefault(comp_graph, []).append(prop) | |||
def __setup(self, prop): | |||
self.__prop_setup__(self._cur_graph, prop) | |||
def __clear(self): | |||
self.__prop_clear__(self._cur_graph) | |||
def __enter__(self): | |||
if self._cur_graph is None: | |||
return | |||
stack = self._graph2stack[self._cur_graph] | |||
if len(stack) > 1: | |||
# clear nested property | |||
self.__clear() | |||
self.__setup(stack[-1]) | |||
def __exit__(self, exc_type, exc_value, exc_traceback): | |||
if self._cur_graph is None: | |||
return | |||
stack = self._graph2stack[self._cur_graph] | |||
self.__clear() | |||
stack.pop() | |||
if stack: | |||
# restore nested property | |||
self.__setup(stack[-1]) | |||
else: | |||
del self._graph2stack[self._cur_graph] | |||
class exc_opr_tracker_scope(_GraphPropertyBase): | |||
"""context manager for associating an object with all operators created | |||
within this context; so when an exception is raised, information about the | |||
corresponding operator could be retrieved from | |||
:attr:`.MegBrainError.tracker` | |||
:param comp_graph: the computing graph where the operators should be tracked | |||
:type comp_graph: :class:`.CompGraph` | |||
:param tracker: an arbitrary python object to track the operators | |||
""" | |||
_graph2stack = {} | |||
def __init__(self, comp_graph, tracker): | |||
assert ( | |||
tracker is not None | |||
), "bad args for exc_opr_tracker_scope: {!r} {!r}".format(comp_graph, tracker) | |||
super().__init__(comp_graph, tracker) | |||
__prop_setup__ = staticmethod(_mgb._config.begin_set_exc_opr_tracker) | |||
__prop_clear__ = staticmethod(_mgb._config.end_set_exc_opr_tracker) | |||
class opr_priority_scope(_GraphPropertyBase): | |||
"""context manager for setting priority for all operators created in this | |||
context | |||
:param comp_graph: the computing graph for which operator priority should | |||
be set | |||
:type comp_graph: :class:`.CompGraph` | |||
:param priority: operator priority. Smaller number means higher priority. | |||
Default value is 0. Grad operator would use negative priority by | |||
default. | |||
""" | |||
_graph2stack = {} | |||
LOWEST_PRIORITY = 2 ** 31 - 1 | |||
"""lowest prority (i.e. max possible value)""" | |||
HIGHEST_PRIORITY = -LOWEST_PRIORITY | |||
"""highest prority (i.e. min possible value)""" | |||
def __init__(self, comp_graph, priority): | |||
super().__init__(comp_graph, int(priority)) | |||
__prop_setup__ = staticmethod(_mgb._config.begin_set_opr_priority) | |||
__prop_clear__ = staticmethod(_mgb._config.end_set_opr_priority) | |||
OprTrackerResult = collections.namedtuple( | |||
"OprTrackerResult", ["msg", "tracker", "grad_tracker"] | |||
) | |||
def get_opr_tracker(cg, var_id): | |||
"""get the tracking object associated with the owner operator of a var | |||
:param cg: the computing graph | |||
:param var_id: id of the var whose owner opr tracker should be found | |||
:return: if no var is found, ``None`` is returned; otherwise return an | |||
:class:`OprTrackerResult` object | |||
""" | |||
assert isinstance(cg, _mgb.CompGraph) | |||
ret = _mgb._config.get_opr_tracker(cg, int(var_id)) | |||
if ret is None: | |||
return | |||
return OprTrackerResult(*ret) | |||
def set_opr_sublinear_memory_endpoint(var): | |||
"""set the owner operator of a symvar to be endpoint of sublinear memory | |||
optimizer | |||
:type var: :class:`.SymbolVar` | |||
""" | |||
_mgb._config.set_opr_sublinear_memory_endpoint(var) | |||
def max_size_t(): | |||
"""get max value of size_t type on local architecture""" | |||
return _mgb.max_size_t() | |||
def is_cuda_ctx_set(): | |||
"""return whether current thread has an active cuda driver context""" | |||
return _mgb._config.is_cuda_ctx_set() | |||
def get_include_path(): | |||
"""get include path for building megbrain extensions""" | |||
return os.path.join(os.path.realpath(os.path.dirname(__file__)), "include") | |||
def get_cuda_gencode(only_cap=False): | |||
"""get -gencode options to be passed to nvcc for compiling on local | |||
machine | |||
:param only_cap: if True, return only a list of cuda compute capability | |||
strings (like ``['35', '52']`` ) | |||
""" | |||
ret = _mgb._config.get_cuda_gencode().split() | |||
if not only_cap: | |||
ret = " ".join(map("-gencode arch=compute_{0},code=sm_{0}".format, ret)) | |||
return ret | |||
def get_cuda_lib_path(): | |||
"""get the cuda lib64 path by locating nvcc | |||
""" | |||
return _mgb._config.get_cuda_lib_path() | |||
def get_cuda_include_path(): | |||
"""get the cuda include path by locating nvcc, including | |||
parent path and `parent path`/include | |||
""" | |||
return _mgb._config.get_cuda_include_path() | |||
def get_cuda_version(): | |||
"""get runtime cuda version | |||
""" | |||
return _mgb._config.get_cuda_version() | |||
def is_local_cuda_env_ok(): | |||
"""check whether local cuda environment ok by locating nvcc | |||
""" | |||
return _mgb._config.is_local_cuda_env_ok() | |||
def is_compiled_with_cuda(): | |||
"""whether cuda is enabled at compile time""" | |||
return _mgb._config.is_compiled_with_cuda() | |||
def load_opr_library(path): | |||
"""Load an external operator library. This essentially sets megbrain | |||
symbols as public and load the library. | |||
:param path: path to the shared object; if it is None, then only megbrain | |||
symbols are made public. | |||
""" | |||
_mgb._config.load_opr_library( | |||
os.path.realpath(os.path.join(os.path.dirname(__file__), "_mgb.so")), path | |||
) | |||
def dump_registered_oprs(): | |||
""" | |||
get all registered oprs, return dict(id, name) | |||
""" | |||
return dict(_mgb._config.dump_registered_oprs()) | |||
def create_mm_server(server_addr, port): | |||
""" | |||
create mm server with server address | |||
throw exception if server_addr is already used | |||
""" | |||
return _mgb._config.create_mm_server(server_addr, port) | |||
def group_barrier(server_addr, port, size, rank): | |||
""" | |||
block until all ranks reach this barrier | |||
""" | |||
return _mgb._config.group_barrier(server_addr, port, size, rank) |
@@ -1,432 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""used for creating a megbrain operator from python""" | |||
import copy | |||
import itertools | |||
from abc import ABCMeta, abstractmethod, abstractproperty | |||
from . import helper as _helper | |||
from . import mgb as _mgb | |||
class _CraniotomeBaseMeta(ABCMeta): | |||
_base_created = False | |||
def __init__(cls, name, bases, member_dict): | |||
if _CraniotomeBaseMeta._base_created: | |||
assert "__init__" not in member_dict, ( | |||
"Craniotome operators should not overwrite __init__ method; " | |||
"use setup() instead." | |||
) | |||
forbidden = set( | |||
k for k in dir(CraniotomeBase) if k[0] == "_" and k[1] != "_" | |||
) | |||
forbidden.add("get_io_vars") | |||
check_key = member_dict.get("__check_key__", True) | |||
whitelist = ["__classcell__"] | |||
for k in member_dict.keys(): | |||
assert k not in forbidden, "{} could not be overwritten".format(k) | |||
if ( | |||
check_key | |||
and k.startswith("__") | |||
and k.endswith("__") | |||
and k not in whitelist | |||
and not hasattr(CraniotomeBase, k) | |||
): | |||
raise KeyError( | |||
"name {} in class {} does not exist in the baseclass".format( | |||
k, name | |||
) | |||
) | |||
else: | |||
_CraniotomeBaseMeta._base_created = True | |||
super().__init__(name, bases, member_dict) | |||
class CraniotomeBase(_mgb.CraniotomeDesc, metaclass=_CraniotomeBaseMeta): | |||
"""base class used for extending megbrain core operators in python | |||
Note: all names starting and ending with two underscores in the subclasses | |||
would be checked and KeyError would be raised if the name does not exist in | |||
the base class. This behavor can be disabled by setting ``__check_key__`` | |||
to ``False`` (see the testcase for more details) | |||
""" | |||
# methods and attributes to be overwritten by subclasses | |||
__expand_single_outputs__ = True | |||
"""if :attr:`__nr_outputs__` is 1, whether to return a single | |||
:class:`.SymbolVar` instead of a tuple in :meth:`make`""" | |||
__is_dynamic_output_shape__ = False | |||
"""whether output shape could not be inferred from input shape. If value of | |||
this attribute is ``False``, :meth:`infer_shape` must be implemented. If | |||
this attribute is ``True`` but the operator has no inputs, then | |||
:meth:`infer_shape` would also be called to infer output shape before | |||
operator execution. | |||
""" | |||
__disable_sys_mem_alloc__ = False | |||
"""whether to disable system memory allocator. This is used when | |||
:attr:`__is_dynamic_output_shape__` is ``False`` but the output memory | |||
should not be managed by megbrain system (so it can be forwarded from | |||
external buffer)""" | |||
__allow_duplicate__ = True | |||
"""whether this operator can be duplicated (e.g. used in sublinear | |||
memory)""" | |||
__allow_empty_out__ = False | |||
"""whether empty output shape is allowed; if it is set as ``False``, then | |||
an exception would be raised if output var is empty to prevent erroneously | |||
forgetting initializing output vars""" | |||
@abstractproperty | |||
def __nr_inputs__(self): | |||
"""number of input vars""" | |||
@abstractproperty | |||
def __nr_outputs__(self): | |||
"""number of output vars""" | |||
@abstractmethod | |||
def execute(self, inputs, outputs): | |||
"""execute the operator, read values from *inputs* by calling | |||
:meth:`.CompGraphCallbackValueProxy.get_value` and write results into | |||
*outputs* by calling :meth:`.SharedND.set_value` | |||
:param inputs: values for each input var | |||
:type inputs: tuple of :class:`.CompGraphCallbackValueProxy` | |||
:param outputs: values for each output var | |||
:type outputs: tuple of :class:`.SharedND` | |||
""" | |||
def setup(self): | |||
"""overwritten by subclass to accept kwargs passed to :meth:`make` to | |||
setup the operator""" | |||
def infer_shape(self, inp_shapes): | |||
"""infer output shape from input shapes | |||
:type inp_shapes: tuple of tuple of ints | |||
:param inp_shapes: input shapes for each input var | |||
:rtype: tuple of tuple of ints | |||
:return: output shapes for each output var | |||
""" | |||
raise NotImplementedError( | |||
"{}: infer_shape() not implemented; for operators with dynamic " | |||
"output shape, __is_dynamic_output_shape__ should be set to True".format( | |||
self | |||
) | |||
) | |||
def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
"""compute symbolic gradient; should be overwritten by differentiable | |||
subclasses | |||
:type wrt_idx: int | |||
:param wrt_idx: the input var with respect to which the gradient should | |||
be computed; please also see the notes below | |||
:type inputs: tuple of :class:`.SymbolVar` | |||
:param inputs: input symbol vars | |||
:type outputs: tuple of :class:`.SymbolVar` | |||
:param outputs: output symbol vars | |||
:type out_grad: tuple of (:class:`.SymbolVar` or None) | |||
:param out_grad: gradients of loss with respect to each output var | |||
.. note:: | |||
In case when loss does not depend on some var (i.e. zero grad), | |||
the corresponding value in *out_grad* would be ``None``. It is | |||
guaranteed that at least one element in *out_grad* is not | |||
``None``. | |||
.. note:: | |||
This function can return either of the following: | |||
1. Gradient of the input specified by ``wrt_idx`` | |||
2. A list containing gradients of all inputs. In this case, | |||
``wrt_idx`` can be ignored. | |||
And the so called gradient can be either one of: | |||
1. A :class:`.SymbolVar` representing the symbolic gradient | |||
value | |||
2. ``0`` representing zero gradient | |||
""" | |||
raise NotImplementedError("grad for {} not implemented".format(self)) | |||
def init_output_dtype(self, input_dtypes): | |||
"""infer output dtypes from input dtypes; return None to use default | |||
infer function in megbrain. | |||
.. note:: | |||
This method must be implemented if there is no input var | |||
:param input_dtypes: input dtypes | |||
:type input_dtypes: list of :class:`numpy.dtype` | |||
:rtype: None or list of :class:`numpy.dtype`-compatible | |||
""" | |||
def get_serialize_params(self): | |||
"""get params for megbrain graph serialization. This function should | |||
return a list or tuple, containing one or two elements: the first | |||
element must be a string, representing the name passed to | |||
``opr_loader_maker`` during deserializing; the second element, if | |||
exists, must be convertible to ``bytes`` and is used for dumping any | |||
extra opr params, which can be retrieved by ``load_buf_with_len`` | |||
during deserializing. | |||
""" | |||
raise NotImplementedError( | |||
"get_serialize_params() for {} not implemented".format(self) | |||
) | |||
def copy(self): | |||
"""copy this craniotome descriptor; the default implementation creates | |||
a new object, and copies object ``__dict__``""" | |||
ret = type(self)() | |||
d0 = self.__dict__.copy() | |||
d0.pop("this") | |||
ret.__dict__.update(copy.deepcopy(d0)) | |||
return ret | |||
def on_graph_compiled(self, used_outputs): | |||
"""a callback that would be invoked when the graph is compiled; it | |||
would always have a matching :meth:`on_compiled_func_deleted` call | |||
:param used_outputs: indices of outputs that are needed for the | |||
computation | |||
:type used_outputs: ``tuple of int`` | |||
""" | |||
def on_compiled_func_deleted(self): | |||
"""a callback that would be invoked when the compiled function is | |||
destructed; it would always have a matching :meth:`on_graph_compiled` | |||
call""" | |||
def get_io_vars(self): | |||
"""get input vars, comp order dep vars and output vars | |||
:return: a dict with keys ``'input'``, ``'output'`` and | |||
``'comp_order'`` that maps to corresponding list of vars | |||
""" | |||
all_vars = list(self._get_all_io_vars()) | |||
nr_inp = self.__nr_inputs__ | |||
nr_out = self.__nr_outputs__ | |||
nr_comp_order = self._get_nr_dev_comp_order_deps() | |||
s0 = nr_inp + nr_comp_order | |||
return dict( | |||
input=all_vars[:nr_inp], | |||
comp_order=all_vars[nr_inp:s0], | |||
output=all_vars[s0:], | |||
) | |||
@property | |||
def owner_opr_id(self): | |||
"""ID of the operator that owns this descriptor""" | |||
return self._get_opr_id() | |||
@property | |||
def comp_node(self): | |||
"""comp node on which this operator runs""" | |||
return self._get_comp_node() | |||
# below are methods that should not be changed | |||
def _hash(self): | |||
return int(hash(self)) % (1 << 64) | |||
def _setup_self(self, dst): | |||
dst.append(self) | |||
def _is_same(self, rhs): | |||
return bool(self == rhs) | |||
def _node_flag(self): | |||
return ( | |||
(int(bool(self.__is_dynamic_output_shape__)) << 0) | |||
| (int(not self.__allow_duplicate__) << 1) | |||
| (int(bool(self.__allow_empty_out__)) << 2) | |||
| (int(bool(self.__disable_sys_mem_alloc__)) << 3) | |||
) | |||
def _get_opr_type_name(self): | |||
return str(self.__class__.__name__) | |||
def _get_nr_outputs(self): | |||
return int(self.__nr_outputs__) | |||
def _execute(self, inputs, outputs): | |||
inputs = tuple(inputs) | |||
outputs = tuple(outputs) | |||
if not self.__is_dynamic_output_shape__: | |||
out_shapes = [i.shape for i in outputs] | |||
self.execute(inputs, outputs) | |||
if not self.__is_dynamic_output_shape__: | |||
new_shapes = [i.shape for i in outputs] | |||
assert ( | |||
out_shapes == new_shapes | |||
), "output shape changed after executing {}: before={} after={}".format( | |||
self, out_shapes, new_shapes | |||
) | |||
def _infer_shape(self, inp_shapes): | |||
inp_shapes = tuple(tuple(map(int, i)) for i in inp_shapes) | |||
oshp_get = self.infer_shape(inp_shapes) | |||
assert ( | |||
len(oshp_get) == self.__nr_outputs__ | |||
), "{}: expect {} outputs; got {}(val: {}) from infer_shape".format( | |||
self, self.__nr_outputs__, len(oshp_get), oshp_get | |||
) | |||
return _helper.cvt_to_vector_of_shape(oshp_get) | |||
def _grad(self, wrt_idx, inputs, outputs, out_grad): | |||
og = [] | |||
for i in out_grad: | |||
if i.valid: | |||
og.append(i) | |||
else: | |||
og.append(None) | |||
rst = self.grad(int(wrt_idx), tuple(inputs), tuple(outputs), tuple(og)) | |||
if not isinstance(rst, (list, tuple)): | |||
rst = [rst] | |||
else: | |||
assert len(rst) == len( | |||
inputs | |||
), "{}: opr has {} inputs but {} grads are returned".format( | |||
self, len(inputs), len(rst) | |||
) | |||
for i in range(len(rst)): | |||
cur = rst[i] | |||
if cur is 0: | |||
rst[i] = _mgb.SymbolVar() | |||
else: | |||
assert isinstance(cur, _mgb.SymbolVar), ( | |||
"{}: invalid grad result; it should be either " | |||
"0 or a SymbolVar, got {!r} instead".format(self, cur) | |||
) | |||
return rst | |||
def _get_nr_dev_comp_order_deps(self): | |||
return 0 | |||
def _init_output_dtype(self, input_dtypes, ret): | |||
get = self.init_output_dtype(input_dtypes) | |||
if get is not None: | |||
assert isinstance(ret, (list, tuple)) and len(get) == len(ret) | |||
ret[:] = get | |||
return True | |||
assert self.__nr_inputs__, ( | |||
"{}: init_output_dtype must be implemented " | |||
"if there is no input var".format(self) | |||
) | |||
return False | |||
def _setup_serialize_params(self, output): | |||
val = list(self.get_serialize_params()) | |||
assert len(val) in [1, 2] | |||
name = val[0] | |||
assert isinstance(name, str) | |||
output.append(name) | |||
if len(val) == 2: | |||
output.append(bytes(val[1])) | |||
def _copy(self): | |||
ret = self.copy() | |||
assert type(ret) is type( | |||
self | |||
), "copy() returned different type: src={} copied={}".format( | |||
type(self), type(ret) | |||
) | |||
assert ret is not self | |||
ret.__disown__() | |||
self._set_copy_result(ret) | |||
def _on_graph_compile_or_func_del(self, used_outputs): | |||
if used_outputs: | |||
self.on_graph_compiled(used_outputs) | |||
else: | |||
self.on_compiled_func_deleted() | |||
def __repr__(self): | |||
return "cranoiotome:{}".format(self.__class__.__name__) | |||
@classmethod | |||
def make( | |||
cls, | |||
*inputs, | |||
comp_graph=None, | |||
name=None, | |||
comp_node=None, | |||
config=None, | |||
dev_comp_order_deps=[], | |||
**kwargs | |||
): | |||
"""apply this operator on some input vars and return corresponding | |||
output vars | |||
:type inputs: tuple of :class:`.SymbolVar` | |||
:param inputs: input symvars; immediate values could also be accepted, | |||
as long as there is symvar to infer comp node and comp graph | |||
:param comp_graph: if there is no input vars, *comp_graph* must be | |||
provided to specify which computing graph to insert this operator | |||
:param dev_comp_order_deps: vars that must have been computed | |||
before executing this operator | |||
:param kwargs: extra keyword arguments to be passed to :meth:`setup` of | |||
this class | |||
:param name: name of the resulting operator | |||
:rtype: tuple of :class:`.SymbolVar` | |||
:return: output symvars | |||
""" | |||
if not inputs and not dev_comp_order_deps: | |||
assert isinstance( | |||
comp_graph, _mgb.CompGraph | |||
), "{}: comp_graph must be given if no inputs provided".format(self) | |||
desc = cls() | |||
desc.setup(**kwargs) | |||
assert ( | |||
len(inputs) == desc.__nr_inputs__ | |||
), "{}: expected {} inputs, got {}".format( | |||
desc, desc.__nr_inputs__, len(inputs) | |||
) | |||
config = _helper.gen_config(name, comp_node, config) | |||
# get inp_vec | |||
inp_vec = _mgb._VectorSymbolVar() | |||
for i in _helper.canonize_input_vars( | |||
itertools.chain(inputs, dev_comp_order_deps), | |||
comp_graph=comp_graph, | |||
config=config, | |||
): | |||
inp_vec.push_back(i) | |||
desc._get_nr_dev_comp_order_deps = lambda *, val=len(dev_comp_order_deps): val | |||
if comp_graph is not None: | |||
desc._get_comp_graph = lambda: comp_graph | |||
expand_single_outputs = desc.__expand_single_outputs__ | |||
desc.__disown__() | |||
rst = _mgb.make_opr_from_craniotome_desc(desc, inp_vec, config) | |||
if expand_single_outputs and len(rst) == 1: | |||
return rst[0] | |||
return tuple(rst) | |||
def make_opr(cls): | |||
"""decorator used to wrap a :class:`.CraniotomeBase` subclass and return | |||
its :meth:`~.CraniotomeBase.make` method | |||
""" | |||
assert issubclass(cls, CraniotomeBase) | |||
return cls.make |
@@ -1,286 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Union | |||
import numpy as np | |||
from .mgb import bfloat16, intb1, intb2, intb4 | |||
_QuantDtypeMetadata = collections.namedtuple( | |||
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||
) | |||
_metadata_dict = { | |||
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||
"qint32": _QuantDtypeMetadata( | |||
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||
), | |||
# NOTE: int2 is not supported for model dump yet | |||
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||
} | |||
def is_quantize(dtype): | |||
return ( | |||
hasattr(dtype, "metadata") | |||
and dtype.metadata is not None | |||
and "mgb_dtype" in dtype.metadata | |||
) | |||
def is_lowbit(dtype): | |||
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
def is_bfloat16(dtype): | |||
return dtype is bfloat16 | |||
def get_scale(dtype): | |||
assert is_quantize(dtype) | |||
return dtype.metadata["mgb_dtype"]["scale"] | |||
def get_zero_point(dtype): | |||
assert is_quantize(dtype) | |||
metadata = dtype.metadata["mgb_dtype"] | |||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
return metadata["zero_point"] | |||
def _check_zero_point(zp: int, dtype_str: str): | |||
qmin = _metadata_dict[dtype_str].qmin | |||
qmax = _metadata_dict[dtype_str].qmax | |||
if zp < qmin or zp > qmax: | |||
raise ValueError( | |||
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||
) | |||
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
r""" | |||
Get quantized dtype with metadata attribute according to _metadata_dict. | |||
Note that unsigned dtype must have ``zero_point`` and signed dtype must | |||
not have ``zero_point``, to be consitent with tensor generated by calling | |||
compiled function from `CompGraph.compile(inputs, outspec)`. | |||
:param dtype: a string indicating which dtype to return | |||
:param scale: a number for scale to store in dtype's metadata | |||
:param zp: a number for zero_point to store in dtype's metadata | |||
""" | |||
metadata = _metadata_dict[dtype_str] | |||
np_dtype_str = metadata.np_dtype_str | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
if zp is None or int(zp) != zp: | |||
raise ValueError("zero_point should be an integer") | |||
zp = int(zp) | |||
_check_zero_point(zp, dtype_str) | |||
return np.dtype( | |||
np_dtype_str, | |||
metadata={ | |||
"mgb_dtype": { | |||
"name": metadata.name, | |||
"scale": float(scale), | |||
"zero_point": zp, | |||
} | |||
}, | |||
) | |||
else: | |||
return np.dtype( | |||
np_dtype_str, | |||
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||
) | |||
def quint8(scale, zero_point): | |||
""" | |||
Consturct a quantized unsigned int8 data type with ``scale`` (float) and | |||
``zero_point`` (uint8). The real value represented by a quint8 data type is | |||
float_val = scale * (uint8_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint8", scale, zero_point) | |||
def qint8(scale): | |||
""" | |||
Construct a quantized int8 data type with ``scale`` (float). The real value | |||
represented by a qint8 data type is float_val = scale * int8_val | |||
""" | |||
return get_quantized_dtype("qint8", scale, None) | |||
def qint32(scale): | |||
""" | |||
Construct a quantized int32 data type with ``scale`` (float). The real value | |||
represented by a qint32 data type is float_val = scale * int32_val | |||
""" | |||
return get_quantized_dtype("qint32", scale, None) | |||
def quint4(scale, zero_point): | |||
""" | |||
Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
float_val = scale * (uint4_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint4", scale, zero_point) | |||
def qint4(scale): | |||
""" | |||
Construct a quantized int4 data type with ``scale`` (float). The real value | |||
represented by a qint4 data type is float_val = scale * int4_val | |||
""" | |||
return get_quantized_dtype("qint4", scale, None) | |||
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = dtype.metadata["mgb_dtype"] | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
) | |||
return ( | |||
(np.round(arr / scale) + zp) | |||
.clip(metadata.qmin, metadata.qmax) | |||
.astype(dtype) | |||
) | |||
else: | |||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
scale = arr_metadata["scale"] | |||
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
) | |||
return (arr.astype(np.float32) - zp) * scale | |||
else: | |||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
scale = arr_metadata["scale"] | |||
return (arr.astype(np.float32)) * scale | |||
def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint8") | |||
def convert_from_quint8(arr: np.ndarray): | |||
""" | |||
Dequantize a quint8 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint8") | |||
def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint8 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint8") | |||
def convert_from_qint8(arr: np.ndarray): | |||
""" | |||
Dequantize a qint8 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint8") | |||
def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint32 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint32") | |||
def convert_from_qint32(arr): | |||
""" | |||
Dequantize a qint32 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint32") | |||
def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint4") | |||
def convert_from_quint4(arr: np.ndarray): | |||
""" | |||
Dequantize a quint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint4") | |||
def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint4") | |||
def convert_from_qint4(arr: np.ndarray): | |||
""" | |||
Dequantize a qint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint4") |
@@ -1,947 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright [2001] [Cython] | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# ---------------------------------------------------------------------- | |||
import sys | |||
from functools import reduce | |||
from operator import or_ as _or_ | |||
from types import DynamicClassAttribute, MappingProxyType | |||
# try _collections first to reduce startup cost | |||
try: | |||
from _collections import OrderedDict | |||
except ImportError: | |||
from collections import OrderedDict | |||
__all__ = [ | |||
"EnumMeta", | |||
"Enum", | |||
"IntEnum", | |||
"Flag", | |||
"IntFlag", | |||
"auto", | |||
"unique", | |||
] | |||
def _is_descriptor(obj): | |||
"""Returns True if obj is a descriptor, False otherwise.""" | |||
return ( | |||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
) | |||
def _is_dunder(name): | |||
"""Returns True if a __dunder__ name, False otherwise.""" | |||
return ( | |||
name[:2] == name[-2:] == "__" | |||
and name[2:3] != "_" | |||
and name[-3:-2] != "_" | |||
and len(name) > 4 | |||
) | |||
def _is_sunder(name): | |||
"""Returns True if a _sunder_ name, False otherwise.""" | |||
return ( | |||
name[0] == name[-1] == "_" | |||
and name[1:2] != "_" | |||
and name[-2:-1] != "_" | |||
and len(name) > 2 | |||
) | |||
def _make_class_unpicklable(cls): | |||
"""Make the given class un-picklable.""" | |||
def _break_on_call_reduce(self, proto): | |||
raise TypeError("%r cannot be pickled" % self) | |||
cls.__reduce_ex__ = _break_on_call_reduce | |||
cls.__module__ = "<unknown>" | |||
_auto_null = object() | |||
class auto: | |||
""" | |||
Instances are replaced with an appropriate value in Enum class suites. | |||
""" | |||
value = _auto_null | |||
class _EnumDict(dict): | |||
"""Track enum member order and ensure member names are not reused. | |||
EnumMeta will use the names found in self._member_names as the | |||
enumeration member names. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self._member_names = [] | |||
self._last_values = [] | |||
def __setitem__(self, key, value): | |||
"""Changes anything not dundered or not a descriptor. | |||
If an enum member name is used twice, an error is raised; duplicate | |||
values are not checked for. | |||
Single underscore (sunder) names are reserved. | |||
""" | |||
if _is_sunder(key): | |||
if key not in ( | |||
"_order_", | |||
"_create_pseudo_member_", | |||
"_generate_next_value_", | |||
"_missing_", | |||
): | |||
raise ValueError("_names_ are reserved for future Enum use") | |||
if key == "_generate_next_value_": | |||
setattr(self, "_generate_next_value", value) | |||
elif _is_dunder(key): | |||
if key == "__order__": | |||
key = "_order_" | |||
elif key in self._member_names: | |||
# descriptor overwriting an enum? | |||
raise TypeError("Attempted to reuse key: %r" % key) | |||
elif not _is_descriptor(value): | |||
if key in self: | |||
# enum overwriting a descriptor? | |||
raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
if isinstance(value, auto): | |||
if value.value == _auto_null: | |||
value.value = self._generate_next_value( | |||
key, 1, len(self._member_names), self._last_values[:] | |||
) | |||
value = value.value | |||
self._member_names.append(key) | |||
self._last_values.append(value) | |||
super().__setitem__(key, value) | |||
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
# until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
# This is also why there are checks in EnumMeta like `if Enum is not None` | |||
Enum = None | |||
class EnumMeta(type): | |||
"""Metaclass for Enum""" | |||
@classmethod | |||
def __prepare__(metacls, cls, bases): | |||
# create the namespace dict | |||
enum_dict = _EnumDict() | |||
# inherit previous flags and _generate_next_value_ function | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
if first_enum is not None: | |||
enum_dict["_generate_next_value_"] = getattr( | |||
first_enum, "_generate_next_value_", None | |||
) | |||
return enum_dict | |||
def __new__(metacls, cls, bases, classdict): | |||
# an Enum class is final once enumeration items have been defined; it | |||
# cannot be mixed with other types (int, float, etc.) if it has an | |||
# inherited __new__ unless a new __new__ is defined (or the resulting | |||
# class will fail). | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
__new__, save_new, use_args = metacls._find_new_( | |||
classdict, member_type, first_enum | |||
) | |||
# save enum items into separate mapping so they don't get baked into | |||
# the new class | |||
enum_members = {k: classdict[k] for k in classdict._member_names} | |||
for name in classdict._member_names: | |||
del classdict[name] | |||
# adjust the sunders | |||
_order_ = classdict.pop("_order_", None) | |||
# check for illegal enum names (any others?) | |||
invalid_names = set(enum_members) & { | |||
"mro", | |||
} | |||
if invalid_names: | |||
raise ValueError( | |||
"Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
) | |||
# create a default docstring if one has not been provided | |||
if "__doc__" not in classdict: | |||
classdict["__doc__"] = "An enumeration." | |||
# create our new Enum type | |||
enum_class = super().__new__(metacls, cls, bases, classdict) | |||
enum_class._member_names_ = [] # names in definition order | |||
enum_class._member_map_ = OrderedDict() # name->value map | |||
enum_class._member_type_ = member_type | |||
# save attributes from super classes so we know if we can take | |||
# the shortcut of storing members in the class dict | |||
base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
# Reverse value->name map for hashable values. | |||
enum_class._value2member_map_ = {} | |||
# If a custom type is mixed into the Enum, and it does not know how | |||
# to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
# fail. Rather than have the error show up later and possibly far | |||
# from the source, sabotage the pickle protocol for this class so | |||
# that pickle.dumps also fails. | |||
# | |||
# However, if the new class implements its own __reduce_ex__, do not | |||
# sabotage -- it's on them to make sure it works correctly. We use | |||
# __reduce_ex__ instead of any of the others as it is preferred by | |||
# pickle over __reduce__, and it handles all pickle protocols. | |||
if "__reduce_ex__" not in classdict: | |||
if member_type is not object: | |||
methods = ( | |||
"__getnewargs_ex__", | |||
"__getnewargs__", | |||
"__reduce_ex__", | |||
"__reduce__", | |||
) | |||
if not any(m in member_type.__dict__ for m in methods): | |||
_make_class_unpicklable(enum_class) | |||
# instantiate them, checking for duplicates as we go | |||
# we instantiate first instead of checking for duplicates first in case | |||
# a custom __new__ is doing something funky with the values -- such as | |||
# auto-numbering ;) | |||
for member_name in classdict._member_names: | |||
value = enum_members[member_name] | |||
if not isinstance(value, tuple): | |||
args = (value,) | |||
else: | |||
args = value | |||
if member_type is tuple: # special case for tuple enums | |||
args = (args,) # wrap it one more time | |||
if not use_args: | |||
enum_member = __new__(enum_class) | |||
if not hasattr(enum_member, "_value_"): | |||
enum_member._value_ = value | |||
else: | |||
enum_member = __new__(enum_class, *args) | |||
if not hasattr(enum_member, "_value_"): | |||
if member_type is object: | |||
enum_member._value_ = value | |||
else: | |||
enum_member._value_ = member_type(*args) | |||
value = enum_member._value_ | |||
enum_member._name_ = member_name | |||
enum_member.__objclass__ = enum_class | |||
enum_member.__init__(*args) | |||
# If another member with the same value was already defined, the | |||
# new member becomes an alias to the existing one. | |||
for name, canonical_member in enum_class._member_map_.items(): | |||
if canonical_member._value_ == enum_member._value_: | |||
enum_member = canonical_member | |||
break | |||
else: | |||
# Aliases don't appear in member names (only in __members__). | |||
enum_class._member_names_.append(member_name) | |||
# performance boost for any member that would not shadow | |||
# a DynamicClassAttribute | |||
if member_name not in base_attributes: | |||
setattr(enum_class, member_name, enum_member) | |||
# now add to _member_map_ | |||
enum_class._member_map_[member_name] = enum_member | |||
try: | |||
# This may fail if value is not hashable. We can't add the value | |||
# to the map, and by-value lookups for this value will be | |||
# linear. | |||
enum_class._value2member_map_[value] = enum_member | |||
except TypeError: | |||
pass | |||
# double check that repr and friends are not the mixin's or various | |||
# things break (such as pickle) | |||
for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
class_method = getattr(enum_class, name) | |||
obj_method = getattr(member_type, name, None) | |||
enum_method = getattr(first_enum, name, None) | |||
if obj_method is not None and obj_method is class_method: | |||
setattr(enum_class, name, enum_method) | |||
# replace any other __new__ with our own (as long as Enum is not None, | |||
# anyway) -- again, this is to support pickle | |||
if Enum is not None: | |||
# if the user defined their own __new__, save it before it gets | |||
# clobbered in case they subclass later | |||
if save_new: | |||
enum_class.__new_member__ = __new__ | |||
enum_class.__new__ = Enum.__new__ | |||
# py3 support for definition order (helps keep py2/py3 code in sync) | |||
if _order_ is not None: | |||
if isinstance(_order_, str): | |||
_order_ = _order_.replace(",", " ").split() | |||
if _order_ != enum_class._member_names_: | |||
raise TypeError("member order does not match _order_") | |||
return enum_class | |||
def __bool__(self): | |||
""" | |||
classes/types should always be True. | |||
""" | |||
return True | |||
def __call__( | |||
cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
"""Either returns an existing member, or creates a new enum class. | |||
This method is used both when an enum class is given a value to match | |||
to an enumeration member (i.e. Color(3)) and for the functional API | |||
(i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
When used for the functional API: | |||
`value` will be the name of the new class. | |||
`names` should be either a string of white-space/comma delimited names | |||
(values will start at `start`), or an iterator/mapping of name, value pairs. | |||
`module` should be set to the module this class is being created in; | |||
if it is not set, an attempt to find that module will be made, but if | |||
it fails the class will not be picklable. | |||
`qualname` should be set to the actual location this class can be found | |||
at in its module; by default it is set to the global scope. If this is | |||
not correct, unpickling will fail in some circumstances. | |||
`type`, if set, will be mixed in as the first base class. | |||
""" | |||
if names is None: # simple value lookup | |||
return cls.__new__(cls, value) | |||
# otherwise, functional API: we're creating a new Enum type | |||
return cls._create_( | |||
value, names, module=module, qualname=qualname, type=type, start=start | |||
) | |||
def __contains__(cls, member): | |||
return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
def __delattr__(cls, attr): | |||
# nicer error message when someone tries to delete an attribute | |||
# (see issue19025). | |||
if attr in cls._member_map_: | |||
raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
super().__delattr__(attr) | |||
def __dir__(self): | |||
return [ | |||
"__class__", | |||
"__doc__", | |||
"__members__", | |||
"__module__", | |||
] + self._member_names_ | |||
def __getattr__(cls, name): | |||
"""Return the enum member matching `name` | |||
We use __getattr__ instead of descriptors or inserting into the enum | |||
class' __dict__ in order to support `name` and `value` being both | |||
properties for enum members (which live in the class' __dict__) and | |||
enum members themselves. | |||
""" | |||
if _is_dunder(name): | |||
raise AttributeError(name) | |||
try: | |||
return cls._member_map_[name] | |||
except KeyError: | |||
raise AttributeError(name) from None | |||
def __getitem__(cls, name): | |||
return cls._member_map_[name] | |||
def __iter__(cls): | |||
return (cls._member_map_[name] for name in cls._member_names_) | |||
def __len__(cls): | |||
return len(cls._member_names_) | |||
@property | |||
def __members__(cls): | |||
"""Returns a mapping of member name->value. | |||
This mapping lists all enum members, including aliases. Note that this | |||
is a read-only view of the internal mapping. | |||
""" | |||
return MappingProxyType(cls._member_map_) | |||
def __repr__(cls): | |||
return "<enum %r>" % cls.__name__ | |||
def __reversed__(cls): | |||
return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
def __setattr__(cls, name, value): | |||
"""Block attempts to reassign Enum members. | |||
A simple assignment to the class namespace only changes one of the | |||
several possible ways to get an Enum member from the Enum class, | |||
resulting in an inconsistent Enumeration. | |||
""" | |||
member_map = cls.__dict__.get("_member_map_", {}) | |||
if name in member_map: | |||
raise AttributeError("Cannot reassign members.") | |||
super().__setattr__(name, value) | |||
def _create_( | |||
cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
"""Convenience method to create a new Enum class. | |||
`names` can be: | |||
* A string containing member names, separated either with spaces or | |||
commas. Values are incremented by 1 from `start`. | |||
* An iterable of member names. Values are incremented by 1 from `start`. | |||
* An iterable of (member name, value) pairs. | |||
* A mapping of member name -> value pairs. | |||
""" | |||
metacls = cls.__class__ | |||
bases = (cls,) if type is None else (type, cls) | |||
_, first_enum = cls._get_mixins_(bases) | |||
classdict = metacls.__prepare__(class_name, bases) | |||
# special processing needed for names? | |||
if isinstance(names, str): | |||
names = names.replace(",", " ").split() | |||
if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
original_names, names = names, [] | |||
last_values = [] | |||
for count, name in enumerate(original_names): | |||
value = first_enum._generate_next_value_( | |||
name, start, count, last_values[:] | |||
) | |||
last_values.append(value) | |||
names.append((name, value)) | |||
# Here, names is either an iterable of (name, value) or a mapping. | |||
for item in names: | |||
if isinstance(item, str): | |||
member_name, member_value = item, names[item] | |||
else: | |||
member_name, member_value = item | |||
classdict[member_name] = member_value | |||
enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
# TODO: replace the frame hack if a blessed way to know the calling | |||
# module is ever developed | |||
if module is None: | |||
try: | |||
module = sys._getframe(2).f_globals["__name__"] | |||
except (AttributeError, ValueError) as exc: | |||
pass | |||
if module is None: | |||
_make_class_unpicklable(enum_class) | |||
else: | |||
enum_class.__module__ = module | |||
if qualname is not None: | |||
enum_class.__qualname__ = qualname | |||
return enum_class | |||
@staticmethod | |||
def _get_mixins_(bases): | |||
"""Returns the type for creating enum members, and the first inherited | |||
enum class. | |||
bases: the tuple of bases that was given to __new__ | |||
""" | |||
if not bases: | |||
return object, Enum | |||
# double check that we are not subclassing a class with existing | |||
# enumeration members; while we're at it, see if any other data | |||
# type has been mixed in so we can use the correct __new__ | |||
member_type = first_enum = None | |||
for base in bases: | |||
if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
raise TypeError("Cannot extend enumerations") | |||
# base is now the last base in bases | |||
if not issubclass(base, Enum): | |||
raise TypeError( | |||
"new enumerations must be created as " | |||
"`ClassName([mixin_type,] enum_type)`" | |||
) | |||
# get correct mix-in type (either mix-in type of Enum subclass, or | |||
# first base if last base is Enum) | |||
if not issubclass(bases[0], Enum): | |||
member_type = bases[0] # first data type | |||
first_enum = bases[-1] # enum type | |||
else: | |||
for base in bases[0].__mro__: | |||
# most common: (IntEnum, int, Enum, object) | |||
# possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
# <class 'int'>, <Enum 'Enum'>, | |||
# <class 'object'>) | |||
if issubclass(base, Enum): | |||
if first_enum is None: | |||
first_enum = base | |||
else: | |||
if member_type is None: | |||
member_type = base | |||
return member_type, first_enum | |||
@staticmethod | |||
def _find_new_(classdict, member_type, first_enum): | |||
"""Returns the __new__ to be used for creating the enum members. | |||
classdict: the class dictionary given to __new__ | |||
member_type: the data type whose __new__ will be used by default | |||
first_enum: enumeration to check for an overriding __new__ | |||
""" | |||
# now find the correct __new__, checking to see of one was defined | |||
# by the user; also check earlier enum classes in case a __new__ was | |||
# saved as __new_member__ | |||
__new__ = classdict.get("__new__", None) | |||
# should __new__ be saved as __new_member__ later? | |||
save_new = __new__ is not None | |||
if __new__ is None: | |||
# check all possibles for __new_member__ before falling back to | |||
# __new__ | |||
for method in ("__new_member__", "__new__"): | |||
for possible in (member_type, first_enum): | |||
target = getattr(possible, method, None) | |||
if target not in { | |||
None, | |||
None.__new__, | |||
object.__new__, | |||
Enum.__new__, | |||
}: | |||
__new__ = target | |||
break | |||
if __new__ is not None: | |||
break | |||
else: | |||
__new__ = object.__new__ | |||
# if a non-object.__new__ is used then whatever value/tuple was | |||
# assigned to the enum member name will be passed to __new__ and to the | |||
# new enum member's __init__ | |||
if __new__ is object.__new__: | |||
use_args = False | |||
else: | |||
use_args = True | |||
return __new__, save_new, use_args | |||
class Enum(metaclass=EnumMeta): | |||
"""Generic enumeration. | |||
Derive from this class to define new enumerations. | |||
""" | |||
def __new__(cls, value): | |||
# all enum instances are actually created during class construction | |||
# without calling this method; this method is called by the metaclass' | |||
# __call__ (i.e. Color(3) ), and by pickle | |||
if type(value) is cls: | |||
# For lookups like Color(Color.RED) | |||
return value | |||
# by-value search for a matching enum member | |||
# see if it's in the reverse mapping (for hashable values) | |||
try: | |||
if value in cls._value2member_map_: | |||
return cls._value2member_map_[value] | |||
except TypeError: | |||
# not there, now do long search -- O(n) behavior | |||
for member in cls._member_map_.values(): | |||
if member._value_ == value: | |||
return member | |||
# still not found -- try _missing_ hook | |||
return cls._missing_(value) | |||
def _generate_next_value_(name, start, count, last_values): | |||
for last_value in reversed(last_values): | |||
try: | |||
return last_value + 1 | |||
except TypeError: | |||
pass | |||
else: | |||
return start | |||
@classmethod | |||
def _missing_(cls, value): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
def __repr__(self): | |||
return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
def __str__(self): | |||
return "%s.%s" % (self.__class__.__name__, self._name_) | |||
def __dir__(self): | |||
added_behavior = [ | |||
m | |||
for cls in self.__class__.mro() | |||
for m in cls.__dict__ | |||
if m[0] != "_" and m not in self._member_map_ | |||
] | |||
return ["__class__", "__doc__", "__module__"] + added_behavior | |||
def __format__(self, format_spec): | |||
# mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
# we can get strange results with the Enum name showing up instead of | |||
# the value | |||
# pure Enum branch | |||
if self._member_type_ is object: | |||
cls = str | |||
val = str(self) | |||
# mix-in branch | |||
else: | |||
cls = self._member_type_ | |||
val = self._value_ | |||
return cls.__format__(val, format_spec) | |||
def __hash__(self): | |||
return hash(self._name_) | |||
def __reduce_ex__(self, proto): | |||
return self.__class__, (self._value_,) | |||
# DynamicClassAttribute is used to provide access to the `name` and | |||
# `value` properties of enum members while keeping some measure of | |||
# protection from modification, while still allowing for an enumeration | |||
# to have members named `name` and `value`. This works because enumeration | |||
# members are not set directly on the enum class -- __getattr__ is | |||
# used to look them up. | |||
@DynamicClassAttribute | |||
def name(self): | |||
"""The name of the Enum member.""" | |||
return self._name_ | |||
@DynamicClassAttribute | |||
def value(self): | |||
"""The value of the Enum member.""" | |||
return self._value_ | |||
@classmethod | |||
def _convert(cls, name, module, filter, source=None): | |||
""" | |||
Create a new Enum subclass that replaces a collection of global constants | |||
""" | |||
# convert all constants from source (or module) that pass filter() to | |||
# a new Enum called name, and export the enum and its members back to | |||
# module; | |||
# also, replace the __reduce_ex__ method so unpickling works in | |||
# previous Python versions | |||
module_globals = vars(sys.modules[module]) | |||
if source: | |||
source = vars(source) | |||
else: | |||
source = module_globals | |||
# We use an OrderedDict of sorted source keys so that the | |||
# _value2member_map is populated in the same order every time | |||
# for a consistent reverse mapping of number to name when there | |||
# are multiple names for the same number rather than varying | |||
# between runs due to hash randomization of the module dictionary. | |||
members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
try: | |||
# sort by value | |||
members.sort(key=lambda t: (t[1], t[0])) | |||
except TypeError: | |||
# unless some values aren't comparable, in which case sort by name | |||
members.sort(key=lambda t: t[0]) | |||
cls = cls(name, members, module=module) | |||
cls.__reduce_ex__ = _reduce_ex_by_name | |||
module_globals.update(cls.__members__) | |||
module_globals[name] = cls | |||
return cls | |||
class IntEnum(int, Enum): | |||
"""Enum where members are also (and must be) ints""" | |||
def _reduce_ex_by_name(self, proto): | |||
return self.name | |||
class Flag(Enum): | |||
"""Support for flags""" | |||
def _generate_next_value_(name, start, count, last_values): | |||
""" | |||
Generate the next value when not given. | |||
name: the name of the member | |||
start: the initital start value or None | |||
count: the number of existing members | |||
last_value: the last value assigned or None | |||
""" | |||
if not count: | |||
return start if start is not None else 1 | |||
for last_value in reversed(last_values): | |||
try: | |||
high_bit = _high_bit(last_value) | |||
break | |||
except Exception: | |||
raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
return 2 ** (high_bit + 1) | |||
@classmethod | |||
def _missing_(cls, value): | |||
original_value = value | |||
if value < 0: | |||
value = ~value | |||
possible_member = cls._create_pseudo_member_(value) | |||
if original_value < 0: | |||
possible_member = ~possible_member | |||
return possible_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
""" | |||
Create a composite member iff value contains only members. | |||
""" | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
# verify all bits are accounted for | |||
_, extra_flags = _decompose(cls, value) | |||
if extra_flags: | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
# construct a singleton enum pseudo-member | |||
pseudo_member = object.__new__(cls) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __contains__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return other._value_ & self._value_ == other._value_ | |||
def __repr__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
return "<%s.%s: %r>" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
self._value_, | |||
) | |||
def __str__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "%s.%s" % (cls.__name__, self._name_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
if len(members) == 1 and members[0]._name_ is None: | |||
return "%s.%r" % (cls.__name__, members[0]._value_) | |||
else: | |||
return "%s.%s" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
) | |||
def __bool__(self): | |||
return bool(self._value_) | |||
def __or__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ | other._value_) | |||
def __and__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ & other._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ other._value_) | |||
def __invert__(self): | |||
members, uncovered = _decompose(self.__class__, self._value_) | |||
inverted_members = [ | |||
m | |||
for m in self.__class__ | |||
if m not in members and not m._value_ & self._value_ | |||
] | |||
inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
return self.__class__(inverted) | |||
class IntFlag(int, Flag): | |||
"""Support for integer-based Flags""" | |||
@classmethod | |||
def _missing_(cls, value): | |||
if not isinstance(value, int): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
new_member = cls._create_pseudo_member_(value) | |||
return new_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
need_to_create = [value] | |||
# get unaccounted for bits | |||
_, extra_flags = _decompose(cls, value) | |||
# timer = 10 | |||
while extra_flags: | |||
# timer -= 1 | |||
bit = _high_bit(extra_flags) | |||
flag_value = 2 ** bit | |||
if ( | |||
flag_value not in cls._value2member_map_ | |||
and flag_value not in need_to_create | |||
): | |||
need_to_create.append(flag_value) | |||
if extra_flags == -flag_value: | |||
extra_flags = 0 | |||
else: | |||
extra_flags ^= flag_value | |||
for value in reversed(need_to_create): | |||
# construct singleton pseudo-members | |||
pseudo_member = int.__new__(cls, value) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __or__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
return result | |||
def __and__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ & self.__class__(other)._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
__ror__ = __or__ | |||
__rand__ = __and__ | |||
__rxor__ = __xor__ | |||
def __invert__(self): | |||
result = self.__class__(~self._value_) | |||
return result | |||
def _high_bit(value): | |||
"""returns index of highest bit, or -1 if value is zero or negative""" | |||
return value.bit_length() - 1 | |||
def unique(enumeration): | |||
"""Class decorator for enumerations ensuring unique member values.""" | |||
duplicates = [] | |||
for name, member in enumeration.__members__.items(): | |||
if name != member.name: | |||
duplicates.append((name, member.name)) | |||
if duplicates: | |||
alias_details = ", ".join( | |||
["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
) | |||
raise ValueError( | |||
"duplicate values found in %r: %s" % (enumeration, alias_details) | |||
) | |||
return enumeration | |||
def _decompose(flag, value): | |||
"""Extract all members from the value.""" | |||
# _decompose is only called if the value is not named | |||
not_covered = value | |||
negative = value < 0 | |||
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
# conditions between iterating over it and having more psuedo- | |||
# members added to it | |||
if negative: | |||
# only check for named flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None | |||
] | |||
else: | |||
# check for named flags and powers-of-two flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None or _power_of_two(v) | |||
] | |||
members = [] | |||
for member, member_value in flags_to_check: | |||
if member_value and member_value & value == member_value: | |||
members.append(member) | |||
not_covered &= ~member_value | |||
if not members and value in flag._value2member_map_: | |||
members.append(flag._value2member_map_[value]) | |||
members.sort(key=lambda m: m._value_, reverse=True) | |||
if len(members) > 1 and members[0].value == value: | |||
# we have the breakdown, don't need the value member itself | |||
members.pop(0) | |||
return members, not_covered | |||
def _power_of_two(value): | |||
if value < 1: | |||
return False | |||
return value == 2 ** _high_bit(value) |
@@ -1,58 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""exception handling""" | |||
from . import mgb as _mgb | |||
class MegBrainError(Exception): | |||
"""exception class used by megbrain library""" | |||
tracker = None | |||
"""the tracker setup by :func:`.set_exc_opr_tracker` when the related | |||
operator is created""" | |||
tracker_grad_orig = None | |||
"""if this operator is created by taking gradient, this var would be the | |||
tracker of the operator that causes the grad.""" | |||
def __init__(self, msg, tracker, tracker_grad_orig): | |||
assert isinstance(msg, str) | |||
super().__init__(msg, tracker, tracker_grad_orig) | |||
self.tracker = tracker | |||
self.tracker_grad_orig = tracker_grad_orig | |||
@classmethod | |||
def _format_tracker(cls, tracker): | |||
return ("| " + i for i in str(tracker).split("\n")) | |||
def __str__(self): | |||
lines = [] | |||
lines.extend(self.args[0].split("\n")) | |||
if self.tracker is not None: | |||
lines.append("Exception tracker:") | |||
lines.extend(self._format_tracker(self.tracker)) | |||
if self.tracker_grad_orig is not None: | |||
lines.append( | |||
"Exception caused by taking grad of another operator with tracker:" | |||
) | |||
lines.extend(self._format_tracker(self.tracker_grad_orig)) | |||
while not lines[-1].strip(): | |||
lines.pop() | |||
for idx, ct in enumerate(lines): | |||
if ct.startswith("bt:"): | |||
lines[idx] = "+ " + lines[idx] | |||
for t in range(idx + 1, len(lines)): | |||
lines[t] = "| " + lines[t] | |||
break | |||
return "\n".join(lines) | |||
_mgb._reg_exception_class(MegBrainError) |
@@ -1,41 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""global initialization work; classes/functions defined in this module should | |||
not be used by user code""" | |||
import atexit | |||
import os | |||
import sys | |||
import traceback | |||
from . import mgb | |||
from .logconf import get_logger | |||
from .persistent_cache import PersistentCacheOnServer | |||
class PyStackExtracterImpl(mgb._PyStackExtracter): | |||
def extract(self): | |||
return "".join(traceback.format_stack()[:-1]) | |||
mgb._register_logger(get_logger()) | |||
assert sys.executable | |||
mgb._timed_func_set_fork_exec_path( | |||
sys.executable, | |||
os.path.join(os.path.dirname(__file__), "_timed_func_fork_exec_entry.py"), | |||
) | |||
persistent_cache_impl_ins = PersistentCacheOnServer() | |||
mgb._PersistentCache.reg(persistent_cache_impl_ins) | |||
PyStackExtracterImplIns = PyStackExtracterImpl() | |||
PyStackExtracterImpl.reg(PyStackExtracterImplIns) | |||
atexit.register(mgb._mgb_global_finalize) |
@@ -1,316 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import numpy as np | |||
from . import mgb | |||
from .exc import MegBrainError | |||
from .mgb import SharedND, SymbolVar | |||
from .opr_param_defs import OptionalAxisV1 | |||
def canonize_reshape(inputs, *, comp_graph, config): | |||
src, tshape = inputs | |||
tshape = cvt_to_shape_desc(tshape, src, comp_graph, config) | |||
return src, tshape | |||
def canonize_shape_input(inputs, *, comp_graph, config): | |||
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 | |||
return [cvt_to_shape_desc(inputs[0], None, comp_graph, config)] | |||
def cvt_to_shape_desc(val, inpvar, graph, config): | |||
"""convert some python object to a :class:`SymbolVar` that describes tensor | |||
shape | |||
:param val: the python object to be converted from | |||
:param inpvar, graph, config: provide graph and comp node information; can | |||
be None if not known. Either input or (graph, config) must be provided. | |||
:return: a new var corresponding to *val* | |||
:rtype: :class:`.SymbolVar` | |||
""" | |||
if hasattr(val, "__mgb_symvar__"): | |||
val = val.__mgb_symvar__() | |||
elif hasattr(val, "symvar"): | |||
val = val.symvar | |||
if isinstance(val, SymbolVar): | |||
return val | |||
if not isinstance(val, collections.Iterable): | |||
val = [val] | |||
components = [] | |||
has_sym = False | |||
for i in val: | |||
if hasattr(i, "__mgb_symvar__"): | |||
i = i.__mgb_symvar__() | |||
elif hasattr(i, "symvar"): | |||
i = i.symvar | |||
if isinstance(i, SymbolVar): | |||
has_sym = True | |||
components.append(i) | |||
else: | |||
assert isinstance(i, int), ( | |||
"shape desc could contain either int or SymbolVar, got {}" | |||
" actually".format(repr(i)) | |||
) | |||
components.append(i) | |||
assert components, "shape desc could not be empty" | |||
if inpvar is not None: | |||
assert isinstance(inpvar, SymbolVar) | |||
if graph is None: | |||
graph = inpvar.owner_graph | |||
else: | |||
assert graph == inpvar.owner_graph | |||
config = mgb.make_opr_config(comp_node=inpvar.comp_node) | |||
else: | |||
assert isinstance(graph, mgb.CompGraph), "graph must be provided" | |||
assert isinstance(config, mgb.OperatorNodeConfig) | |||
if not has_sym: | |||
shape = np.ascontiguousarray(components, dtype=np.int32) | |||
assert np.all(shape == components), "failed to convert to shape: {}".format( | |||
components | |||
) | |||
return mgb._make_immutable(graph, shape, None, config) | |||
for idx, v in enumerate(components): | |||
if not isinstance(v, SymbolVar): | |||
vi = int(v) | |||
assert vi == v, "could not convert {} to int".format(v) | |||
components[idx] = mgb._make_immutable(graph, vi, None, config) | |||
from . import opr as O | |||
return O.concat(components, axis=0, config=config) | |||
def canonize_input_vars(inputs, *, comp_graph, config): | |||
"""convert immediate numbers and SharedND to SymbolVar in inputs; at least | |||
one of the inputs must be SymbolVar, so comp node and comp graph can | |||
beinferred | |||
:return: list of converted vars | |||
""" | |||
from . import make_immutable | |||
if ( | |||
isinstance(inputs, (list, tuple)) | |||
and len(inputs) == 1 | |||
and isinstance(inputs[0], (list, tuple)) | |||
): | |||
# handle the case when a list is passed to a function with | |||
# variable-length argument (e.g. concat has signature concat(*inputs) | |||
# and is called with concat([a, b])) | |||
inputs = inputs[0] | |||
if isinstance(inputs, SymbolVar): | |||
return [inputs] | |||
old_inputs = inputs | |||
inputs = [] | |||
get_comp_node = None | |||
need_cvt = False | |||
for i in old_inputs: | |||
if isinstance(i, SymbolVar): | |||
get_comp_node = lambda cn=i.comp_node: cn | |||
if comp_graph is not None: | |||
assert comp_graph == i.owner_graph | |||
else: | |||
comp_graph = i.owner_graph | |||
else: | |||
need_cvt = True | |||
inputs.append(i) | |||
if not need_cvt: | |||
return inputs | |||
if get_comp_node is None: | |||
def get_comp_node(): | |||
nonlocal get_comp_node | |||
cn = config.require_comp_node() | |||
get_comp_node = lambda: cn | |||
return cn | |||
for idx, var in enumerate(inputs): | |||
if not isinstance(var, SymbolVar): | |||
if isinstance(var, SharedND): | |||
var = var.symvar(comp_graph) | |||
elif isinstance(var, mgb.SharedScalar): | |||
var = var._as_sym_var(comp_graph, get_comp_node()) | |||
elif hasattr(var, "__mgb_symvar__"): | |||
try: | |||
cn = get_comp_node() | |||
except MegBrainError: | |||
cn = None | |||
var = var.__mgb_symvar__(comp_graph=comp_graph, comp_node=cn) | |||
elif hasattr(var, "symvar"): | |||
var = var.symvar | |||
else: | |||
var = make_immutable(get_comp_node(), comp_graph, var) | |||
inputs[idx] = var | |||
return inputs | |||
def cvt_to_vector_of_shape(shapes): | |||
"""convert ``[[int]]`` to nested ``std::vector`` of ``size_t``""" | |||
ret = mgb._VectorTensorShape() | |||
for i in shapes: | |||
val = tuple(i) | |||
assert val and all( | |||
j > 0 and isinstance(j, int) for j in val | |||
), "something returns bad shape in infer_shape(): {}".format(val) | |||
ret.push_back(val) | |||
return ret | |||
def cvt_to_opr_param_def(param, ptype, kwargs): | |||
if param is not None: | |||
if isinstance(param, ptype): | |||
return param | |||
param = [param] | |||
assert len(param) == len( | |||
ptype.__slots__ | |||
), "{} needs {} params, but {} are provided".format( | |||
ptype, len(ptype.__slots__), len(param) | |||
) | |||
return ptype(*param) | |||
ckw = {} | |||
for i in ptype.__slots__: | |||
val = kwargs.pop(i, ckw) | |||
if val is not ckw: | |||
ckw[i] = val | |||
return ptype(**ckw) | |||
def cvt_getitem_to_idx_desc(inpvar, tuple_val, *, allow_newaxis=True): | |||
"""convert ``__getitem__`` args to index desc | |||
:return: ``(new_var, index_desc)`` where new_var is inpvar with | |||
``np.newaxis`` applied; note that ``index_desc`` can be ``None``. | |||
""" | |||
assert isinstance(inpvar, SymbolVar), "bad input: {!r}".format(inpvar) | |||
if not isinstance(tuple_val, tuple): | |||
tuple_val = (tuple_val,) | |||
axis_indexer = mgb._VectorAxisIndexer() | |||
config = mgb.make_opr_config(comp_node=inpvar.comp_node) | |||
graph = inpvar.owner_graph | |||
def as_symvar(v, *, allow_list=True): | |||
if isinstance(v, SymbolVar): | |||
return v | |||
vi = np.ascontiguousarray(v, dtype=np.int32) | |||
assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v) | |||
return mgb._make_immutable(graph, vi, None, config) | |||
def _s(v): # convert slice item | |||
if v is None: | |||
return SymbolVar() | |||
return as_symvar(v, allow_list=False) | |||
new_axes = [] | |||
cur_axis = -1 | |||
for i_idx, i in enumerate(tuple_val): | |||
cur_axis += 1 | |||
if i is np.newaxis: | |||
if cur_axis >= 0: | |||
new_axes.append(cur_axis) | |||
continue | |||
if i is Ellipsis: | |||
cur_axis = -1 | |||
for j in tuple_val[:i_idx:-1]: | |||
if j is Ellipsis: | |||
raise IndexError("only one ellipsis is allowed") | |||
if j is np.newaxis: | |||
new_axes.append(cur_axis) | |||
cur_axis -= 1 | |||
continue | |||
if isinstance(i, slice): | |||
if i.start is None and i.stop is None and i.step is None: | |||
continue | |||
cur = mgb._AxisIndexer.make_interval( | |||
cur_axis, _s(i.start), _s(i.stop), _s(i.step) | |||
) | |||
else: | |||
cur = mgb._AxisIndexer.make_index(cur_axis, as_symvar(i)) | |||
axis_indexer.push_back(cur) | |||
if new_axes: | |||
if not allow_newaxis: | |||
raise IndexError("newaxis is not allowed here") | |||
inpvar = mgb._Opr.add_axis(inpvar, new_axes, mgb.make_opr_config()) | |||
if axis_indexer.empty(): | |||
axis_indexer = None | |||
return inpvar, axis_indexer | |||
def cvt_to_reshape_unspec_axis(unspec_axis, tshape): | |||
assert isinstance(unspec_axis, OptionalAxisV1), repr(unspec_axis) | |||
unspec_axis = unspec_axis.axis | |||
assert abs(unspec_axis) <= OptionalAxisV1.MAX_NDIM | |||
if not isinstance(tshape, SymbolVar): | |||
for idx, val in enumerate(tshape): | |||
if val == -1: | |||
assert ( | |||
unspec_axis == OptionalAxisV1.INVALID_AXIS | |||
), "multiple unknown dimensions for reshape" | |||
unspec_axis = idx | |||
return OptionalAxisV1(unspec_axis) | |||
def gen_config(name, comp_node, config, output_dtype=None): | |||
if config is None: | |||
config = mgb.make_opr_config(name, comp_node, output_dtype) | |||
else: | |||
assert isinstance(config, mgb.OperatorNodeConfig) | |||
assert name is None and comp_node is None | |||
return config | |||
def cvt_opr_result(rst, *, explode_single=True): | |||
""":param explode_single: whether to return the content of a single-item | |||
list rather thatn the list itself""" | |||
if not isinstance(rst, mgb.SymbolVar): | |||
assert isinstance(rst, (list, tuple)) | |||
if len(rst) == 1 and explode_single: | |||
return cvt_opr_result(rst[0]) | |||
return tuple(map(cvt_opr_result, rst)) | |||
if not rst.valid: | |||
return None | |||
# TODO Because the __init__ of SwigObject can not be modified to keep the | |||
# reference of graph, we get owner graph explicitly here. The correct | |||
# handling is moving the reference to SwigWrapper, but it is unsupported to | |||
# add a member variable to SwigWrapper, so we should wrap the SymbolVar | |||
# manually in megbrain_wrap.h | |||
rst.owner_graph | |||
f32 = np.float32 | |||
if not hasattr(cvt_opr_result, "_cvt_to_float32"): | |||
import os | |||
from .logconf import get_logger | |||
cvt_opr_result._cvt_to_float32 = os.getenv("MGB_ALL_FLOAT32") | |||
if cvt_opr_result._cvt_to_float32: | |||
get_logger().warn( | |||
"\n" | |||
"+=====================================================+\n" | |||
"| MGB_ALL_FLOAT32 is set, so all megbrain opr result |\n" | |||
"| would to converted to float32; this should only be |\n" | |||
"| used for loading old models. |\n" | |||
"+=====================================================+" | |||
) | |||
if cvt_opr_result._cvt_to_float32 and rst.dtype != f32: | |||
rst = rst.astype(f32) | |||
return rst |
@@ -1,54 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import logging | |||
import os | |||
_replaced_logger = None | |||
def get_logger(): | |||
global _replaced_logger | |||
if _replaced_logger is not None: | |||
return _replaced_logger | |||
logger = logging.getLogger("megbrain") | |||
logger.propagate = False | |||
logger.setLevel(logging.INFO) | |||
handler = logging.StreamHandler() | |||
handler.setFormatter(MgbLogFormatter(datefmt="%d %H:%M:%S")) | |||
handler.setLevel(0) | |||
del logger.handlers[:] | |||
logger.addHandler(handler) | |||
_replaced_logger = logger | |||
return logger | |||
class MgbLogFormatter(logging.Formatter): | |||
def format(self, record): | |||
date = "\x1b[32m[%(asctime)s %(lineno)d@%(filename)s:%(name)s]\x1b[0m" | |||
msg = "%(message)s" | |||
if record.levelno == logging.DEBUG: | |||
fmt = "{} \x1b[32mDBG\x1b[0m {}".format(date, msg) | |||
elif record.levelno == logging.WARNING: | |||
fmt = "{} \x1b[1;31mWRN\x1b[0m {}".format(date, msg) | |||
elif record.levelno == logging.ERROR: | |||
fmt = "{} \x1b[1;4;31mERR\x1b[0m {}".format(date, msg) | |||
else: | |||
fmt = date + " " + msg | |||
self._style._fmt = fmt | |||
return super().format(record) | |||
def set_logger(logger): | |||
"""replace the logger""" | |||
global _replaced_logger | |||
_replaced_logger = logger | |||
from .mgb import _register_logger | |||
_register_logger(logger) |
@@ -1,87 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""helper utils for the core mgb module""" | |||
import collections | |||
import inspect | |||
import json | |||
import threading | |||
from abc import ABCMeta, abstractmethod | |||
class callback_lazycopy: | |||
"""wraps around a callable to be passed to :meth:`.CompGraph.compile`. | |||
This is used to disable eager copy, so we could get rid of an h2d copy and | |||
a d2h if values are to be passed from one callback to another | |||
:class:`.SharedND`. | |||
""" | |||
def __init__(self, func): | |||
assert isinstance(func, collections.Callable) | |||
self.__func = func | |||
@property | |||
def func(self): | |||
return self.__func | |||
class SharedNDLazyInitializer(metaclass=ABCMeta): | |||
"""lazy initialization policy for :class:`.SharedND`""" | |||
@abstractmethod | |||
def get_shape(self): | |||
"""get shape, without loading value""" | |||
@abstractmethod | |||
def get_value(self): | |||
"""get value as numpy ndarray""" | |||
class copy_output: | |||
"""wraps a :class:`.SymbolVar` in outspec for :meth:`.CompGraph.compile`, | |||
to copy the output to function return value""" | |||
symvar = None | |||
borrow_mem = None | |||
def __init__(self, symvar, *, borrow_mem=False): | |||
""" | |||
:param borrow_mem: see :meth:`.CompGraphCallbackValueProxy.get_value` | |||
""" | |||
from .mgb import SymbolVar | |||
assert isinstance( | |||
symvar, SymbolVar | |||
), "copy_output expects an SymbolVar, got {} instead".format(symvar) | |||
self.symvar = symvar | |||
self.borrow_mem = borrow_mem | |||
class FuncOutputSaver: | |||
"""instance could be used as callbacks for :meth:`.CompGraph.compile` to | |||
copy output to host buffer | |||
""" | |||
_value = None | |||
_borrow_mem = None | |||
def __init__(self, borrow_mem=False): | |||
self._borrow_mem = borrow_mem | |||
def __call__(self, v): | |||
self._value = v.get_value(borrow_mem=self._borrow_mem) | |||
def get(self): | |||
assert ( | |||
self._value is not None | |||
), "{} not called; maybe due to unwaited async func".format(self) | |||
return self._value |
@@ -1,3 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright (c) 2015-2019 Megvii Inc. All rights reserved. | |||
@@ -1,90 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import argparse | |||
import getpass | |||
import json | |||
import os | |||
import shelve | |||
from .logconf import get_logger | |||
from .mgb import _PersistentCache | |||
from .version import __version__ | |||
class _FakeRedisConn: | |||
def __init__(self): | |||
try: | |||
from ..hub.hub import _get_megengine_home | |||
cache_dir = os.path.expanduser( | |||
os.path.join(_get_megengine_home(), "persistent_cache") | |||
) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
cache_file = os.path.join(cache_dir, "cache") | |||
self._dict = shelve.open(cache_file) | |||
self._is_shelve = True | |||
except: | |||
self._dict = {} | |||
self._is_shelve = False | |||
def get(self, key): | |||
if self._is_shelve and isinstance(key, bytes): | |||
key = key.decode("utf-8") | |||
return self._dict.get(key) | |||
def set(self, key, val): | |||
if self._is_shelve and isinstance(key, bytes): | |||
key = key.decode("utf-8") | |||
self._dict[key] = val | |||
def __del__(self): | |||
if self._is_shelve: | |||
self._dict.close() | |||
class PersistentCacheOnServer(_PersistentCache): | |||
_cached_conn = None | |||
_prefix = None | |||
_prev_get_refkeep = None | |||
@property | |||
def _conn(self): | |||
"""get redis connection""" | |||
if self._cached_conn is None: | |||
self._cached_conn = _FakeRedisConn() | |||
self._prefix = self.make_user_prefix() | |||
return self._cached_conn | |||
@classmethod | |||
def make_user_prefix(cls): | |||
return "mgbcache:{}".format(getpass.getuser()) | |||
def _make_key(self, category, key): | |||
prefix_with_version = "{}:MGB{}".format(self._prefix, __version__) | |||
return b"@".join( | |||
(prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||
) | |||
def put(self, category, key, value): | |||
conn = self._conn | |||
key = self._make_key(category, key) | |||
conn.set(key, value) | |||
def get(self, category, key): | |||
conn = self._conn | |||
key = self._make_key(category, key) | |||
self._prev_get_refkeep = conn.get(key) | |||
return self._prev_get_refkeep | |||
@@ -1,261 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""plugins associated with computing graph""" | |||
import atexit | |||
import collections | |||
import json | |||
import os | |||
import platform | |||
import signal | |||
import struct | |||
import numpy as np | |||
from . import mgb as _mgb | |||
from .logconf import get_logger | |||
InfkernFinderInputValueRec = collections.namedtuple( | |||
"InfkernFinderInputValueRec", ["var_name", "var_id", "run_id", "value"] | |||
) | |||
class CompGraphProfiler(_mgb._CompGraphProfilerImpl): | |||
"""a plugin to profile computing graphs""" | |||
def __init__(self, comp_graph): | |||
super().__init__(comp_graph) | |||
def get(self): | |||
"""get visualizable profiling result on a function""" | |||
return json.loads(self._get_result()) | |||
def write_json(self, fobj): | |||
"""write the result to a json file | |||
:param fobj: a file-like object, or a string | |||
""" | |||
if isinstance(fobj, str): | |||
with open(fobj, "w") as fout: | |||
return self.write_json(fout) | |||
fobj.write(self._get_result()) | |||
class NumRangeChecker(_mgb._NumRangeCheckerImpl): | |||
"""check that all numberical float values of variables in a computing graph | |||
are within given range""" | |||
def __init__(self, comp_graph, max_abs_val): | |||
""":param max_abs_val: max absolute value""" | |||
super().__init__(comp_graph, float(max_abs_val)) | |||
class TextOprIODump(_mgb._TextOprIODumpImpl): | |||
"""dump all internal results as text to a file""" | |||
def __init__(self, comp_graph, fpath, *, print_addr=None, max_size=None): | |||
super().__init__(comp_graph, fpath) | |||
if print_addr is not None: | |||
self.print_addr(print_addr) | |||
if max_size is not None: | |||
self.max_size(max_size) | |||
def print_addr(self, flag): | |||
"""set whether to print var address | |||
:return: self | |||
""" | |||
self._print_addr(flag) | |||
return self | |||
def max_size(self, size): | |||
"""set the number of elements to be printed for each var | |||
:return: self | |||
""" | |||
self._max_size(size) | |||
return self | |||
class BinaryOprIODump(_mgb._BinaryOprIODumpImpl): | |||
"""dump all internal results binary files to a directory; the values can be | |||
loaded by :func:`load_tensor_binary` | |||
""" | |||
def __init__(self, comp_graph, dir_path): | |||
super().__init__(comp_graph, dir_path) | |||
class InfkernFinder(_mgb._InfkernFinderImpl): | |||
"""a plugin to find kernels that cause infinite loops""" | |||
def __init__(self, comp_graph, record_input_value): | |||
""" | |||
:param record_input_value: whether need to record input var values of | |||
all operators | |||
:type record_input_value: bool | |||
""" | |||
super().__init__(comp_graph, record_input_value) | |||
def write_to_file(self, fpath): | |||
"""write current execution status to a text file | |||
:return: ID of the first operator that is still not finished, | |||
or None if all oprs are finished | |||
:rtype: int or None | |||
""" | |||
v = self._write_to_file(fpath) | |||
if v == 0: | |||
return | |||
return v - 1 | |||
def get_input_values(self, opr_id): | |||
"""get recorded input values of a given operator. Return a list | |||
of :class:`InfkernFinderInputValueRec`. Note that the value in | |||
each item is either None (if it is not recorded) or a numpy | |||
array | |||
""" | |||
ret = [] | |||
for idx in range(self._get_input_values_prepare(opr_id)): | |||
vn = self._get_input_values_var_name(idx) | |||
vi = self._get_input_values_var_idx(idx) | |||
ri = self._get_input_values_run_id(idx) | |||
val = self._get_input_values_val(idx) | |||
if not val.shape: | |||
val = None | |||
else: | |||
val = val.get_value() | |||
ret.append(InfkernFinderInputValueRec(vn, vi, ri, val)) | |||
return ret | |||
def fast_signal_hander(signum, callback): | |||
"""bypass python's signal handling system and registera handler that is | |||
called ASAP in a dedicated thread (in contrary, python calls handlers in | |||
the main thread) | |||
:param callback: signal callback, taking the signal number as its sole | |||
argument | |||
""" | |||
def cb_wrapped(): | |||
try: | |||
callback(signum) | |||
except: | |||
get_logger().exception("error calling signal handler for {}".format(signum)) | |||
_mgb._FastSignal.register_handler(signum, cb_wrapped) | |||
atexit.register(_mgb._FastSignal.shutdown) | |||
class GlobalInfkernFinder: | |||
""" | |||
manage a list of :class:`InfkernFinder` objects; when this process is | |||
signaled with SIGUSR1, an interactive IPython shell would be presented for | |||
further investigation | |||
""" | |||
_signal = None | |||
if platform.system() != "Windows": | |||
_signal = signal.SIGUSR1 | |||
else: | |||
_signal = signal.CTRL_C_EVENT | |||
_registry = [] | |||
_shell_maker = None | |||
@classmethod | |||
def add_graph(cls, comp_graph): | |||
"""register a graph so it can be tracked by :class:`InfkernFinder`""" | |||
enabled = os.getenv("MGB_DBG_INFKERN_FINDER") | |||
if not enabled: | |||
return | |||
if enabled == "1": | |||
record_input_value = False | |||
else: | |||
assert enabled == "2", ( | |||
"MGB_DBG_INFKERN_FINDER must be either 1 or 2, indicating " | |||
"whether to record input values" | |||
) | |||
record_input_value = True | |||
finder = InfkernFinder(comp_graph, record_input_value) | |||
get_logger().warning( | |||
"interactive InfkernFinder {} registered to graph {}; all input " | |||
"var values would be recorded and the graph would never be " | |||
"reclaimed. You can enter the interactive debug session by " | |||
'executing "kill -{} {}". record_input_value={}'.format( | |||
finder, comp_graph, cls._signal, os.getpid(), record_input_value | |||
) | |||
) | |||
if not cls._registry: | |||
from IPython.terminal.embed import InteractiveShellEmbed | |||
cls._shell_maker = InteractiveShellEmbed | |||
fast_signal_hander(cls._signal, cls._on_signal) | |||
cls._registry.append(finder) | |||
@classmethod | |||
def _on_signal(cls, signum): | |||
shell = cls._shell_maker() | |||
shell( | |||
header="Enter interactive InfkernFinder session; the registered " | |||
"finder objects can be found in variable f", | |||
local_ns={"f": cls._registry}, | |||
) | |||
def load_tensor_binary(fobj): | |||
"""load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
Multiple values can be compared by ``tools/compare_binary_iodump.py``. | |||
:param fobj: file object, or a string that contains the file name | |||
:return: tuple ``(tensor_value, tensor_name)`` | |||
""" | |||
if isinstance(fobj, str): | |||
with open(fobj, "rb") as fin: | |||
return load_tensor_binary(fin) | |||
DTYPE_LIST = { | |||
0: np.float32, | |||
1: np.uint8, | |||
2: np.int8, | |||
3: np.int16, | |||
4: np.int32, | |||
5: _mgb.intb1, | |||
6: _mgb.intb2, | |||
7: _mgb.intb4, | |||
8: None, | |||
9: np.float16, | |||
# quantized dtype start from 100000 | |||
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
# dnn/include/megdnn/dtype.h | |||
100000: np.uint8, | |||
100001: np.int32, | |||
100002: np.int8, | |||
} | |||
header_fmt = struct.Struct("III") | |||
name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
assert ( | |||
DTYPE_LIST[dtype] is not None | |||
), "Cannot load this tensor: dtype Byte is unsupported." | |||
shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
while shape[-1] == 0: | |||
shape.pop(-1) | |||
name = fobj.read(name_len).decode("ascii") | |||
return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name |
@@ -1,57 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
"""version information for MegBrain package""" | |||
import collections | |||
from . import mgb as _mgb | |||
class Version( | |||
collections.namedtuple("VersionBase", ["major", "minor", "patch", "dev"]) | |||
): | |||
"""simple sematic version object""" | |||
@classmethod | |||
def __normalize(cls, v): | |||
if isinstance(v, str): | |||
v = v.split(".") | |||
a, b, c = map(int, v) | |||
return cls(a, b, c) | |||
def __eq__(self, rhs): | |||
return super().__eq__(self.__normalize(rhs)) | |||
def __ne__(self, rhs): | |||
return super().__ne__(self.__normalize(rhs)) | |||
def __lt__(self, rhs): | |||
return super().__lt__(self.__normalize(rhs)) | |||
def __le__(self, rhs): | |||
return super().__le__(self.__normalize(rhs)) | |||
def __gt__(self, rhs): | |||
return super().__gt__(self.__normalize(rhs)) | |||
def __ge__(self, rhs): | |||
return super().__ge__(self.__normalize(rhs)) | |||
def __str__(self): | |||
rst = "{}.{}.{}".format(self.major, self.minor, self.patch) | |||
if self.dev: | |||
rst += "-dev{}".format(self.dev) | |||
return rst | |||
Version.__new__.__defaults__ = (0,) # dev defaults to 0 | |||
version_info = Version(*_mgb._get_mgb_version()) | |||
__version__ = str(version_info) |
@@ -1,20 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .device import ( | |||
get_default_device, | |||
get_device_count, | |||
is_cuda_available, | |||
set_default_device, | |||
) | |||
from .function import Function | |||
from .graph import Graph, dump | |||
from .serialization import load, save | |||
from .tensor import Tensor, TensorDict, tensor, wrap_io_tensor | |||
from .tensor_factory import ones, zeros | |||
from .tensor_nn import Buffer, Parameter |
@@ -1,60 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import megengine._internal as mgb | |||
_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||
def get_device_count(device_type: str) -> int: | |||
"""Gets number of devices installed on this system. | |||
:param device_type: device type, one of 'gpu' or 'cpu' | |||
""" | |||
device_type_set = ("cpu", "gpu") | |||
assert device_type in device_type_set, "device must be one of {}".format( | |||
device_type_set | |||
) | |||
return mgb.config.get_device_count(device_type) | |||
def is_cuda_available() -> bool: | |||
"""Returns whether cuda device is available on this system. | |||
""" | |||
return mgb.config.get_device_count("gpu", warn=False) > 0 | |||
def set_default_device(device: str = "xpux"): | |||
r"""Sets default computing node. | |||
:param device: default device type. The type can be 'cpu0', 'cpu1', etc., | |||
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | |||
'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. | |||
'multithread' device type is avaliable when inference, which implements | |||
multi-threading parallelism at the operator level. For example, | |||
'multithread4' will compute with 4 threads. which implements | |||
The default value is 'xpux' to specify any device available. | |||
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | |||
""" | |||
global _default_device # pylint: disable=global-statement | |||
_default_device = device | |||
def get_default_device() -> str: | |||
r"""Gets default computing node. | |||
It returns the value set by :func:`~.set_default_device`. | |||
""" | |||
return _default_device |
@@ -1,176 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Iterable, Tuple, Union | |||
import megengine._internal as mgb | |||
from .tensor import Tensor | |||
class _OverrideGradientCraniotome(mgb.craniotome.CraniotomeBase): | |||
__nr_inputs__ = None | |||
__nr_outputs__ = None | |||
__expand_single_outputs__ = False | |||
__allow_duplicate__ = False | |||
grad_func = None | |||
def setup(self, nr_inputs, nr_outputs, grad_func): | |||
self.__nr_inputs__ = nr_inputs + nr_outputs | |||
self.__nr_outputs__ = nr_outputs | |||
self.grad_func = grad_func | |||
def infer_shape(self, inp_shapes): | |||
return inp_shapes[-self.__nr_outputs__ :] | |||
def init_output_dtype(self, input_dtypes): | |||
return input_dtypes[-self.__nr_outputs__ :] | |||
def execute(self, inputs, outputs): | |||
for ivar, ovar in zip(inputs[-self.__nr_outputs__ :], outputs): | |||
ovar.set_value(ivar) | |||
def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
# TODO: Make sure grad_values really have values in eager mode. | |||
# Porting to the new imperative engine would solve this, but if it | |||
# don't happen, EagerEvalManager should be changed. | |||
grads = self.grad_func( | |||
*(Tensor(x) if x is not None else None for x in out_grad) | |||
) | |||
# pylint: disable=literal-comparison | |||
if isinstance(grads, Tensor) or grads is None or grads is 0: | |||
grads = (grads,) | |||
assert ( | |||
len(grads) == self.__nr_inputs__ - self.__nr_outputs__ | |||
), "Function.backward should return a tuple with len = {}, got {}".format( | |||
self.__nr_inputs__ - self.__nr_outputs__, len(grads) | |||
) | |||
# pylint: disable=literal-comparison | |||
return ( | |||
list(x._symvar if x is not None and x is not 0 else 0 for x in grads) | |||
+ [0] * self.__nr_outputs__ | |||
) | |||
def get_serialize_params(self): | |||
raise NotImplementedError("Serialization of Function is not implemented") | |||
class Function(metaclass=ABCMeta): | |||
""" | |||
Defines a block of operations with customizable differentiation. | |||
The computation should be defined in ``forward`` method, with gradient | |||
computation defined in ``backward`` method. | |||
Each instance of ``Function`` should be used only once during forwardding. | |||
Examples: | |||
.. testcode:: | |||
class Sigmoid(Function): | |||
def forward(self, x): | |||
y = 1 / (1 + F.exp(-x)) | |||
self.save_for_backward(y) | |||
return y | |||
def backward(self, output_grads): | |||
(y, ) = self.saved_tensors | |||
return output_grads * y * (1-y) | |||
""" | |||
_has_saved_state = False | |||
saved_tensors = None | |||
def __init__(self): | |||
self.saved_tensors = () | |||
@abstractmethod | |||
def forward(self, *inputs: Iterable[Tensor]) -> Union[Tuple[Tensor], Tensor]: | |||
""" | |||
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||
Users can call :meth:`~.function.Function.save_for_backward` in this method to save tensors. | |||
:param input: Input tensors. | |||
:return: A tuple of Tensor or a single Tensor. | |||
.. note:: | |||
This method should return a tuple of Tensor or a single Tensor representing the output | |||
of the function. | |||
""" | |||
raise NotImplementedError | |||
@abstractmethod | |||
def backward( | |||
self, *output_grads: Iterable[Union[Tensor, None]] | |||
) -> Union[Tuple[Tensor], Tensor]: | |||
""" | |||
Compute the gradient of the forward function. It must be overriden by all subclasses. | |||
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` | |||
.. note:: | |||
In case when some tensors of outputs are not related to loss function, the corresponding | |||
values in ``output_grads`` would be ``None``. | |||
.. note:: | |||
This method should return a tuple which containing the gradients of all inputs, in the same order | |||
as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||
instead if there is only one input. If users want to stop the propagation of some gradients, | |||
the corresponding returned values should be set ``None`` . | |||
""" | |||
raise NotImplementedError | |||
def save_for_backward(self, *tensors: Iterable[Tensor]): | |||
""" | |||
Saves tensors needed for gradient computation. This method should be called only | |||
once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously. | |||
The saved tensors can be accessed through the ``saved_tensors`` attribute. | |||
""" | |||
self.saved_tensors = tensors | |||
def __deepcopy__(self, memo): | |||
""" | |||
Defines how the operator is deeply copied | |||
""" | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
tmp = self.saved_tensors | |||
self.saved_tensors = None | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
setattr(result, k, copy.deepcopy(v, memo)) | |||
setattr(result, "saved_tensors", tmp) | |||
self.saved_tensors = tmp | |||
return result | |||
def __call__(self, *inputs): | |||
assert ( | |||
not self._has_saved_state | |||
), "A Function instance should not be called multiple times" | |||
outputs = self.forward(*inputs) | |||
if isinstance(outputs, Tensor): | |||
outputs = (outputs,) | |||
self._has_saved_state = True | |||
sv = (x._symvar for x in inputs + outputs) | |||
outputs = _OverrideGradientCraniotome.make( | |||
*sv, nr_inputs=len(inputs), nr_outputs=len(outputs), grad_func=self.backward | |||
) | |||
outputs = tuple(map(Tensor, outputs)) | |||
if len(outputs) == 1: | |||
outputs = outputs[0] | |||
return outputs |
@@ -1,158 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import threading | |||
import megengine._internal as mgb | |||
from .device import get_default_device | |||
class _DefaultGraph(threading.local): | |||
r""" | |||
An implicit thread-local graph | |||
""" | |||
def __init__(self): | |||
super(_DefaultGraph, self).__init__() | |||
self._default_graph = None | |||
def get_default(self): | |||
r"""Returns a default Graph object for eager evaluation. | |||
""" | |||
if self._default_graph is None: | |||
self._default_graph = Graph() | |||
return self._default_graph | |||
_default_graph = _DefaultGraph() | |||
class Graph(mgb.CompGraph): | |||
r""" | |||
A computing graph that supporting context management. | |||
:param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``. | |||
:param eager_evaluation: use dynamic graph(``True``) or static graph(``False``). | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
from megengine.core import Graph | |||
with Graph(eager_evaluation=True): | |||
x = tensor([1, 2]) | |||
print(x) | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([1 2], dtype=int32) | |||
""" | |||
__saved_graph = None | |||
def __new__( | |||
cls, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs | |||
): | |||
kwargs.update(eager_evaluation=eager_evaluation) | |||
self = mgb.comp_graph(extra_opts=kwargs, check_env_var=check_env_var) | |||
self.__class__ = cls | |||
return self | |||
def __init__( | |||
self, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs | |||
): | |||
# pylint: disable=super-init-not-called | |||
pass | |||
def __enter__(self): | |||
self.__saved_graph = _default_graph._default_graph | |||
_default_graph._default_graph = self | |||
return self | |||
def __exit__(self, type, value, traceback): | |||
_default_graph._default_graph = self.__saved_graph | |||
del self.__saved_graph | |||
def _use_default_if_none(device, comp_graph): | |||
if device is None: | |||
device = get_default_device() | |||
if comp_graph is None: | |||
comp_graph = get_default_graph() | |||
return device, comp_graph | |||
def dump(outputs, fpath, optimize_options=None, **kwargs): | |||
r""" | |||
Serializes this computing graph and writes it to a file. | |||
:type outputs: ``Tensor`` or a collection of ``Tensor`` | |||
:param outputs: output variables that need to be retrieved when | |||
deserializing | |||
:type fpath: ``str`` | |||
:param fpath: path for the output file | |||
:type optimize_options: ``list`` | |||
:param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them. | |||
.. note:: | |||
``f16_io_f32_comp`` – whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be changed to float16; | |||
``f16_io_comp`` – whether to use float16 for both I/O and computation precision; | |||
``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices; | |||
``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only when ``use_nhwcd4`` is set. | |||
""" | |||
from .tensor import Tensor | |||
assert optimize_options is None or isinstance( | |||
optimize_options, list | |||
), "optimize_options must be a list" | |||
if isinstance(outputs, Tensor): | |||
outputs = [outputs] | |||
else: | |||
assert isinstance(outputs, collections.Iterable), "{} not iterable".format( | |||
outputs | |||
) | |||
outputs = list(outputs) | |||
for output in outputs: | |||
assert isinstance(output, Tensor), "All outputs must be Tensors." | |||
outputs = [o._symvar for o in outputs] | |||
if optimize_options: | |||
opt_dict = dict.fromkeys(optimize_options, True) | |||
mgb.optimize_for_inference(outputs, **opt_dict) | |||
mgb.serialize_comp_graph_to_file(fpath, outputs, **kwargs) | |||
def set_default_graph(default_graph): | |||
r""" | |||
Sets a global default Graph object. | |||
""" | |||
global _default_graph # pylint: disable=global-statement | |||
_default_graph._default_graph = default_graph | |||
def get_default_graph(): | |||
r""" | |||
Returns a default Graph object, most probably for eager evaluation. | |||
""" | |||
return _default_graph.get_default() |
@@ -1,128 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import pickle | |||
import megengine._internal as mgb | |||
from ..utils.max_recursion_limit import max_recursion_limit | |||
from .device import get_default_device | |||
def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): | |||
r"""Save an object to disk file. | |||
:type obj: object | |||
:param obj: object to save. Only ``module`` or ``state_dict`` are allowed. | |||
:type f: text file object | |||
:param f: a string of file name or a text file object to which ``obj`` is saved to. | |||
:type pickle_module: | |||
:param pickle_module: Default: ``pickle``. | |||
:type pickle_protocol: | |||
:param pickle_protocol: Default: ``pickle.HIGHEST_PROTOCOL``. | |||
""" | |||
if isinstance(f, str): | |||
with open(f, "wb") as fout: | |||
save( | |||
obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol | |||
) | |||
return | |||
with max_recursion_limit(): | |||
assert hasattr(f, "write"), "{} does not support write".format(f) | |||
pickle_module.dump(obj, f, pickle_protocol) | |||
class dmap: | |||
def __init__(self, map_location): | |||
self.map_location = map_location | |||
def __enter__(self): | |||
mgb.add_device_map(self.map_location) | |||
return self | |||
def __exit__(self, type, value, traceback): | |||
mgb.del_device_map() | |||
def _get_callable_map_location(map_location): | |||
if map_location is None: | |||
def callable_map_location(state): | |||
return str(get_default_device()) | |||
elif isinstance(map_location, str): | |||
def callable_map_location(state): | |||
return map_location | |||
elif isinstance(map_location, dict): | |||
locator_map = {} | |||
for key, value in map_location.items(): | |||
locator_key = mgb.config.parse_locator(key)[:2] | |||
locator_map[locator_key] = value | |||
def callable_map_location(state): | |||
orig = mgb.config.parse_locator(state)[:2] | |||
if orig in locator_map.keys(): | |||
state = locator_map[orig] | |||
return state | |||
else: | |||
assert callable(map_location), "map_location should be str, dict or function" | |||
callable_map_location = map_location | |||
return callable_map_location | |||
def load(f, map_location=None, pickle_module=pickle): | |||
r"""Load an object saved with save() from a file. | |||
:type f: text file object | |||
:param f: a string of file name or a text file object from which to load. | |||
:type map_location: str, dict or a function specifying the map rules | |||
:param map_location: Default: ``None``. | |||
.. note:: | |||
map_location will change the logical locator when loading models, | |||
avoiding tensors be loading on non-existent device. If you want to | |||
add the mapping relationship between logical locator and physical | |||
locator in runtime, please call :func:`mge.set_device_map()` | |||
:type pickle_module: | |||
:param pickle_module: Default: ``pickle``. | |||
.. note:: | |||
If you will call :func:`mge.set_default_device()`, please do it | |||
before :func:`mge.load()`. | |||
Examples: | |||
.. testcode: | |||
import megengine as mge | |||
mge.load('model.mge') | |||
# Load all tensors based on logical location. | |||
mge.load('model.mge', map_location='gpu0') | |||
# Load all tensors onto the device: GPU0 | |||
mge.load('model.mge', map_location={'gpu0':'cpu0'}) | |||
# Load all tensors based on logical location, but 'GPU0' will be renamed to 'CPU0' | |||
mge.load('model.mge', map_location=lambda dev: 'cpu0') | |||
# Load all tensors onto the device" CPU0 | |||
""" | |||
if isinstance(f, str): | |||
with open(f, "rb") as fin: | |||
return load(fin, map_location=map_location, pickle_module=pickle_module) | |||
map_location = _get_callable_map_location(map_location) # callable map_location | |||
with dmap(map_location): | |||
return pickle_module.load(f) |
@@ -1,771 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
import functools | |||
import itertools | |||
import weakref | |||
from typing import Callable, Tuple, Union | |||
import numpy as np | |||
import megengine._internal as mgb | |||
from .graph import _use_default_if_none, get_default_graph | |||
def wrap_io_tensor(func): | |||
r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``. | |||
""" | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
comp_graph = None | |||
for i in itertools.chain(args, kwargs.values()): | |||
if isinstance(i, Tensor) and i._comp_graph: | |||
comp_graph = i._comp_graph | |||
break | |||
else: | |||
comp_graph = get_default_graph() | |||
new_args = ( | |||
arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args | |||
) | |||
new_kwargs = { | |||
k: v._attach(comp_graph) if isinstance(v, Tensor) else v | |||
for k, v in kwargs.items() | |||
} | |||
ret = func(*new_args, **new_kwargs) | |||
if isinstance(ret, mgb.SymbolVar): | |||
ret = Tensor(ret) | |||
elif isinstance(ret, list): | |||
ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret] | |||
elif isinstance(ret, tuple): | |||
ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret) | |||
return ret | |||
return wrapper | |||
def _wrap_symbolvar_binary_op(f): | |||
@functools.wraps(f) | |||
def wrapped(self, other): | |||
comp_graph = ( | |||
isinstance(other, Tensor) | |||
and other._comp_graph | |||
or self._comp_graph | |||
or get_default_graph() | |||
) | |||
if isinstance(other, Tensor): | |||
other = other._attach(comp_graph) | |||
return Tensor(f(self._attach(comp_graph), other)) | |||
return wrapped | |||
def _wrap_slice(inp: slice): | |||
r""" | |||
A wrapper to handle Tensor values in ``inp`` slice. | |||
""" | |||
start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start | |||
stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop | |||
step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step | |||
return slice(start, stop, step) | |||
def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]): | |||
r""" | |||
A wrapper to handle Tensor values in ``idx``. | |||
""" | |||
if not isinstance(idx, tuple): | |||
idx = (idx,) | |||
idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx) | |||
idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx) | |||
return idx | |||
class _MGBIndexWrapper: | |||
r""" | |||
A wrapper class to handle ``__getitem__`` for index containing Tensor values. | |||
:param dest: a destination Tensor to do indexing on. | |||
:param mgb_index: an ``_internal`` helper function indicating how to index. | |||
:param val: a optional Tensor parameter used for ``mgb_index``. | |||
""" | |||
def __init__(self, dest: "Tensor", mgb_index: Callable, val=None): | |||
self.dest = dest | |||
self.val = val | |||
self.mgb_index = mgb_index | |||
def __getitem__(self, idx): | |||
if self.val is None: | |||
return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)( | |||
_wrap_idx(idx) | |||
) | |||
else: | |||
return wrap_io_tensor( | |||
self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__ | |||
)(_wrap_idx(idx)) | |||
class _Guard: | |||
r""" | |||
A wrapper class with custom ``__del__`` method calling ``deleter``. | |||
:param deleter: a function to be called in ``__del__``. | |||
""" | |||
def __init__(self, deleter: Callable): | |||
self.deleter = deleter | |||
def __del__(self): | |||
self.deleter() | |||
class Tensor: | |||
r"""The main data container in MegEngine. | |||
Use :func:`~.tensor` to create a Tensor with existed data. | |||
""" | |||
requires_grad = False | |||
grad = None | |||
def __init__(self, val=None, *, requires_grad=None): | |||
self._reset(val, requires_grad=requires_grad) | |||
self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
def _reset(self, val=None, *, requires_grad=None): | |||
self.__sym_override = None | |||
if val is None: | |||
self.__val = None | |||
self.__sym = None | |||
elif isinstance(val, mgb.SharedND): | |||
self.__val = val | |||
self.__sym = None | |||
elif isinstance(val, mgb.SymbolVar): | |||
self.__val = None | |||
self.__sym = val | |||
else: | |||
raise TypeError("must be initialized with SymbolVar or SharedND") | |||
self.requires_grad = requires_grad | |||
def _as_tensor(self, obj): | |||
r"""Convert the data into a ``Tensor``. If the data is already a Tensor | |||
with the same dtype and device, no copy will be performed. Otherwise a | |||
new Tensor will be returned with computational graph retained. | |||
""" | |||
if isinstance(obj, Tensor): | |||
return obj | |||
if isinstance(obj, mgb.SymbolVar): | |||
return Tensor(obj) | |||
if isinstance(obj, mgb.SharedScalar): | |||
return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node)) | |||
return tensor(data=obj, device=self.device) | |||
def numpy(self): | |||
r"""Return the tensor value in numpy.ndarray format. | |||
""" | |||
if self.__val is not None: | |||
assert self.__sym is None | |||
return self.__val.get_value() | |||
if self.__sym is None: | |||
raise ValueError("uninitialized") | |||
if self.__sym.eager_val is not None: | |||
return self.__sym.eager_val.get_value() | |||
return self.__sym.inferred_value | |||
def item(self): | |||
r"""If tensor only has only one value, return it.""" | |||
return self.numpy().item() | |||
def _attach(self, comp_graph, *, volatile=True): | |||
sym = self.__sym_override or self.__sym | |||
if sym: | |||
if sym.owner_graph != comp_graph: | |||
raise RuntimeError("internal error") | |||
return sym | |||
if self.__val: | |||
return self.__val.symvar(comp_graph, volatile=volatile) | |||
else: | |||
raise ValueError("uninitialized") | |||
@property | |||
def _symvar(self): | |||
if self.__sym_override: | |||
return self.__sym_override | |||
if self.__sym: | |||
assert not self.__val | |||
return self.__sym | |||
if not self.__val: | |||
raise ValueError("uninitialized") | |||
return self._attach(get_default_graph()) | |||
def __mgb_symvar__(self, comp_graph=None, **_): | |||
if self.__sym_override: | |||
return self.__sym_override | |||
if self.__val and comp_graph: | |||
return self._attach(comp_graph) | |||
return self._symvar # read by mgb.opr | |||
def _override_symvar_during_trace(self, trace, symvar): | |||
assert self.__val and not self.__sym | |||
assert trace is type(trace)._active_instance | |||
deleters = trace._user_cache.setdefault(Tensor, set()) | |||
self_ref = weakref.ref(self) | |||
def restore(): | |||
self = self_ref() | |||
if self is not None: | |||
self.__sym_override = None | |||
deleters.add(_Guard(restore)) | |||
self.__sym_override = symvar | |||
@property | |||
def dtype(self): | |||
r"""Return the data type of the tensor. | |||
""" | |||
if self.__val is not None: | |||
return self.__val.dtype | |||
return self._symvar.dtype | |||
@dtype.setter | |||
def dtype(self, dtype: str = None): | |||
r"""Set the data type of the tensor. | |||
""" | |||
if self.__val is not None: | |||
self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy()) | |||
elif self.__sym_override is not None: | |||
self.__sym_override = self.__sym_override.astype(dtype) | |||
elif self.__sym is not None: | |||
self.__sym = self.__sym.astype(dtype) | |||
@property | |||
def name(self): | |||
r"""Get the tensor name, does not support Parameter and Buffer. | |||
""" | |||
return self._symvar.name | |||
@name.setter | |||
def name(self, name: str = None): | |||
r"""Set the tensor name, does not support Parameter and Buffer. | |||
""" | |||
if self.__val is not None: | |||
raise ValueError("name setting is not available for Parameter or Buffer.") | |||
if self.__sym_override is not None: | |||
self.__sym_override = self.__sym_override.rename(name) | |||
if self.__sym is not None: | |||
assert not self.__val | |||
self.__sym = self.__sym.rename(name) | |||
@property | |||
def _comp_node(self): | |||
if self.__val is not None: | |||
return self.__val.comp_node | |||
return self._symvar.comp_node | |||
device = _comp_node | |||
@property | |||
def _comp_graph(self): | |||
if self.__sym is not None: | |||
return self.__sym.owner_graph | |||
return None | |||
@property | |||
def shape(self): | |||
r"""Return an int tuple that is the shape/layout of the tensor. | |||
Could be invalid in static graph mode. | |||
""" | |||
from ..jit import trace | |||
if trace._active_instance: # pylint: disable=protected-access | |||
# NOTE: this is an hack | |||
shape = mgb.opr.get_var_shape(self._symvar) | |||
return tuple(Tensor(shape[i]) for i in range(self.ndim)) | |||
return self._symvar.imm_shape | |||
def set_value(self, value, *, sync=True, inplace=False, share=False): | |||
r"""Set value to the tensor. | |||
""" | |||
if not self.__val: | |||
raise ValueError("not detached") | |||
if isinstance(value, Tensor): | |||
value = value.__val or value.__sym.eager_val | |||
self.__val.set_value(value, sync=sync, inplace=inplace, share=share) | |||
def fill(self, value): | |||
r"""Fills the tensor with the specified value. | |||
""" | |||
self.set_value(np.full(self.shape, value, dtype=self.dtype)) | |||
def reset_zero(self): | |||
r"""Reset the tensor and fills with zeros. | |||
""" | |||
if not self.__val: | |||
raise ValueError("not detached") | |||
self.__val.reset_zero() | |||
def to(self, device): | |||
r"""Performs Tensor device conversion, returns Tensor with the specified device. | |||
""" | |||
return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device) | |||
# https://docs.python.org/3/reference/datamodel.html#object.__hash__ | |||
# > If a class does not define an __eq__() method it should not define a | |||
# > __hash__() operation either | |||
__hash__ = None # type: ignore[assignment] | |||
def __eq__(self, rhs): | |||
rhs = self._as_tensor(rhs) | |||
return Tensor(self._symvar._binary_opr("EQ", rhs._symvar)) | |||
def __ne__(self, rhs): | |||
return 1 - self.__eq__(rhs) | |||
def __len__(self): | |||
if self._symvar.eager_val is not None: | |||
return self._symvar.eager_val.shape[0] | |||
raise TypeError( | |||
"__len__ and __iter__ is not available for tensors on non eager graph." | |||
) | |||
__add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__) | |||
__radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__) | |||
__sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__) | |||
__rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__) | |||
__mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__) | |||
__rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__) | |||
__matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__) | |||
__rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__) | |||
__lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__) | |||
__rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__) | |||
__truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__) | |||
__rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__) | |||
__floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__) | |||
__rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__) | |||
__mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__) | |||
__rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__) | |||
__pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__) | |||
__rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__) | |||
__lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__) | |||
__gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__) | |||
__le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__) | |||
__ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__) | |||
__neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__) | |||
sum = wrap_io_tensor(mgb.SymbolVar.sum) | |||
""" | |||
Sum up the given tensors. | |||
""" | |||
max = wrap_io_tensor(mgb.SymbolVar.max) | |||
""" | |||
Return the maximum value of given tensor. | |||
""" | |||
min = wrap_io_tensor(mgb.SymbolVar.min) | |||
""" | |||
Return the minimum value of given tensor. | |||
""" | |||
prod = wrap_io_tensor(mgb.SymbolVar.prod) | |||
""" | |||
Return the product value of the given tensor. | |||
""" | |||
mean = wrap_io_tensor(mgb.SymbolVar.mean) | |||
""" | |||
Return the mean value of the given tensor. | |||
""" | |||
dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle) | |||
""" | |||
See more details in :func:`~.functional.tensor.dimshuffle`. | |||
""" | |||
astype = wrap_io_tensor(mgb.SymbolVar.astype) | |||
""" | |||
Cast the tensor to a specified type. | |||
""" | |||
def reshape(self, *target_shape): | |||
r"""Return a tensor which has given target shape | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4)) | |||
out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16)) | |||
out = out.reshape(inp.shape) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[[100 101 102 103] | |||
[104 105 106 107] | |||
[108 109 110 111] | |||
[112 113 114 115]] | |||
""" | |||
if isinstance(target_shape[0], tuple): | |||
if len(target_shape) > 1: | |||
raise ValueError("Only single tuple is accepted in reshape") | |||
target_shape = target_shape[0] | |||
target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) | |||
return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape)) | |||
def broadcast(self, *target_shape): | |||
r"""Return a tesnor broadcasted by current tensor to given target shape | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4)) | |||
data = data.broadcast((4,4)) | |||
print(data.numpy()) | |||
.. testoutput:: | |||
[[100 101 102 103] | |||
[100 101 102 103] | |||
[100 101 102 103] | |||
[100 101 102 103]] | |||
""" | |||
if isinstance(target_shape[0], tuple): | |||
if len(target_shape) > 1: | |||
raise ValueError("Only single tuple is accepted in broadcast") | |||
target_shape = target_shape[0] | |||
target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape) | |||
return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape)) | |||
# Prefer operators on Tensor instead of convert to numpy | |||
__array_priority__ = 1000 | |||
# mgb indexing family | |||
def __getitem__(self, idx): | |||
return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx)) | |||
def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Return a object which supports using ``__getitem__`` to set subtensor. | |||
``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val) | |||
def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Return a object which supports using ``__getitem__`` to increase subtensor. | |||
``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val) | |||
@property | |||
def ai(self) -> _MGBIndexWrapper: | |||
r""" | |||
Return a object which supports complex index method to get subtensor. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) | |||
print(a.ai[:, [2, 3]]) | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([[ 2. 3.] | |||
[ 6. 7.] | |||
[10. 11.] | |||
[14. 15.]]) | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.advanced_indexing) | |||
def set_ai(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val) | |||
def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val) | |||
@property | |||
def mi(self) -> _MGBIndexWrapper: | |||
r""" | |||
Return a object which supports getting subtensor by | |||
the coordinates which is Cartesian product of given index. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4))) | |||
print(a.mi[[1, 2], [2, 3]]) | |||
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]] | |||
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11 | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([[ 6. 7.] | |||
[10. 11.]]) | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.mesh_indexing) | |||
def set_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val) | |||
def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val) | |||
@property | |||
def batched_mi(self) -> _MGBIndexWrapper: | |||
r""" | |||
Return a object which supports getting subtensor by | |||
batched mesh indexing. | |||
For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice. | |||
Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``. | |||
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated. | |||
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4))) | |||
print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]]) | |||
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1]) | |||
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1]) | |||
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77 | |||
print(a.batched_mi[:2, [[0],[1]], :2, :1]) | |||
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]`` | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([[[[ 0.] | |||
[ 4.]]] | |||
[[[73.] | |||
[77.]]]]) | |||
Tensor([[[[ 0.] | |||
[ 4.]]] | |||
[[[64.] | |||
[68.]]]]) | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing) | |||
def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val) | |||
def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper: | |||
r""" | |||
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing. | |||
""" | |||
return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val) | |||
def __array__(self, dtype=None): | |||
if dtype is None: | |||
return self.numpy() | |||
else: | |||
return self.numpy().astype(dtype, copy=False) | |||
def __int__(self): | |||
return int(self.item()) | |||
def __index__(self): | |||
return int(self.item()) | |||
def __round__(self, ndigits=0): | |||
if ndigits != 0: | |||
raise ValueError("ndigits must be 0 for Tensor.round") | |||
return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND")) | |||
round = __round__ | |||
def sqrt(self): | |||
r"""Return a tensor that each element is the square root of its | |||
original value. | |||
""" | |||
return Tensor(mgb.opr.sqrt(self._symvar)) | |||
def shapeof(self, axis=None): | |||
r"""Return a Tensor that represent the shape of the tensor. | |||
""" | |||
return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis)) | |||
@property | |||
def ndim(self): | |||
r"""Return the number of dimensions of the tensor. | |||
""" | |||
return len(self._symvar.imm_shape) | |||
def __repr__(self): | |||
piece = "Tensor(" | |||
with np.printoptions(precision=4, suppress=True): | |||
piece += "{}".format(str(self.numpy())) | |||
if self.dtype != np.float32: | |||
piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
if self._comp_node.locator_logical != ("XPU", -1, 0): | |||
piece += ", device={}".format(self.device) | |||
piece += ")" | |||
return piece | |||
def __bool__(self): | |||
raise RuntimeError( | |||
"Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode." | |||
) | |||
def __getstate__(self): | |||
r""" __getstate__ will be called for pickle serialization or deep copy | |||
""" | |||
assert (self.__val is not None) and ( | |||
self.__sym is None | |||
), "Only SharedND initialized Tensor can be serialized or deep copied" | |||
metadata = {"requires_grad": self.requires_grad} | |||
state = { | |||
"data": self.numpy(), | |||
"device": self.device, | |||
"dtype": self.dtype, | |||
"metadata": metadata, | |||
} | |||
return state | |||
def __setstate__(self, state): | |||
data = state.pop("data") | |||
device = state.pop("device") | |||
dtype = state.pop("dtype") | |||
metadata = state.pop("metadata", {}) | |||
requires_grad = metadata.pop("requires_grad", None) | |||
snd = mgb.make_shared(device, value=data, dtype=dtype) | |||
self._reset(snd, requires_grad=requires_grad) | |||
def __deepcopy__(self, memo): | |||
""" | |||
The default deepcopy will ignore other attributes except those defined at | |||
__getstate__ and __setstate__ method. | |||
So we need to add __deepcopy__ method to deepcopy correct attributes. | |||
""" | |||
assert (self.__val is not None) and ( | |||
self.__sym is None | |||
), "Only SharedND initialized Tensor can be serialized or deep copied" | |||
cls = self.__class__ | |||
result = cls.__new__(cls) | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
setattr(result, k, copy.deepcopy(v, memo)) | |||
return result | |||
def tensor( | |||
data: Union[list, np.ndarray] = None, | |||
*, | |||
dtype: str = None, | |||
device: mgb.CompNode = None, | |||
requires_grad: bool = None | |||
): | |||
r"""A helper function to create a :class:`~.Tensor` using existing data. | |||
:param data: an existing data array, must be Python list, NumPy array or None. | |||
:param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``. | |||
:param device: target device for Tensor storing. | |||
:param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward` | |||
""" | |||
supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16") | |||
if isinstance(data, Tensor): | |||
raise NotImplementedError | |||
if dtype is not None and np.dtype(dtype).name not in supported_dtypes: | |||
raise TypeError("unsupported dtype {}".format(dtype)) | |||
if data is not None: | |||
if not isinstance(data, np.ndarray): | |||
data = np.array(data, dtype=dtype) | |||
# In order to accept tensor([1]), | |||
# Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray. | |||
dtype = mgb.to_mgb_supported_dtype(data.dtype) | |||
if dtype is None: | |||
if data.dtype.name not in supported_dtypes: | |||
raise TypeError("unsupported dtype {}".format(data.dtype)) | |||
device, _ = _use_default_if_none(device, None) | |||
shared_nd = mgb.make_shared(device, value=data, dtype=dtype) | |||
return Tensor(shared_nd, requires_grad=requires_grad) | |||
class TensorDict(collections.MutableMapping): | |||
r""" | |||
A helper class to maintain dict with Tensor key. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
self.data = {} | |||
for i in args: | |||
self.update(i) | |||
self.update(**kwargs) | |||
class keyfn: | |||
def __new__(cls, x: Tensor): | |||
if not isinstance(x, Tensor): | |||
return x | |||
return super().__new__(cls) | |||
def __init__(self, x: Tensor): | |||
self._data = x # do not save id directly to make pickle work | |||
def __hash__(self): | |||
return id(self._data) | |||
def __eq__(self, other): | |||
return isinstance(other, type(self)) and id(self._data) == id(other._data) | |||
def __getitem__(self, key): | |||
_, v = self.data[self.keyfn(key)] | |||
return v | |||
def __setitem__(self, key, value): | |||
self.data[self.keyfn(key)] = key, value | |||
def __delitem__(self, key): | |||
del self.data[self.keyfn(key)] | |||
def __iter__(self): | |||
for _, (k, _) in self.data.items(): | |||
yield k | |||
def __len__(self): | |||
return len(self.data) |
@@ -1,109 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Optional, Union | |||
import megengine._internal as mgb | |||
from .graph import _use_default_if_none | |||
from .tensor import Tensor | |||
__all__ = ["zeros", "ones"] | |||
def scalar( | |||
value, | |||
dtype: type = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> Tensor: | |||
""" | |||
convert ``value`` to the type of :class:`~.Tensor`. | |||
""" | |||
device, comp_graph = _use_default_if_none(device, comp_graph) | |||
return Tensor(mgb.make_immutable(device, comp_graph, value, dtype=dtype, name=None)) | |||
def zeros( | |||
shape: Union[int, Iterable[int], Tensor], | |||
dtype: type = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> Tensor: | |||
""" | |||
Create a tensor filled with 0. | |||
:param shape: tensor shape | |||
:param dtype: data type, Default: "int32" | |||
:param device: Compute node of the matrix, Default: None | |||
:param comp_graph: Compute graph of the matrix, Default: None | |||
:return: tensor of zeros | |||
Examples: | |||
.. testcode:: | |||
import megengine as mge | |||
t = mge.zeros((2, 2), dtype="int32") | |||
print(t.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0 0] | |||
[0 0]] | |||
""" | |||
device, comp_graph = _use_default_if_none(device, comp_graph) | |||
if isinstance(shape, (int, Tensor)): | |||
shape = (shape,) | |||
tensor = scalar(0, dtype=dtype, device=device, comp_graph=comp_graph) | |||
tensor = tensor.broadcast(*shape) | |||
return tensor | |||
def ones( | |||
shape: Union[int, Iterable[int], Tensor], | |||
dtype: type = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> Tensor: | |||
""" | |||
Create a tensor filled with 1. | |||
:param shape: tensor shape | |||
:param dtype: data type, Default: "int32" | |||
:param device: Compute node of the matrix, Default: None | |||
:param comp_graph: Compute graph of the matrix, Default: None | |||
:return: tensor of ones | |||
Examples: | |||
.. testcode:: | |||
import megengine as mge | |||
t = mge.ones((2, 2), dtype="float32") | |||
print(t.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1. 1.] | |||
[1. 1.]] | |||
""" | |||
device, comp_graph = _use_default_if_none(device, comp_graph) | |||
if isinstance(shape, (int, Tensor)): | |||
shape = (shape,) | |||
tensor = scalar(1, dtype=dtype, device=device, comp_graph=comp_graph) | |||
tensor = tensor.broadcast(*shape) | |||
return tensor |
@@ -1,45 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .tensor import Tensor, tensor | |||
class Buffer(Tensor): | |||
r"""A kind of Tensor with ``requires_grad=False``. | |||
""" | |||
def __init__(self, value, *, dtype=None, device=None, requires_grad=False): | |||
# pylint: disable=super-init-not-called | |||
t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | |||
self.__dict__.update(t.__dict__) | |||
class Parameter(Tensor): | |||
r"""A kind of Tensor that is to be considered a module parameter. | |||
""" | |||
def __init__(self, value, *, dtype=None, device=None, requires_grad=True): | |||
# pylint: disable=super-init-not-called | |||
if isinstance(value, Tensor): | |||
t = value | |||
else: | |||
t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) | |||
self.__dict__.update(t.__dict__) | |||
# broadcast and allreduce will not be performed in optimizer if replica_mode is False | |||
self.replica_mode = True | |||
@property | |||
def shape(self): | |||
r"""Return shape of parameter. | |||
""" | |||
if self._Tensor__val is not None: | |||
return self._Tensor__val.shape | |||
elif self._Tensor__sym is not None: | |||
return self._Tensor__sym.imm_shape | |||
return None |
@@ -1,17 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .collator import Collator | |||
from .dataloader import DataLoader | |||
from .sampler import ( | |||
Infinite, | |||
RandomSampler, | |||
ReplacementSampler, | |||
Sampler, | |||
SequentialSampler, | |||
) |
@@ -1,144 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import binascii | |||
import os | |||
import queue | |||
import subprocess | |||
from multiprocessing import Queue | |||
import pyarrow | |||
import pyarrow.plasma as plasma | |||
MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB | |||
# Each process only need to start one plasma store, so we set it as a global variable. | |||
# TODO: how to share between different processes? | |||
MGE_PLASMA_STORE_MANAGER = None | |||
def _clear_plasma_store(): | |||
# `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | |||
# so this function should be called explicitly | |||
global MGE_PLASMA_STORE_MANAGER | |||
if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: | |||
del MGE_PLASMA_STORE_MANAGER | |||
MGE_PLASMA_STORE_MANAGER = None | |||
class _PlasmaStoreManager: | |||
__initialized = False | |||
def __init__(self): | |||
self.socket_name = "/tmp/mge_plasma_{}".format( | |||
binascii.hexlify(os.urandom(8)).decode() | |||
) | |||
debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) | |||
# NOTE: this is a hack. Directly use `plasma_store` may make subprocess | |||
# difficult to handle the exception happened in `plasma-store-server`. | |||
# For `plasma_store` is just a wrapper of `plasma-store-server`, which use | |||
# `os.execv` to call the executable `plasma-store-server`. | |||
cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") | |||
self.plasma_store = subprocess.Popen( | |||
[cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], | |||
stdout=None if debug_flag else subprocess.DEVNULL, | |||
stderr=None if debug_flag else subprocess.DEVNULL, | |||
) | |||
self.__initialized = True | |||
self.refcount = 1 | |||
def __del__(self): | |||
if self.__initialized and self.plasma_store.returncode is None: | |||
self.plasma_store.kill() | |||
class PlasmaShmQueue: | |||
def __init__(self, maxsize: int = 0): | |||
r"""Use pyarrow in-memory plasma store to implement shared memory queue. | |||
Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle | |||
and communication overhead, leading to better performance in multi-process | |||
application. | |||
:type maxsize: int | |||
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) | |||
""" | |||
# Lazy start the plasma store manager | |||
global MGE_PLASMA_STORE_MANAGER | |||
if MGE_PLASMA_STORE_MANAGER is None: | |||
try: | |||
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() | |||
except Exception as e: | |||
err_info = ( | |||
"Please make sure pyarrow installed correctly!\n" | |||
"You can try reinstall pyarrow and see if you can run " | |||
"`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." | |||
) | |||
raise RuntimeError( | |||
"Exception happened in starting plasma_store: {}\n" | |||
"Tips: {}".format(str(e), err_info) | |||
) | |||
else: | |||
MGE_PLASMA_STORE_MANAGER.refcount += 1 | |||
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | |||
# TODO: how to catch the exception happened in `plasma.connect`? | |||
self.client = None | |||
# Used to store the header for the data.(ObjectIDs) | |||
self.queue = Queue(maxsize) # type: Queue | |||
def put(self, data, block=True, timeout=None): | |||
if self.client is None: | |||
self.client = plasma.connect(self.socket_name) | |||
try: | |||
object_id = self.client.put(data) | |||
except plasma.PlasmaStoreFull: | |||
raise RuntimeError("plasma store out of memory!") | |||
try: | |||
self.queue.put(object_id, block, timeout) | |||
except queue.Full: | |||
self.client.delete([object_id]) | |||
raise queue.Full | |||
def get(self, block=True, timeout=None): | |||
if self.client is None: | |||
self.client = plasma.connect(self.socket_name) | |||
object_id = self.queue.get(block, timeout) | |||
if not self.client.contains(object_id): | |||
raise RuntimeError( | |||
"ObjectID: {} not found in plasma store".format(object_id) | |||
) | |||
data = self.client.get(object_id) | |||
self.client.delete([object_id]) | |||
return data | |||
def qsize(self): | |||
return self.queue.qsize() | |||
def empty(self): | |||
return self.queue.empty() | |||
def join(self): | |||
self.queue.join() | |||
def disconnect_client(self): | |||
if self.client is not None: | |||
self.client.disconnect() | |||
def close(self): | |||
self.queue.close() | |||
self.disconnect_client() | |||
global MGE_PLASMA_STORE_MANAGER | |||
MGE_PLASMA_STORE_MANAGER.refcount -= 1 | |||
_clear_plasma_store() | |||
def cancel_join_thread(self): | |||
self.queue.cancel_join_thread() |
@@ -1,76 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |||
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |||
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||
# Copyright (c) 2011-2013 NYU (Clement Farabet) | |||
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# ---------------------------------------------------------------------- | |||
import collections.abc | |||
import re | |||
import numpy as np | |||
np_str_obj_array_pattern = re.compile(r"[aO]") | |||
default_collate_err_msg_format = ( | |||
"default_collator: inputs must contain numpy arrays, numbers, " | |||
"Unicode strings, bytes, dicts or lists; found {}" | |||
) | |||
class Collator: | |||
r""" | |||
Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. | |||
modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py | |||
""" | |||
def apply(self, inputs): | |||
""" | |||
input : sequence_N(tuple(CHW, C, CK)) | |||
output : tuple(NCHW, NC, NCK) | |||
""" | |||
elem = inputs[0] | |||
elem_type = type(elem) | |||
if ( | |||
elem_type.__module__ == "numpy" | |||
and elem_type.__name__ != "str_" | |||
and elem_type.__name__ != "string_" | |||
): | |||
elem = inputs[0] | |||
if elem_type.__name__ == "ndarray": | |||
# array of string classes and object | |||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |||
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |||
return np.ascontiguousarray(np.stack(inputs)) | |||
elif elem.shape == (): # scalars | |||
return np.array(inputs) | |||
elif isinstance(elem, float): | |||
return np.array(inputs, dtype=np.float64) | |||
elif isinstance(elem, int): | |||
return np.array(inputs) | |||
elif isinstance(elem, (str, bytes)): | |||
return inputs | |||
elif isinstance(elem, collections.abc.Mapping): | |||
return {key: self.apply([d[key] for d in inputs]) for key in elem} | |||
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | |||
return elem_type(*(self.apply(samples) for samples in zip(*inputs))) | |||
elif isinstance(elem, collections.abc.Sequence): | |||
transposed = zip(*inputs) | |||
return [self.apply(samples) for samples in transposed] | |||
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
@@ -1,500 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import math | |||
import multiprocessing | |||
import queue | |||
import random | |||
import time | |||
import numpy as np | |||
from ..logger import get_logger | |||
from ..random.rng import _random_seed_generator | |||
from .collator import Collator | |||
from .dataset import Dataset | |||
from .sampler import Sampler, SequentialSampler | |||
from .transform import PseudoTransform, Transform | |||
logger = get_logger(__name__) | |||
MP_QUEUE_GET_TIMEOUT = 5 | |||
class DataLoader: | |||
__initialized = False | |||
def __init__( | |||
self, | |||
dataset: Dataset, | |||
sampler: Sampler = None, | |||
transform: Transform = None, | |||
collator: Collator = None, | |||
num_workers: int = 0, | |||
timeout: int = 0, | |||
divide: bool = False, | |||
): | |||
r"""Provides a convenient way to iterate on a given dataset. | |||
`DataLoader` combines a dataset with sampler, transform and collator, | |||
make it flexible to get minibatch continually from a dataset. | |||
:type dataset: Dataset | |||
:param dataset: dataset from which to load the minibatch. | |||
:type sampler: Sampler | |||
:param sampler: defines the strategy to sample data from the dataset. | |||
If specified, :attr:`shuffle` must be ``False``. | |||
:type transform: Transform | |||
:param transform: defined the transforming strategy for a sampled batch. | |||
(default: ``None``) | |||
:type collator: Collator | |||
:param collator: defined the merging strategy for a transformed batch. | |||
(default: ``None``) | |||
:type num_workers: int | |||
:param num_workers: the number of sub-process to load, transform and collate | |||
the batch. ``0`` means using single-process. (default: ``0``) | |||
:type timeout: int | |||
:param timeout: if positive, means the timeout value(second) for collecting a | |||
batch from workers. (default: 0) | |||
:type divide: bool | |||
:param divide: define the paralleling strategy in multi-processing mode. | |||
``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
the workers will process these pieces parallelly. ``False`` means | |||
different sub-process will process different batch. (default: ``False``) | |||
""" | |||
if num_workers < 0: | |||
raise ValueError("num_workers should not be negative") | |||
if timeout < 0: | |||
raise ValueError("timeout should not be negative") | |||
if divide and num_workers <= 1: | |||
raise ValueError("divide should not be set to True when num_workers <= 1") | |||
self.dataset = dataset | |||
self.num_workers = num_workers | |||
self.timeout = timeout | |||
self.divide = divide | |||
if sampler is None: | |||
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) | |||
else: | |||
self.sampler = sampler | |||
if divide: | |||
if self.sampler.batch_size <= self.num_workers: | |||
raise ValueError( | |||
"batch size must not smaller than num_workers in divide mode." | |||
) | |||
elif self.sampler.batch_size % self.num_workers: | |||
logger.warning( | |||
"batch size is not divisible by num_workers, may lose performance in divide mode." | |||
) | |||
if transform is None: | |||
self.transform = PseudoTransform() | |||
else: | |||
self.transform = transform | |||
if collator is None: | |||
self.collator = Collator() | |||
else: | |||
self.collator = collator | |||
self.__initialized = True | |||
def __iter__(self): | |||
if self.num_workers == 0: | |||
return _SerialDataLoaderIter(self) | |||
else: | |||
return _ParallelDataLoaderIter(self) | |||
def __len__(self): | |||
return len(self.sampler) | |||
class _BaseDataLoaderIter: | |||
def __init__(self, loader): | |||
self.dataset = loader.dataset | |||
self.sampler = loader.sampler | |||
self.seed = _random_seed_generator().__next__() | |||
self.transform = loader.transform | |||
self.collator = loader.collator | |||
self.num_workers = loader.num_workers | |||
self.timeout = loader.timeout | |||
self.divide = loader.divide | |||
self.num_processed = 0 | |||
def _get_next_batch(self): | |||
raise NotImplementedError | |||
def __len__(self): | |||
return len(self.sampler) | |||
def __iter__(self): | |||
return self | |||
def __next__(self): | |||
if self.num_processed >= len(self): | |||
raise StopIteration | |||
minibatch = self._get_next_batch() | |||
self.num_processed += 1 | |||
return minibatch | |||
class _SerialDataLoaderIter(_BaseDataLoaderIter): | |||
def __init__(self, loader): | |||
super(_SerialDataLoaderIter, self).__init__(loader) | |||
self.indices_iter = iter(self.sampler) | |||
def _get_next_batch(self): | |||
indices = next(self.indices_iter) | |||
items = [self.dataset[idx] for idx in indices] | |||
trans_items = self.transform.apply_batch(items) | |||
return self.collator.apply(trans_items) | |||
class _ParallelDataLoaderIter(_BaseDataLoaderIter): | |||
__initialized = False | |||
def __init__(self, loader): | |||
super(_ParallelDataLoaderIter, self).__init__(loader) | |||
self.task_queues = [ | |||
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
] | |||
self.feed_batch_idx = multiprocessing.Value("i", 0) | |||
self.target_batch_idx = multiprocessing.Value("i", 0) | |||
self.shutdown_flag = multiprocessing.Value("i", 0) | |||
self.trans_data_queues = [ | |||
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) | |||
] | |||
# use shared-memory queue implemented by pyarrow plasma store. | |||
from ._queue import PlasmaShmQueue | |||
self.batch_queue = PlasmaShmQueue(maxsize=2) | |||
self.task_feeding_worker = multiprocessing.Process( | |||
target=_task_feeding_loop, | |||
args=( | |||
iter(self.sampler), | |||
self.task_queues, | |||
self.num_workers, | |||
self.divide, | |||
self.shutdown_flag, | |||
self.feed_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
self.task_feeding_worker.start() | |||
self.workers = [] | |||
for worker_id in range(self.num_workers): | |||
worker = multiprocessing.Process( | |||
target=_worker_loop, | |||
args=( | |||
self.dataset, | |||
self.task_queues[worker_id], | |||
self.trans_data_queues[worker_id], | |||
self.transform, | |||
self.seed + worker_id + 1, | |||
self.shutdown_flag, | |||
), | |||
daemon=True, | |||
) | |||
worker.start() | |||
self.workers.append(worker) | |||
if self.divide: | |||
self.data_collecting_worker = multiprocessing.Process( | |||
target=_data_gathering_loop, | |||
args=( | |||
self.trans_data_queues, | |||
self.batch_queue, | |||
self.collator, | |||
len(self), | |||
self.num_workers, | |||
self.shutdown_flag, | |||
self.target_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
else: | |||
self.data_collecting_worker = multiprocessing.Process( | |||
target=_data_selecting_loop, | |||
args=( | |||
self.trans_data_queues, | |||
self.batch_queue, | |||
self.collator, | |||
len(self), | |||
self.num_workers, | |||
self.shutdown_flag, | |||
self.target_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
self.data_collecting_worker.start() | |||
self.__initialized = True | |||
def _check_workers(self): | |||
# Check the status of each worker. | |||
if not self.data_collecting_worker.is_alive(): | |||
exitcode = self.task_feeding_worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("data collecting worker died. {}".format(exitcode)) | |||
if not self.task_feeding_worker.is_alive(): | |||
exitcode = self.task_feeding_worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("task feeding worker died. {}".format(exitcode)) | |||
for worker_id, worker in enumerate(self.workers): | |||
if not worker.is_alive(): | |||
exitcode = worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) | |||
logger.debug("all workers are alive.") | |||
def _try_get_next_batch(self): | |||
start_time = time.time() | |||
while True: | |||
self._check_workers() | |||
try: | |||
return self.batch_queue.get(timeout=1) | |||
except queue.Empty: | |||
logger.debug("batch queue empty!") | |||
waited_time = time.time() - start_time | |||
if self.timeout > 0: | |||
if waited_time > self.timeout: | |||
raise RuntimeError("get_next_batch timeout!") | |||
def _get_next_batch(self): | |||
batch_data = self._try_get_next_batch() | |||
return batch_data | |||
def _shutdown(self): | |||
with self.shutdown_flag.get_lock(): | |||
self.shutdown_flag.value = 1 | |||
if self.task_feeding_worker.is_alive(): | |||
self.task_feeding_worker.terminate() | |||
self.task_feeding_worker.join() | |||
if self.data_collecting_worker.is_alive(): | |||
self.data_collecting_worker.terminate() | |||
self.data_collecting_worker.join() | |||
for worker in self.workers: | |||
if worker.is_alive(): | |||
worker.terminate() | |||
worker.join() | |||
for q in self.trans_data_queues: | |||
q.cancel_join_thread() | |||
q.close() | |||
for q in self.task_queues: | |||
q.cancel_join_thread() | |||
q.close() | |||
self.batch_queue.cancel_join_thread() | |||
self.batch_queue.close() | |||
def __del__(self): | |||
if self.__initialized: | |||
self._shutdown() | |||
def _task_feeding_loop( | |||
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx | |||
): | |||
# Feed the indices into the task queues | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
batch_idx = feed_batch_idx.value | |||
try: | |||
indices = next(indices_iter) | |||
except StopIteration: | |||
break | |||
if divide: | |||
# make sure all task_queues is ready for put | |||
while any([q.full() for q in task_queues]): | |||
if shutdown_flag.value == 1: | |||
return | |||
# divide into small pieces, feed to different workers. | |||
sub_num = math.ceil(len(indices) / num_workers) | |||
for worker_id in range(num_workers): | |||
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] | |||
task_queues[worker_id].put((batch_idx, sub_indices)) | |||
else: | |||
# distribute tasks to different workers uniformly. | |||
target_id = batch_idx % num_workers | |||
while task_queues[target_id].full(): | |||
if shutdown_flag.value == 1: | |||
return | |||
task_queues[target_id].put((batch_idx, indices)) | |||
with feed_batch_idx.get_lock(): | |||
feed_batch_idx.value += 1 | |||
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): | |||
# Get dataset items and do the transform | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
try: | |||
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) | |||
except queue.Empty: | |||
continue | |||
if len(indices) > 0: | |||
items = [dataset[idx] for idx in indices] | |||
trans_items = transform.apply_batch(items) | |||
else: | |||
# in case of incomplete last batch | |||
trans_items = () | |||
while True: | |||
try: | |||
trans_data_queue.put((batch_idx, trans_items), timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch part queue is full!") | |||
def _data_gathering_loop( | |||
trans_data_queues, | |||
batch_queue, | |||
collator, | |||
length, | |||
num_workers, | |||
shutdown_flag, | |||
target_idx, | |||
): | |||
# Gathering the small pieces of batch data into full batch data | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
target_batch_idx = target_idx.value | |||
if target_batch_idx >= length: | |||
break | |||
full_trans_items = [] | |||
for worker_id in range(num_workers): | |||
while True: | |||
try: | |||
batch_idx, trans_items = trans_data_queues[worker_id].get( | |||
timeout=MP_QUEUE_GET_TIMEOUT | |||
) | |||
break | |||
except queue.Empty: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug( | |||
"worker:{} data queue get timeout! target batch idx:{}".format( | |||
worker_id, target_batch_idx | |||
) | |||
) | |||
if batch_idx != target_batch_idx: | |||
raise RuntimeError( | |||
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format( | |||
worker_id | |||
) | |||
) | |||
else: | |||
full_trans_items.extend(trans_items) | |||
# Merge different parts into a batch. | |||
full_batch = collator.apply(full_trans_items) | |||
while True: | |||
try: | |||
batch_queue.put(full_batch, timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch queue is full!") | |||
with target_idx.get_lock(): | |||
target_idx.value += 1 | |||
batch_queue.disconnect_client() | |||
def _data_selecting_loop( | |||
trans_data_queues, | |||
batch_queue, | |||
collator, | |||
length, | |||
num_workers, | |||
shutdown_flag, | |||
target_idx, | |||
): | |||
# Make sure that batch is generated exactly with the same order as generated indices | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
target_batch_idx = target_idx.value | |||
if target_batch_idx >= length: | |||
break | |||
target_worker_id = target_batch_idx % num_workers | |||
while True: | |||
try: | |||
batch_idx, trans_items = trans_data_queues[target_worker_id].get( | |||
timeout=MP_QUEUE_GET_TIMEOUT | |||
) | |||
batch_data = collator.apply(trans_items) | |||
break | |||
except queue.Empty: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug( | |||
"worker:{} data queue get timeout! target batch idx:{}".format( | |||
target_worker_id, target_batch_idx | |||
) | |||
) | |||
if batch_idx != target_batch_idx: | |||
raise RuntimeError( | |||
"batch_idx {} mismatch the target_batch_idx {}".format( | |||
batch_idx, target_batch_idx | |||
) | |||
) | |||
while True: | |||
try: | |||
batch_queue.put(batch_data, timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch queue is full!") | |||
with target_idx.get_lock(): | |||
target_idx.value += 1 | |||
batch_queue.disconnect_client() |
@@ -1,10 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||
from .vision import * |
@@ -1,73 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABC, abstractmethod | |||
from typing import Tuple | |||
class Dataset(ABC): | |||
r""" | |||
An abstract class for all Datasets | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
class MapDataset(Dataset): | |||
r""" | |||
An abstract class for map data | |||
__getitem__ and __len__ method are aditionally needed | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
@abstractmethod | |||
def __getitem__(self, index): | |||
pass | |||
@abstractmethod | |||
def __len__(self): | |||
pass | |||
class StreamDataset(Dataset): | |||
r""" | |||
An abstract class for stream data | |||
__iter__ method is aditionally needed | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
@abstractmethod | |||
def __iter__(self): | |||
pass | |||
class ArrayDataset(MapDataset): | |||
def __init__(self, *arrays): | |||
r""" | |||
ArrayDataset is a dataset for numpy array data, one or more numpy arrays | |||
are needed to initiate the dataset. And the dimensions represented sample number | |||
are expected to be the same. | |||
""" | |||
super().__init__() | |||
if not all(len(arrays[0]) == len(array) for array in arrays): | |||
raise ValueError("lengths of input arrays are inconsistent") | |||
self.arrays = arrays | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) |
@@ -1,17 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .cifar import CIFAR10, CIFAR100 | |||
from .cityscapes import Cityscapes | |||
from .coco import COCO | |||
from .folder import ImageFolder | |||
from .imagenet import ImageNet | |||
from .meta_vision import VisionDataset | |||
from .mnist import MNIST | |||
from .objects365 import Objects365 | |||
from .voc import PascalVOC |
@@ -1,171 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import pickle | |||
import tarfile | |||
from typing import Tuple | |||
import numpy as np | |||
from ....logger import get_logger | |||
from .meta_vision import VisionDataset | |||
from .utils import _default_dataset_root, load_raw_data_from_url | |||
logger = get_logger(__name__) | |||
class CIFAR10(VisionDataset): | |||
r""" ``Dataset`` for CIFAR10 meta data | |||
""" | |||
url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
raw_file_name = "cifar-10-python.tar.gz" | |||
raw_file_md5 = "c58f30108f718f92721af3b95e74349a" | |||
raw_file_dir = "cifar-10-batches-py" | |||
train_batch = [ | |||
"data_batch_1", | |||
"data_batch_2", | |||
"data_batch_3", | |||
"data_batch_4", | |||
"data_batch_5", | |||
] | |||
test_batch = ["test_batch"] | |||
meta_info = {"name": "batches.meta"} | |||
def __init__( | |||
self, | |||
root: str = None, | |||
train: bool = True, | |||
download: bool = True, | |||
timeout: int = 500, | |||
): | |||
super().__init__(root, order=("image", "image_category")) | |||
self.timeout = timeout | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
if not os.path.exists(self.root): | |||
os.makedirs(self.root) | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
if download: | |||
logger.debug( | |||
"dir %s does not exist, will be automatically created", | |||
self.root, | |||
) | |||
os.makedirs(self.root) | |||
else: | |||
raise ValueError("dir %s does not exist" % self.root) | |||
self.target_file = os.path.join(self.root, self.raw_file_dir) | |||
# check existence of target pickle dir, if exists load the | |||
# pickle file no matter what download is set | |||
if os.path.exists(self.target_file): | |||
if train: | |||
self.arrays = self.bytes2array(self.train_batch) | |||
else: | |||
self.arrays = self.bytes2array(self.test_batch) | |||
else: | |||
if download: | |||
self.download() | |||
if train: | |||
self.arrays = self.bytes2array(self.train_batch) | |||
else: | |||
self.arrays = self.bytes2array(self.test_batch) | |||
else: | |||
raise ValueError( | |||
"dir does not contain target file %s, please set download=True" | |||
% (self.target_file) | |||
) | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def meta(self): | |||
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
with open(meta_path, "rb") as f: | |||
meta = pickle.load(f, encoding="bytes") | |||
return meta | |||
def download(self): | |||
url = self.url_path + self.raw_file_name | |||
load_raw_data_from_url( | |||
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||
) | |||
self.process() | |||
def untar(self, file_path, dirs): | |||
assert file_path.endswith(".tar.gz") | |||
logger.debug("untar file %s to %s", file_path, dirs) | |||
t = tarfile.open(file_path) | |||
t.extractall(path=dirs) | |||
def bytes2array(self, filenames): | |||
data = [] | |||
label = [] | |||
for filename in filenames: | |||
path = os.path.join(self.root, self.raw_file_dir, filename) | |||
logger.debug("unpickle file %s", path) | |||
with open(path, "rb") as fo: | |||
dic = pickle.load(fo, encoding="bytes") | |||
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
data.extend(list(batch_data[..., [2, 1, 0]])) | |||
label.extend(dic[b"labels"]) | |||
label = np.array(label, dtype=np.int32) | |||
return (data, label) | |||
def process(self): | |||
logger.info("process raw data ...") | |||
self.untar(os.path.join(self.root, self.raw_file_name), self.root) | |||
class CIFAR100(CIFAR10): | |||
url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
raw_file_name = "cifar-100-python.tar.gz" | |||
raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85" | |||
raw_file_dir = "cifar-100-python" | |||
train_batch = ["train"] | |||
test_batch = ["test"] | |||
meta_info = {"name": "meta"} | |||
@property | |||
def meta(self): | |||
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
with open(meta_path, "rb") as f: | |||
meta = pickle.load(f, encoding="bytes") | |||
return meta | |||
def bytes2array(self, filenames): | |||
data = [] | |||
fine_label = [] | |||
coarse_label = [] | |||
for filename in filenames: | |||
path = os.path.join(self.root, self.raw_file_dir, filename) | |||
logger.debug("unpickle file %s", path) | |||
with open(path, "rb") as fo: | |||
dic = pickle.load(fo, encoding="bytes") | |||
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
data.extend(list(batch_data[..., [2, 1, 0]])) | |||
fine_label.extend(dic[b"fine_labels"]) | |||
coarse_label.extend(dic[b"coarse_labels"]) | |||
fine_label = np.array(fine_label, dtype=np.int32) | |||
coarse_label = np.array(coarse_label, dtype=np.int32) | |||
return data, fine_label, coarse_label |
@@ -1,151 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to torchvision | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class Cityscapes(VisionDataset): | |||
r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"mask", | |||
"info", | |||
) | |||
def __init__(self, root, image_set, mode, *, order=None): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
city_root = self.root | |||
if not os.path.isdir(city_root): | |||
raise RuntimeError("Dataset not found or corrupted.") | |||
self.mode = mode | |||
self.images_dir = os.path.join(city_root, "leftImg8bit", image_set) | |||
self.masks_dir = os.path.join(city_root, self.mode, image_set) | |||
self.images, self.masks = [], [] | |||
# self.target_type = ["instance", "semantic", "polygon", "color"] | |||
# for semantic segmentation | |||
if mode == "gtFine": | |||
valid_modes = ("train", "test", "val") | |||
else: | |||
valid_modes = ("train", "train_extra", "val") | |||
for city in os.listdir(self.images_dir): | |||
img_dir = os.path.join(self.images_dir, city) | |||
mask_dir = os.path.join(self.masks_dir, city) | |||
for file_name in os.listdir(img_dir): | |||
mask_name = "{}_{}".format( | |||
file_name.split("_leftImg8bit")[0], | |||
self._get_target_suffix(self.mode, "semantic"), | |||
) | |||
self.images.append(os.path.join(img_dir, file_name)) | |||
self.masks.append(os.path.join(mask_dir, mask_name)) | |||
def __getitem__(self, index): | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "mask": | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
mask = self._trans_mask(mask) | |||
mask = mask[:, :, np.newaxis] | |||
target.append(mask) | |||
elif k == "info": | |||
if image is None: | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
info = [image.shape[0], image.shape[1], self.images[index]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.images) | |||
def _trans_mask(self, mask): | |||
trans_labels = [ | |||
7, | |||
8, | |||
11, | |||
12, | |||
13, | |||
17, | |||
19, | |||
20, | |||
21, | |||
22, | |||
23, | |||
24, | |||
25, | |||
26, | |||
27, | |||
28, | |||
31, | |||
32, | |||
33, | |||
] | |||
label = np.ones(mask.shape) * 255 | |||
for i, tl in enumerate(trans_labels): | |||
label[mask == tl] = i | |||
return label.astype(np.uint8) | |||
def _get_target_suffix(self, mode, target_type): | |||
if target_type == "instance": | |||
return "{}_instanceIds.png".format(mode) | |||
elif target_type == "semantic": | |||
return "{}_labelIds.png".format(mode) | |||
elif target_type == "color": | |||
return "{}_color.png".format(mode) | |||
else: | |||
return "{}_polygons.json".format(mode) | |||
def _load_json(self, path): | |||
with open(path, "r") as file: | |||
data = json.load(file) | |||
return data | |||
class_names = ( | |||
"road", | |||
"sidewalk", | |||
"building", | |||
"wall", | |||
"fence", | |||
"pole", | |||
"traffic light", | |||
"traffic sign", | |||
"vegetation", | |||
"terrain", | |||
"sky", | |||
"person", | |||
"rider", | |||
"car", | |||
"truck", | |||
"bus", | |||
"train", | |||
"motorcycle", | |||
"bicycle", | |||
) |
@@ -1,366 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to maskrcnn-benchmark | |||
# MIT License | |||
# | |||
# Copyright (c) 2018 Facebook | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
min_keypoints_per_image = 10 | |||
def _count_visible_keypoints(anno): | |||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
def has_valid_annotation(anno, order): | |||
# if it"s empty, there is no annotation | |||
if len(anno) == 0: | |||
return False | |||
if "boxes" in order or "boxes_category" in order: | |||
if "bbox" not in anno[0]: | |||
return False | |||
if "keypoints" in order: | |||
if "keypoints" not in anno[0]: | |||
return False | |||
# for keypoint detection tasks, only consider valid images those | |||
# containing at least min_keypoints_per_image | |||
if _count_visible_keypoints(anno) < min_keypoints_per_image: | |||
return False | |||
return True | |||
class COCO(VisionDataset): | |||
r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"keypoints", | |||
# TODO: need to check | |||
# "polygons", | |||
"info", | |||
) | |||
def __init__( | |||
self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
with open(ann_file, "r") as f: | |||
dataset = json.load(f) | |||
self.imgs = dict() | |||
for img in dataset["images"]: | |||
# for saving memory | |||
if "license" in img: | |||
del img["license"] | |||
if "coco_url" in img: | |||
del img["coco_url"] | |||
if "date_captured" in img: | |||
del img["date_captured"] | |||
if "flickr_url" in img: | |||
del img["flickr_url"] | |||
self.imgs[img["id"]] = img | |||
self.img_to_anns = defaultdict(list) | |||
for ann in dataset["annotations"]: | |||
# for saving memory | |||
if ( | |||
"boxes" not in self.order | |||
and "boxes_category" not in self.order | |||
and "bbox" in ann | |||
): | |||
del ann["bbox"] | |||
if "polygons" not in self.order and "segmentation" in ann: | |||
del ann["segmentation"] | |||
self.img_to_anns[ann["image_id"]].append(ann) | |||
self.cats = dict() | |||
for cat in dataset["categories"]: | |||
self.cats[cat["id"]] = cat | |||
self.ids = list(sorted(self.imgs.keys())) | |||
# filter images without detection annotations | |||
if remove_images_without_annotations: | |||
ids = [] | |||
for img_id in self.ids: | |||
anno = self.img_to_anns[img_id] | |||
# filter crowd annotations | |||
anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
anno = [ | |||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
] | |||
if has_valid_annotation(anno, order): | |||
ids.append(img_id) | |||
self.img_to_anns[img_id] = anno | |||
else: | |||
del self.imgs[img_id] | |||
del self.img_to_anns[img_id] | |||
self.ids = ids | |||
self.json_category_id_to_contiguous_id = { | |||
v: i + 1 for i, v in enumerate(sorted(self.cats.keys())) | |||
} | |||
self.contiguous_category_id_to_json_id = { | |||
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
} | |||
def __getitem__(self, index): | |||
img_id = self.ids[index] | |||
anno = self.img_to_anns[img_id] | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
file_name = self.imgs[img_id]["file_name"] | |||
path = os.path.join(self.root, file_name) | |||
image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
boxes = [obj["bbox"] for obj in anno] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
# transfer boxes from xywh to xyxy | |||
boxes[:, 2:] += boxes[:, :2] | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
boxes_category = [obj["category_id"] for obj in anno] | |||
boxes_category = [ | |||
self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "keypoints": | |||
keypoints = [obj["keypoints"] for obj in anno] | |||
keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||
-1, len(self.keypoint_names), 3 | |||
) | |||
target.append(keypoints) | |||
elif k == "polygons": | |||
polygons = [obj["segmentation"] for obj in anno] | |||
polygons = [ | |||
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||
for ps in polygons | |||
] | |||
target.append(polygons) | |||
elif k == "info": | |||
info = self.imgs[img_id] | |||
info = [info["height"], info["width"], info["file_name"]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.ids) | |||
def get_img_info(self, index): | |||
img_id = self.ids[index] | |||
img_info = self.imgs[img_id] | |||
return img_info | |||
class_names = ( | |||
"person", | |||
"bicycle", | |||
"car", | |||
"motorcycle", | |||
"airplane", | |||
"bus", | |||
"train", | |||
"truck", | |||
"boat", | |||
"traffic light", | |||
"fire hydrant", | |||
"stop sign", | |||
"parking meter", | |||
"bench", | |||
"bird", | |||
"cat", | |||
"dog", | |||
"horse", | |||
"sheep", | |||
"cow", | |||
"elephant", | |||
"bear", | |||
"zebra", | |||
"giraffe", | |||
"backpack", | |||
"umbrella", | |||
"handbag", | |||
"tie", | |||
"suitcase", | |||
"frisbee", | |||
"skis", | |||
"snowboard", | |||
"sports ball", | |||
"kite", | |||
"baseball bat", | |||
"baseball glove", | |||
"skateboard", | |||
"surfboard", | |||
"tennis racket", | |||
"bottle", | |||
"wine glass", | |||
"cup", | |||
"fork", | |||
"knife", | |||
"spoon", | |||
"bowl", | |||
"banana", | |||
"apple", | |||
"sandwich", | |||
"orange", | |||
"broccoli", | |||
"carrot", | |||
"hot dog", | |||
"pizza", | |||
"donut", | |||
"cake", | |||
"chair", | |||
"couch", | |||
"potted plant", | |||
"bed", | |||
"dining table", | |||
"toilet", | |||
"tv", | |||
"laptop", | |||
"mouse", | |||
"remote", | |||
"keyboard", | |||
"cell phone", | |||
"microwave", | |||
"oven", | |||
"toaster", | |||
"sink", | |||
"refrigerator", | |||
"book", | |||
"clock", | |||
"vase", | |||
"scissors", | |||
"teddy bear", | |||
"hair drier", | |||
"toothbrush", | |||
) | |||
classes_originID = { | |||
"person": 1, | |||
"bicycle": 2, | |||
"car": 3, | |||
"motorcycle": 4, | |||
"airplane": 5, | |||
"bus": 6, | |||
"train": 7, | |||
"truck": 8, | |||
"boat": 9, | |||
"traffic light": 10, | |||
"fire hydrant": 11, | |||
"stop sign": 13, | |||
"parking meter": 14, | |||
"bench": 15, | |||
"bird": 16, | |||
"cat": 17, | |||
"dog": 18, | |||
"horse": 19, | |||
"sheep": 20, | |||
"cow": 21, | |||
"elephant": 22, | |||
"bear": 23, | |||
"zebra": 24, | |||
"giraffe": 25, | |||
"backpack": 27, | |||
"umbrella": 28, | |||
"handbag": 31, | |||
"tie": 32, | |||
"suitcase": 33, | |||
"frisbee": 34, | |||
"skis": 35, | |||
"snowboard": 36, | |||
"sports ball": 37, | |||
"kite": 38, | |||
"baseball bat": 39, | |||
"baseball glove": 40, | |||
"skateboard": 41, | |||
"surfboard": 42, | |||
"tennis racket": 43, | |||
"bottle": 44, | |||
"wine glass": 46, | |||
"cup": 47, | |||
"fork": 48, | |||
"knife": 49, | |||
"spoon": 50, | |||
"bowl": 51, | |||
"banana": 52, | |||
"apple": 53, | |||
"sandwich": 54, | |||
"orange": 55, | |||
"broccoli": 56, | |||
"carrot": 57, | |||
"hot dog": 58, | |||
"pizza": 59, | |||
"donut": 60, | |||
"cake": 61, | |||
"chair": 62, | |||
"couch": 63, | |||
"potted plant": 64, | |||
"bed": 65, | |||
"dining table": 67, | |||
"toilet": 70, | |||
"tv": 72, | |||
"laptop": 73, | |||
"mouse": 74, | |||
"remote": 75, | |||
"keyboard": 76, | |||
"cell phone": 77, | |||
"microwave": 78, | |||
"oven": 79, | |||
"toaster": 80, | |||
"sink": 81, | |||
"refrigerator": 82, | |||
"book": 84, | |||
"clock": 85, | |||
"vase": 86, | |||
"scissors": 87, | |||
"teddy bear": 88, | |||
"hair drier": 89, | |||
"toothbrush": 90, | |||
} | |||
keypoint_names = ( | |||
"nose", | |||
"left_eye", | |||
"right_eye", | |||
"left_ear", | |||
"right_ear", | |||
"left_shoulder", | |||
"right_shoulder", | |||
"left_elbow", | |||
"right_elbow", | |||
"left_wrist", | |||
"right_wrist", | |||
"left_hip", | |||
"right_hip", | |||
"left_knee", | |||
"right_knee", | |||
"left_ankle", | |||
"right_ankle", | |||
) |
@@ -1,90 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# BSD 3-Clause License | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import os | |||
from typing import Dict, List, Tuple | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
from .utils import is_img | |||
class ImageFolder(VisionDataset): | |||
def __init__(self, root: str, check_valid_func=None, class_name: bool = False): | |||
r""" | |||
ImageFolder is a class for loading image data and labels from a organized folder. | |||
the folder is expected to be organized as followed | |||
root/cls/xxx.img_ext | |||
labels are indices of sorted classes in the root directory | |||
:param root: root directory of an image folder | |||
:param loader: a function used to load image from path, | |||
if ``None``, default function that loads | |||
images with PILwill be called | |||
:param check_valid_func: a function used to check if files in folder are | |||
expected image files, if ``None``, default function | |||
that checks file extensions will be called | |||
:param class_name: if ``True``, return class name instead of class index | |||
""" | |||
super().__init__(root, order=("image", "image_category")) | |||
self.root = root | |||
if check_valid_func is not None: | |||
self.check_valid = check_valid_func | |||
else: | |||
self.check_valid = is_img | |||
self.class_name = class_name | |||
self.class_dict = self.collect_class() | |||
self.samples = self.collect_samples() | |||
def collect_samples(self) -> List: | |||
samples = [] | |||
directory = os.path.expanduser(self.root) | |||
for key in sorted(self.class_dict.keys()): | |||
d = os.path.join(directory, key) | |||
if not os.path.isdir(d): | |||
continue | |||
for r, _, filename in sorted(os.walk(d, followlinks=True)): | |||
for name in sorted(filename): | |||
path = os.path.join(r, name) | |||
if self.check_valid(path): | |||
if self.class_name: | |||
samples.append((path, key)) | |||
else: | |||
samples.append((path, self.class_dict[key])) | |||
return samples | |||
def collect_class(self) -> Dict: | |||
classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |||
classes.sort() | |||
return {classes[i]: np.int32(i) for i in range(len(classes))} | |||
def __getitem__(self, index: int) -> Tuple: | |||
path, label = self.samples[index] | |||
img = cv2.imread(path, cv2.IMREAD_COLOR) | |||
return img, label | |||
def __len__(self): | |||
return len(self.samples) |
@@ -1,248 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import os | |||
import shutil | |||
from tqdm import tqdm | |||
from ....core.serialization import load, save | |||
from ....distributed.util import is_distributed | |||
from ....logger import get_logger | |||
from .folder import ImageFolder | |||
from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||
logger = get_logger(__name__) | |||
class ImageNet(ImageFolder): | |||
r""" | |||
Load ImageNet from raw files or folder, expected folder looks like | |||
.. code-block:: bash | |||
${root}/ | |||
| [REQUIRED TAR FILES] | |||
|- ILSVRC2012_img_train.tar | |||
|- ILSVRC2012_img_val.tar | |||
|- ILSVRC2012_devkit_t12.tar.gz | |||
| [OPTIONAL IMAGE FOLDERS] | |||
|- train/cls/xxx.${img_ext} | |||
|- val/cls/xxx.${img_ext} | |||
|- ILSVRC2012_devkit_t12/data/meta.mat | |||
|- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
If the image folders don't exist, raw tar files are required to get extracted and processed. | |||
""" | |||
raw_file_meta = { | |||
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||
} # ImageNet raw files | |||
default_train_dir = "train" | |||
default_val_dir = "val" | |||
default_devkit_dir = "ILSVRC2012_devkit_t12" | |||
def __init__(self, root: str = None, train: bool = True, **kwargs): | |||
r""" | |||
initialization: | |||
* if ``root`` contains ``self.target_folder`` depent on ``train``: | |||
* initialize ImageFolder with target_folder | |||
* else: | |||
* if all raw files are in ``root``: | |||
* parse ``self.target_folder`` from raw files | |||
* initialize ImageFolder with ``self.target_folder`` | |||
* else: | |||
* raise error | |||
:param root: root directory of imagenet data, if root is ``None``, used default_dataset_root | |||
:param train: if ``True``, load the train split, otherwise load the validation split | |||
""" | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
raise FileNotFoundError("dir %s does not exist" % self.root) | |||
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
if not os.path.exists(self.devkit_dir): | |||
logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||
self._prepare_devkit() | |||
self.train = train | |||
if train: | |||
self.target_folder = os.path.join(self.root, self.default_train_dir) | |||
else: | |||
self.target_folder = os.path.join(self.root, self.default_val_dir) | |||
if not os.path.exists(self.target_folder): | |||
logger.warning( | |||
"expected image folder %s does not exist, try to load from raw file", | |||
self.target_folder, | |||
) | |||
if not self.check_raw_file(): | |||
raise FileNotFoundError( | |||
"expected image folder %s does not exist, and raw files do not exist in %s" | |||
% (self.target_folder, self.root) | |||
) | |||
elif is_distributed(): | |||
raise RuntimeError( | |||
"extracting raw file shouldn't be done in distributed mode, use single process instead" | |||
) | |||
elif train: | |||
self._prepare_train() | |||
else: | |||
self._prepare_val() | |||
super().__init__(self.target_folder, **kwargs) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def valid_ground_truth(self): | |||
groud_truth_path = os.path.join( | |||
self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" | |||
) | |||
if os.path.exists(groud_truth_path): | |||
with open(groud_truth_path, "r") as f: | |||
val_labels = f.readlines() | |||
return [int(val_label) for val_label in val_labels] | |||
else: | |||
raise FileNotFoundError( | |||
"valid ground truth file %s does not exist" % groud_truth_path | |||
) | |||
@property | |||
def meta(self): | |||
try: | |||
return load(os.path.join(self.devkit_dir, "meta.pkl")) | |||
except FileNotFoundError: | |||
import scipy.io | |||
meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | |||
if not os.path.exists(meta_path): | |||
raise FileNotFoundError("meta file %s does not exist" % meta_path) | |||
meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||
nums_children = list(zip(*meta))[4] | |||
meta = [ | |||
meta[idx] | |||
for idx, num_children in enumerate(nums_children) | |||
if num_children == 0 | |||
] | |||
idcs, wnids, classes = list(zip(*meta))[:3] | |||
classes = [tuple(clss.split(", ")) for clss in classes] | |||
idx_to_wnid = dict(zip(idcs, wnids)) | |||
wnid_to_classes = dict(zip(wnids, classes)) | |||
logger.info( | |||
"saving cached meta file to %s", | |||
os.path.join(self.devkit_dir, "meta.pkl"), | |||
) | |||
save( | |||
(idx_to_wnid, wnid_to_classes), | |||
os.path.join(self.devkit_dir, "meta.pkl"), | |||
) | |||
return idx_to_wnid, wnid_to_classes | |||
def check_raw_file(self) -> bool: | |||
return all( | |||
[ | |||
os.path.exists(os.path.join(self.root, value[0])) | |||
for _, value in self.raw_file_meta.items() | |||
] | |||
) | |||
def _organize_val_data(self): | |||
id2wnid = self.meta[0] | |||
val_idcs = self.valid_ground_truth | |||
val_wnids = [id2wnid[idx] for idx in val_idcs] | |||
val_images = sorted( | |||
[ | |||
os.path.join(self.target_folder, image) | |||
for image in os.listdir(self.target_folder) | |||
] | |||
) | |||
logger.debug("mkdir for val set wnids") | |||
for wnid in set(val_wnids): | |||
os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) | |||
logger.debug("mv val images into wnids dir") | |||
for wnid, img_file in tqdm(zip(val_wnids, val_images)): | |||
shutil.move( | |||
img_file, | |||
os.path.join( | |||
self.root, self.default_val_dir, wnid, os.path.basename(img_file) | |||
), | |||
) | |||
def _prepare_val(self): | |||
assert not self.train | |||
raw_filename, checksum = self.raw_file_meta["val"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum valid tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract valid tar file... this may take 10-20 minutes") | |||
untar(os.path.join(self.root, raw_file), self.target_folder) | |||
self._organize_val_data() | |||
def _prepare_train(self): | |||
assert self.train | |||
raw_filename, checksum = self.raw_file_meta["train"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum train tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract train tar file.. this may take several hours") | |||
untar( | |||
os.path.join(self.root, raw_file), self.target_folder, | |||
) | |||
paths = [ | |||
os.path.join(self.target_folder, child_dir) | |||
for child_dir in os.listdir(self.target_folder) | |||
] | |||
for path in tqdm(paths): | |||
untar(path, os.path.splitext(path)[0], remove=True) | |||
def _prepare_devkit(self): | |||
raw_filename, checksum = self.raw_file_meta["devkit"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum devkit tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract devkit file..") | |||
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) |
@@ -1,41 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import os | |||
from ..meta_dataset import MapDataset | |||
class VisionDataset(MapDataset): | |||
_repr_indent = 4 | |||
def __init__(self, root, *, order=None, supported_order=None): | |||
if isinstance(root, (str, bytes)): | |||
root = os.path.expanduser(root) | |||
self.root = root | |||
if order is None: | |||
order = ("image",) | |||
if not isinstance(order, collections.abc.Sequence): | |||
raise ValueError( | |||
"order should be a sequence, but got order={}".format(order) | |||
) | |||
if supported_order is not None: | |||
assert isinstance(supported_order, collections.abc.Sequence) | |||
for k in order: | |||
if k not in supported_order: | |||
raise NotImplementedError("{} is unsupported data type".format(k)) | |||
self.order = order | |||
def __getitem__(self, index): | |||
raise NotImplementedError | |||
def __len__(self): | |||
raise NotImplementedError |
@@ -1,197 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import gzip | |||
import os | |||
import struct | |||
from typing import Tuple | |||
import numpy as np | |||
from tqdm import tqdm | |||
from ....logger import get_logger | |||
from .meta_vision import VisionDataset | |||
from .utils import _default_dataset_root, load_raw_data_from_url | |||
logger = get_logger(__name__) | |||
class MNIST(VisionDataset): | |||
r""" ``Dataset`` for MNIST meta data | |||
""" | |||
url_path = "http://yann.lecun.com/exdb/mnist/" | |||
""" | |||
url prefix for downloading raw file | |||
""" | |||
raw_file_name = [ | |||
"train-images-idx3-ubyte.gz", | |||
"train-labels-idx1-ubyte.gz", | |||
"t10k-images-idx3-ubyte.gz", | |||
"t10k-labels-idx1-ubyte.gz", | |||
] | |||
""" | |||
raw file names of both training set and test set (10k) | |||
""" | |||
raw_file_md5 = [ | |||
"f68b3c2dcbeaaa9fbdd348bbdeb94873", | |||
"d53e105ee54ea40749a09fcbcd1e9432", | |||
"9fb629c4189551a2d022fa330f9573f3", | |||
"ec29112dd5afa0611ce80d1b7f02629c", | |||
] | |||
""" | |||
md5 for checking raw files | |||
""" | |||
def __init__( | |||
self, | |||
root: str = None, | |||
train: bool = True, | |||
download: bool = True, | |||
timeout: int = 500, | |||
): | |||
r""" | |||
:param root: path for mnist dataset downloading or loading, if ``None``, | |||
set ``root`` to the ``_default_root`` | |||
:param train: if ``True``, loading trainingset, else loading test set | |||
:param download: if raw files do not exists and download sets to ``True``, | |||
download raw files and process, otherwise raise ValueError, default is True | |||
""" | |||
super().__init__(root, order=("image", "image_category")) | |||
self.timeout = timeout | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
if not os.path.exists(self.root): | |||
os.makedirs(self.root) | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
if download: | |||
logger.debug( | |||
"dir %s does not exist, will be automatically created", | |||
self.root, | |||
) | |||
os.makedirs(self.root) | |||
else: | |||
raise ValueError("dir %s does not exist" % self.root) | |||
if self._check_raw_files(): | |||
self.process(train) | |||
elif download: | |||
self.download() | |||
self.process(train) | |||
else: | |||
raise ValueError( | |||
"root does not contain valid raw files, please set download=True" | |||
) | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def meta(self): | |||
return self._meta_data | |||
def _check_raw_files(self): | |||
return all( | |||
[ | |||
os.path.exists(os.path.join(self.root, path)) | |||
for path in self.raw_file_name | |||
] | |||
) | |||
def download(self): | |||
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | |||
url = self.url_path + file_name | |||
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||
def process(self, train): | |||
# load raw files and transform them into meta data and datasets Tuple(np.array) | |||
logger.info("process the raw files of %s set...", "train" if train else "test") | |||
if train: | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[0]) | |||
) | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[1]) | |||
) | |||
else: | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[2]) | |||
) | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[3]) | |||
) | |||
self._meta_data = { | |||
"images": meta_data_images, | |||
"labels": meta_data_labels, | |||
} | |||
self.arrays = (images, labels.astype(np.int32)) | |||
def parse_idx3(idx3_file): | |||
# parse idx3 file to meta data and data in numpy array (images) | |||
logger.debug("parse idx3 file %s ...", idx3_file) | |||
assert idx3_file.endswith(".gz") | |||
with gzip.open(idx3_file, "rb") as f: | |||
bin_data = f.read() | |||
# parse meta data | |||
offset = 0 | |||
fmt_header = ">iiii" | |||
magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) | |||
meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} | |||
# parse images | |||
image_size = height * width | |||
offset += struct.calcsize(fmt_header) | |||
fmt_image = ">" + str(image_size) + "B" | |||
images = [] | |||
bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
for image in struct.iter_unpack(fmt_image, bin_data[offset:]): | |||
images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) | |||
bar.update() | |||
bar.close() | |||
return meta_data, images | |||
def parse_idx1(idx1_file): | |||
# parse idx1 file to meta data and data in numpy array (labels) | |||
logger.debug("parse idx1 file %s ...", idx1_file) | |||
assert idx1_file.endswith(".gz") | |||
with gzip.open(idx1_file, "rb") as f: | |||
bin_data = f.read() | |||
# parse meta data | |||
offset = 0 | |||
fmt_header = ">ii" | |||
magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) | |||
meta_data = {"magic": magic, "imgs": imgs} | |||
# parse labels | |||
offset += struct.calcsize(fmt_header) | |||
fmt_image = ">B" | |||
labels = np.empty(imgs, dtype=int) | |||
bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): | |||
labels[i] = label[0] | |||
bar.update() | |||
bar.close() | |||
return meta_data, labels |
@@ -1,498 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to maskrcnn-benchmark | |||
# MIT License | |||
# | |||
# Copyright (c) 2018 Facebook | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class Objects365(VisionDataset): | |||
r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"info", | |||
) | |||
def __init__( | |||
self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
with open(ann_file, "r") as f: | |||
dataset = json.load(f) | |||
self.imgs = dict() | |||
for img in dataset["images"]: | |||
self.imgs[img["id"]] = img | |||
self.img_to_anns = defaultdict(list) | |||
for ann in dataset["annotations"]: | |||
# for saving memory | |||
if ( | |||
"boxes" not in self.order | |||
and "boxes_category" not in self.order | |||
and "bbox" in ann | |||
): | |||
del ann["bbox"] | |||
self.img_to_anns[ann["image_id"]].append(ann) | |||
self.cats = dict() | |||
for cat in dataset["categories"]: | |||
self.cats[cat["id"]] = cat | |||
self.ids = list(sorted(self.imgs.keys())) | |||
# filter images without detection annotations | |||
if remove_images_without_annotations: | |||
ids = [] | |||
for img_id in self.ids: | |||
anno = self.img_to_anns[img_id] | |||
# filter crowd annotations | |||
anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
anno = [ | |||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
] | |||
if len(anno) > 0: | |||
ids.append(img_id) | |||
self.img_to_anns[img_id] = anno | |||
else: | |||
del self.imgs[img_id] | |||
del self.img_to_anns[img_id] | |||
self.ids = ids | |||
self.json_category_id_to_contiguous_id = { | |||
v: i + 1 for i, v in enumerate(sorted(self.cats.keys())) | |||
} | |||
self.contiguous_category_id_to_json_id = { | |||
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
} | |||
def __getitem__(self, index): | |||
img_id = self.ids[index] | |||
anno = self.img_to_anns[img_id] | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
file_name = self.imgs[img_id]["file_name"] | |||
path = os.path.join(self.root, file_name) | |||
image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
boxes = [obj["bbox"] for obj in anno] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
# transfer boxes from xywh to xyxy | |||
boxes[:, 2:] += boxes[:, :2] | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
boxes_category = [obj["category_id"] for obj in anno] | |||
boxes_category = [ | |||
self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "info": | |||
info = self.imgs[img_id] | |||
info = [info["height"], info["width"], info["file_name"]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.ids) | |||
def get_img_info(self, index): | |||
img_id = self.ids[index] | |||
img_info = self.imgs[img_id] | |||
return img_info | |||
class_names = ( | |||
"person", | |||
"sneakers", | |||
"chair", | |||
"hat", | |||
"lamp", | |||
"bottle", | |||
"cabinet/shelf", | |||
"cup", | |||
"car", | |||
"glasses", | |||
"picture/frame", | |||
"desk", | |||
"handbag", | |||
"street lights", | |||
"book", | |||
"plate", | |||
"helmet", | |||
"leather shoes", | |||
"pillow", | |||
"glove", | |||
"potted plant", | |||
"bracelet", | |||
"flower", | |||
"tv", | |||
"storage box", | |||
"vase", | |||
"bench", | |||
"wine glass", | |||
"boots", | |||
"bowl", | |||
"dining table", | |||
"umbrella", | |||
"boat", | |||
"flag", | |||
"speaker", | |||
"trash bin/can", | |||
"stool", | |||
"backpack", | |||
"couch", | |||
"belt", | |||
"carpet", | |||
"basket", | |||
"towel/napkin", | |||
"slippers", | |||
"barrel/bucket", | |||
"coffee table", | |||
"suv", | |||
"toy", | |||
"tie", | |||
"bed", | |||
"traffic light", | |||
"pen/pencil", | |||
"microphone", | |||
"sandals", | |||
"canned", | |||
"necklace", | |||
"mirror", | |||
"faucet", | |||
"bicycle", | |||
"bread", | |||
"high heels", | |||
"ring", | |||
"van", | |||
"watch", | |||
"sink", | |||
"horse", | |||
"fish", | |||
"apple", | |||
"camera", | |||
"candle", | |||
"teddy bear", | |||
"cake", | |||
"motorcycle", | |||
"wild bird", | |||
"laptop", | |||
"knife", | |||
"traffic sign", | |||
"cell phone", | |||
"paddle", | |||
"truck", | |||
"cow", | |||
"power outlet", | |||
"clock", | |||
"drum", | |||
"fork", | |||
"bus", | |||
"hanger", | |||
"nightstand", | |||
"pot/pan", | |||
"sheep", | |||
"guitar", | |||
"traffic cone", | |||
"tea pot", | |||
"keyboard", | |||
"tripod", | |||
"hockey", | |||
"fan", | |||
"dog", | |||
"spoon", | |||
"blackboard/whiteboard", | |||
"balloon", | |||
"air conditioner", | |||
"cymbal", | |||
"mouse", | |||
"telephone", | |||
"pickup truck", | |||
"orange", | |||
"banana", | |||
"airplane", | |||
"luggage", | |||
"skis", | |||
"soccer", | |||
"trolley", | |||
"oven", | |||
"remote", | |||
"baseball glove", | |||
"paper towel", | |||
"refrigerator", | |||
"train", | |||
"tomato", | |||
"machinery vehicle", | |||
"tent", | |||
"shampoo/shower gel", | |||
"head phone", | |||
"lantern", | |||
"donut", | |||
"cleaning products", | |||
"sailboat", | |||
"tangerine", | |||
"pizza", | |||
"kite", | |||
"computer box", | |||
"elephant", | |||
"toiletries", | |||
"gas stove", | |||
"broccoli", | |||
"toilet", | |||
"stroller", | |||
"shovel", | |||
"baseball bat", | |||
"microwave", | |||
"skateboard", | |||
"surfboard", | |||
"surveillance camera", | |||
"gun", | |||
"life saver", | |||
"cat", | |||
"lemon", | |||
"liquid soap", | |||
"zebra", | |||
"duck", | |||
"sports car", | |||
"giraffe", | |||
"pumpkin", | |||
"piano", | |||
"stop sign", | |||
"radiator", | |||
"converter", | |||
"tissue ", | |||
"carrot", | |||
"washing machine", | |||
"vent", | |||
"cookies", | |||
"cutting/chopping board", | |||
"tennis racket", | |||
"candy", | |||
"skating and skiing shoes", | |||
"scissors", | |||
"folder", | |||
"baseball", | |||
"strawberry", | |||
"bow tie", | |||
"pigeon", | |||
"pepper", | |||
"coffee machine", | |||
"bathtub", | |||
"snowboard", | |||
"suitcase", | |||
"grapes", | |||
"ladder", | |||
"pear", | |||
"american football", | |||
"basketball", | |||
"potato", | |||
"paint brush", | |||
"printer", | |||
"billiards", | |||
"fire hydrant", | |||
"goose", | |||
"projector", | |||
"sausage", | |||
"fire extinguisher", | |||
"extension cord", | |||
"facial mask", | |||
"tennis ball", | |||
"chopsticks", | |||
"electronic stove and gas stove", | |||
"pie", | |||
"frisbee", | |||
"kettle", | |||
"hamburger", | |||
"golf club", | |||
"cucumber", | |||
"clutch", | |||
"blender", | |||
"tong", | |||
"slide", | |||
"hot dog", | |||
"toothbrush", | |||
"facial cleanser", | |||
"mango", | |||
"deer", | |||
"egg", | |||
"violin", | |||
"marker", | |||
"ship", | |||
"chicken", | |||
"onion", | |||
"ice cream", | |||
"tape", | |||
"wheelchair", | |||
"plum", | |||
"bar soap", | |||
"scale", | |||
"watermelon", | |||
"cabbage", | |||
"router/modem", | |||
"golf ball", | |||
"pine apple", | |||
"crane", | |||
"fire truck", | |||
"peach", | |||
"cello", | |||
"notepaper", | |||
"tricycle", | |||
"toaster", | |||
"helicopter", | |||
"green beans", | |||
"brush", | |||
"carriage", | |||
"cigar", | |||
"earphone", | |||
"penguin", | |||
"hurdle", | |||
"swing", | |||
"radio", | |||
"CD", | |||
"parking meter", | |||
"swan", | |||
"garlic", | |||
"french fries", | |||
"horn", | |||
"avocado", | |||
"saxophone", | |||
"trumpet", | |||
"sandwich", | |||
"cue", | |||
"kiwi fruit", | |||
"bear", | |||
"fishing rod", | |||
"cherry", | |||
"tablet", | |||
"green vegetables", | |||
"nuts", | |||
"corn", | |||
"key", | |||
"screwdriver", | |||
"globe", | |||
"broom", | |||
"pliers", | |||
"volleyball", | |||
"hammer", | |||
"eggplant", | |||
"trophy", | |||
"dates", | |||
"board eraser", | |||
"rice", | |||
"tape measure/ruler", | |||
"dumbbell", | |||
"hamimelon", | |||
"stapler", | |||
"camel", | |||
"lettuce", | |||
"goldfish", | |||
"meat balls", | |||
"medal", | |||
"toothpaste", | |||
"antelope", | |||
"shrimp", | |||
"rickshaw", | |||
"trombone", | |||
"pomegranate", | |||
"coconut", | |||
"jellyfish", | |||
"mushroom", | |||
"calculator", | |||
"treadmill", | |||
"butterfly", | |||
"egg tart", | |||
"cheese", | |||
"pig", | |||
"pomelo", | |||
"race car", | |||
"rice cooker", | |||
"tuba", | |||
"crosswalk sign", | |||
"papaya", | |||
"hair drier", | |||
"green onion", | |||
"chips", | |||
"dolphin", | |||
"sushi", | |||
"urinal", | |||
"donkey", | |||
"electric drill", | |||
"spring rolls", | |||
"tortoise/turtle", | |||
"parrot", | |||
"flute", | |||
"measuring cup", | |||
"shark", | |||
"steak", | |||
"poker card", | |||
"binoculars", | |||
"llama", | |||
"radish", | |||
"noodles", | |||
"yak", | |||
"mop", | |||
"crab", | |||
"microscope", | |||
"barbell", | |||
"bread/bun", | |||
"baozi", | |||
"lion", | |||
"red cabbage", | |||
"polar bear", | |||
"lighter", | |||
"seal", | |||
"mangosteen", | |||
"comb", | |||
"eraser", | |||
"pitaya", | |||
"scallop", | |||
"pencil case", | |||
"saw", | |||
"table tennis paddle", | |||
"okra", | |||
"starfish", | |||
"eagle", | |||
"monkey", | |||
"durian", | |||
"game board", | |||
"rabbit", | |||
"french horn", | |||
"ambulance", | |||
"asparagus", | |||
"hoverboard", | |||
"pasta", | |||
"target", | |||
"hotair balloon", | |||
"chainsaw", | |||
"lobster", | |||
"iron", | |||
"flashlight", | |||
) |
@@ -1,89 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import hashlib | |||
import os | |||
import tarfile | |||
from ....distributed.util import is_distributed | |||
from ....logger import get_logger | |||
from ....utils.http_download import download_from_url | |||
IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |||
logger = get_logger(__name__) | |||
def _default_dataset_root(): | |||
default_dataset_root = os.path.expanduser( | |||
os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine") | |||
) | |||
return default_dataset_root | |||
def load_raw_data_from_url( | |||
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||
): | |||
cached_file = os.path.join(raw_data_dir, filename) | |||
logger.debug( | |||
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||
) | |||
if not os.path.exists(cached_file): | |||
if is_distributed(): | |||
logger.warning( | |||
"Downloading raw data in DISTRIBUTED mode\n" | |||
" File may be downloaded multiple times. We recommend\n" | |||
" users to download in single process first." | |||
) | |||
md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||
else: | |||
md5 = calculate_md5(cached_file) | |||
if target_md5 == md5: | |||
logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||
else: | |||
os.remove(cached_file) | |||
raise RuntimeError("{} exists but fail to match md5".format(filename)) | |||
def calculate_md5(filename): | |||
m = hashlib.md5() | |||
with open(filename, "rb") as f: | |||
while True: | |||
data = f.read(4096) | |||
if not data: | |||
break | |||
m.update(data) | |||
return m.hexdigest() | |||
def is_img(filename): | |||
return filename.lower().endswith(IMG_EXT) | |||
def untar(path, to=None, remove=False): | |||
if to is None: | |||
to = os.path.dirname(path) | |||
with tarfile.open(path, "r") as tar: | |||
tar.extractall(path=to) | |||
if remove: | |||
os.remove(path) | |||
def untargz(path, to=None, remove=False): | |||
if path.endswith(".tar.gz"): | |||
if to is None: | |||
to = os.path.dirname(path) | |||
with tarfile.open(path, "r:gz") as tar: | |||
tar.extractall(path=to) | |||
else: | |||
raise ValueError("path %s does not end with .tar" % path) | |||
if remove: | |||
os.remove(path) |
@@ -1,185 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to torchvision | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import collections.abc | |||
import os | |||
import xml.etree.ElementTree as ET | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class PascalVOC(VisionDataset): | |||
r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"mask", | |||
"info", | |||
) | |||
def __init__(self, root, image_set, *, order=None): | |||
if ("boxes" in order or "boxes_category" in order) and "mask" in order: | |||
raise ValueError( | |||
"PascalVOC only supports boxes & boxes_category or mask, not both." | |||
) | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
if not os.path.isdir(self.root): | |||
raise RuntimeError("Dataset not found or corrupted.") | |||
self.image_set = image_set | |||
image_dir = os.path.join(self.root, "JPEGImages") | |||
if "boxes" in order or "boxes_category" in order: | |||
annotation_dir = os.path.join(self.root, "Annotations") | |||
splitdet_dir = os.path.join(self.root, "ImageSets/Main") | |||
split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") | |||
with open(os.path.join(split_f), "r") as f: | |||
self.file_names = [x.strip() for x in f.readlines()] | |||
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
self.annotations = [ | |||
os.path.join(annotation_dir, x + ".xml") for x in self.file_names | |||
] | |||
assert len(self.images) == len(self.annotations) | |||
elif "mask" in order: | |||
if "aug" in image_set: | |||
mask_dir = os.path.join(self.root, "SegmentationClass_aug") | |||
else: | |||
mask_dir = os.path.join(self.root, "SegmentationClass") | |||
splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") | |||
split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") | |||
with open(os.path.join(split_f), "r") as f: | |||
self.file_names = [x.strip() for x in f.readlines()] | |||
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] | |||
assert len(self.images) == len(self.masks) | |||
else: | |||
raise NotImplementedError | |||
self.img_infos = dict() | |||
def __getitem__(self, index): | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] | |||
# boxes type xyxy | |||
boxes = [ | |||
(bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes | |||
] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] | |||
boxes_category = [ | |||
self.class_names.index(bc) + 1 for bc in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "mask": | |||
if "aug" in self.image_set: | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
else: | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||
mask = self._trans_mask(mask) | |||
mask = mask[:, :, np.newaxis] | |||
target.append(mask) | |||
elif k == "info": | |||
info = self.get_img_info(index, image) | |||
info = [info["height"], info["width"], info["file_name"]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.images) | |||
def get_img_info(self, index, image=None): | |||
if index not in self.img_infos: | |||
if image is None: | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
self.img_infos[index] = dict( | |||
height=image.shape[0], | |||
width=image.shape[1], | |||
file_name=self.file_names[index], | |||
) | |||
return self.img_infos[index] | |||
def _trans_mask(self, mask): | |||
label = np.ones(mask.shape[:2]) * 255 | |||
for i in range(len(self.class_colors)): | |||
b, g, r = self.class_colors[i] | |||
label[ | |||
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | |||
] = i | |||
return label.astype(np.uint8) | |||
def parse_voc_xml(self, node): | |||
voc_dict = {} | |||
children = list(node) | |||
if children: | |||
def_dic = collections.defaultdict(list) | |||
for dc in map(self.parse_voc_xml, children): | |||
for ind, v in dc.items(): | |||
def_dic[ind].append(v) | |||
if node.tag == "annotation": | |||
def_dic["object"] = [def_dic["object"]] | |||
voc_dict = { | |||
node.tag: { | |||
ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() | |||
} | |||
} | |||
if node.text: | |||
text = node.text.strip() | |||
if not children: | |||
voc_dict[node.tag] = text | |||
return voc_dict | |||
class_names = ( | |||
"aeroplane", | |||
"bicycle", | |||
"bird", | |||
"boat", | |||
"bottle", | |||
"bus", | |||
"car", | |||
"cat", | |||
"chair", | |||
"cow", | |||
"diningtable", | |||
"dog", | |||
"horse", | |||
"motorbike", | |||
"person", | |||
"pottedplant", | |||
"sheep", | |||
"sofa", | |||
"train", | |||
"tvmonitor", | |||
) |
@@ -1,274 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import math | |||
from abc import ABC | |||
from typing import Any, Generator, Iterator, List, Union | |||
import numpy as np | |||
import megengine.distributed as dist | |||
class Sampler(ABC): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
num_samples=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
An abstract class for all sampler | |||
:type dataset: `dataset` | |||
:param dataset: dataset to sample from | |||
:type batch_size: positive integer | |||
:param batch_size: batch size for batch method | |||
:type drop_last: bool | |||
:param drop_last: set ``True`` to drop the last incomplete batch, | |||
if the dataset size is not divisible by the batch size. If ``False`` and | |||
the size of dataset is not divisible by the batch_size, then the last batch will | |||
be smaller. (default: ``False``) | |||
:type num_samples: positive integer | |||
:param num_samples: number of samples assigned to one rank | |||
:type world_size: positive integer | |||
:param world_size: number of ranks | |||
:type rank: non-negative integer within 0 and world_size | |||
:param rank: rank id, non-negative interger within 0 and ``world_size`` | |||
:type seed: non-negative integer | |||
:param seed: seed for random operators | |||
""" | |||
if ( | |||
not isinstance(batch_size, int) | |||
or isinstance(batch_size, bool) | |||
or batch_size <= 0 | |||
): | |||
raise ValueError( | |||
"batch_size should be a positive integer value, " | |||
"but got batch_size={}".format(batch_size) | |||
) | |||
if not isinstance(drop_last, bool): | |||
raise ValueError( | |||
"drop_last should be a boolean value, but got " | |||
"drop_last={}".format(drop_last) | |||
) | |||
if num_samples is not None and ( | |||
not isinstance(num_samples, int) | |||
or isinstance(num_samples, bool) | |||
or num_samples <= 0 | |||
): | |||
raise ValueError( | |||
"num_samples should be a positive integer " | |||
"value, but got num_samples={}".format(num_samples) | |||
) | |||
self.batch_size = batch_size | |||
self.dataset = dataset | |||
self.drop_last = drop_last | |||
if world_size is None: | |||
world_size = dist.get_world_size() if dist.is_distributed() else 1 | |||
self.world_size = world_size | |||
if rank is None: | |||
rank = dist.get_rank() if dist.is_distributed() else 0 | |||
self.rank = rank | |||
if num_samples is None: | |||
num_samples = len(self.dataset) | |||
self.num_samples = int(math.ceil(num_samples / self.world_size)) | |||
# Make sure seeds are the same at each rank | |||
if seed is None and self.world_size > 1: | |||
seed = 0 | |||
self.rng = np.random.RandomState(seed) | |||
def __iter__(self) -> Union[Generator, Iterator]: | |||
return self.batch() | |||
def __len__(self) -> int: | |||
if self.drop_last: | |||
return self.num_samples // self.batch_size | |||
else: | |||
return int(math.ceil(self.num_samples / self.batch_size)) | |||
def sample(self): | |||
""" | |||
return a list contains all sample indices | |||
""" | |||
raise NotImplementedError | |||
def scatter(self, indices) -> List: | |||
r""" | |||
scatter method is used for splitting indices into subset, each subset | |||
will be assigned to a rank. Indices are evenly splitted by default. | |||
If customized indices assignment method is needed, please rewrite this method | |||
""" | |||
total_size = self.num_samples * self.world_size | |||
# add extra indices to make it evenly divisible | |||
indices += indices[: (total_size - len(indices))] | |||
assert len(indices) == total_size | |||
# subsample | |||
indices = indices[self.rank : total_size : self.world_size] | |||
assert len(indices) == self.num_samples | |||
return indices | |||
def batch(self) -> Iterator[List[Any]]: | |||
r""" | |||
batch method provides a batch indices generator | |||
""" | |||
indices = list(self.sample()) | |||
# user might pass the world_size parameter without dist, | |||
# so dist.is_distributed() should not be used | |||
if self.world_size > 1: | |||
indices = self.scatter(indices) | |||
step, length = self.batch_size, len(indices) | |||
batch_index = [indices[i : i + step] for i in range(0, length, step)] | |||
if self.drop_last and len(batch_index[-1]) < self.batch_size: | |||
batch_index.pop() | |||
return iter(batch_index) | |||
class SequentialSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
indices=None, | |||
world_size=None, | |||
rank=None, | |||
): | |||
r""" | |||
Sample elements sequentially | |||
""" | |||
super().__init__(dataset, batch_size, drop_last, None, world_size, rank) | |||
if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
raise ValueError( | |||
"indices should be None or a sequence, " | |||
"but got indices={}".format(indices) | |||
) | |||
self.indices = indices | |||
def sample(self) -> Iterator[Any]: | |||
r""" | |||
return a generator | |||
""" | |||
if self.indices is None: | |||
return iter(range(len(self.dataset))) | |||
else: | |||
return self.indices | |||
class RandomSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
indices=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
Sample elements randomly without replacement | |||
""" | |||
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) | |||
if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
raise ValueError( | |||
"indices should be None or a sequence, " | |||
"but got indices={}".format(indices) | |||
) | |||
self.indices = indices | |||
def sample(self) -> List: | |||
if self.indices is None: | |||
return self.rng.permutation(len(self.dataset)).tolist() | |||
else: | |||
return self.rng.permutation(self.indices).tolist() | |||
class ReplacementSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
num_samples=None, | |||
weights=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
Sample elements randomly with replacement | |||
:type weights: List | |||
:param weights: weights for sampling indices, it could be unnormalized weights | |||
""" | |||
super().__init__( | |||
dataset, batch_size, drop_last, num_samples, world_size, rank, seed | |||
) | |||
if weights is not None: | |||
if not isinstance(weights, collections.abc.Sequence): | |||
raise ValueError( | |||
"weights should be None or a sequence, " | |||
"but got weights={}".format(weights) | |||
) | |||
if len(weights) != len(dataset): | |||
raise ValueError( | |||
"len(dataset)={} should be equal to" | |||
"len(weights)={}".format(len(dataset), len(weights)) | |||
) | |||
self.weights = weights | |||
if self.weights is not None: | |||
self.weights = np.array(weights) / sum(weights) | |||
def sample(self) -> List: | |||
n = len(self.dataset) | |||
if self.weights is None: | |||
return self.rng.randint(n, size=self.num_samples).tolist() | |||
else: | |||
return self.rng.multinomial(n, self.weights, self.num_samples).tolist() | |||
class Infinite(Sampler): | |||
r"""Infinite Sampler warper for basic sampler""" | |||
def sample(self): | |||
raise NotImplementedError("sample method not supported in Infinite") | |||
def __init__(self, sampler): | |||
self.sampler = sampler | |||
self.sampler_iter = iter(self.sampler) | |||
def __iter__(self): | |||
return self | |||
def __next__(self): | |||
try: | |||
index = next(self.sampler_iter) | |||
except StopIteration: | |||
self.sampler_iter = iter(self.sampler) | |||
index = next(self.sampler_iter) | |||
return index | |||
def __len__(self): | |||
return np.iinfo(np.int64).max |
@@ -1,10 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .meta_transform import PseudoTransform, Transform | |||
from .vision import * |
@@ -1,31 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABC, abstractmethod | |||
from typing import Sequence, Tuple | |||
class Transform(ABC): | |||
""" | |||
rewrite apply method in subclass | |||
""" | |||
def apply_batch(self, inputs: Sequence[Tuple]): | |||
return tuple(self.apply(input) for input in inputs) | |||
@abstractmethod | |||
def apply(self, input: Tuple): | |||
pass | |||
def __repr__(self): | |||
return self.__class__.__name__ | |||
class PseudoTransform(Transform): | |||
def apply(self, input: Tuple): | |||
return input |
@@ -1,9 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .transform import * |
@@ -1,111 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import functools | |||
import random | |||
import cv2 | |||
import numpy as np | |||
def wrap_keepdims(func): | |||
"""Wraper to keep the dimension of input images unchanged""" | |||
@functools.wraps(func) | |||
def wrapper(image, *args, **kwargs): | |||
if len(image.shape) != 3: | |||
raise ValueError( | |||
"image must have 3 dims, but got {} dims".format(len(image.shape)) | |||
) | |||
ret = func(image, *args, **kwargs) | |||
if len(ret.shape) == 2: | |||
ret = ret[:, :, np.newaxis] | |||
return ret | |||
return wrapper | |||
@wrap_keepdims | |||
def to_gray(image): | |||
r""" | |||
Change BGR format image's color space to gray | |||
:param image: Input BGR format image, with (H, W, C) shape | |||
:return: Gray format image, with (H, W, C) shape | |||
""" | |||
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||
@wrap_keepdims | |||
def to_bgr(image): | |||
r""" | |||
Change gray format image's color space to BGR | |||
:param image: input Gray format image, with (H, W, C) shape | |||
:return: BGR format image, with (H, W, C) shape | |||
""" | |||
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |||
@wrap_keepdims | |||
def pad(input, size, value): | |||
r""" | |||
Pad input data with *value* and given *size* | |||
:param input: Input data, with (H, W, C) shape | |||
:param size: Padding size of input data, it could be integer or sequence. | |||
If it's an integer, the input data will be padded in four directions. | |||
If it's a sequence contains two integer, the bottom and right side | |||
of input data will be padded. | |||
If it's a sequence contains four integer, the top, bottom, left, right | |||
side of input data will be padded with given size. | |||
:param value: Padding value of data, could be a sequence of int or float. | |||
if it's float value, the dtype of image will be casted to float32 also. | |||
:return: Padded image | |||
""" | |||
if isinstance(size, int): | |||
size = (size, size, size, size) | |||
elif isinstance(size, collections.abc.Sequence) and len(size) == 2: | |||
size = (0, size[0], 0, size[1]) | |||
if np.array(value).dtype == float: | |||
input = input.astype(np.float32) | |||
return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value) | |||
@wrap_keepdims | |||
def flip(image, flipCode): | |||
r""" | |||
Accordding to the flipCode (the type of flip), flip the input image | |||
:param image: Input image, with (H, W, C) shape | |||
:param flipCode: code that indicates the type of flip. | |||
1 : Flip horizontally | |||
0 : Flip vertically | |||
-1 : Flip horizontally and vertically | |||
:return: BGR format image, with (H, W, C) shape | |||
""" | |||
return cv2.flip(image, flipCode=flipCode) | |||
@wrap_keepdims | |||
def resize(input, size, interpolation=cv2.INTER_LINEAR): | |||
r""" | |||
resize the input data to given size | |||
:param input: Input data, could be image or masks, with (H, W, C) shape | |||
:param size: Target size of input data, with (height, width) shape. | |||
:param interpolation: Interpolation method. | |||
:return: Resized data, with (H, W, C) shape | |||
""" | |||
if len(size) != 2: | |||
raise ValueError("resize needs (h, w), but got {}".format(size)) | |||
if isinstance(interpolation, collections.abc.Sequence): | |||
interpolation = random.choice(interpolation) | |||
return cv2.resize(input, size[::-1], interpolation=interpolation) |
@@ -1,33 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .functional import ( | |||
all_gather, | |||
all_reduce_max, | |||
all_reduce_min, | |||
all_reduce_sum, | |||
all_to_all, | |||
bcast_param, | |||
broadcast, | |||
gather, | |||
reduce_scatter_sum, | |||
reduce_sum, | |||
scatter, | |||
) | |||
from .util import ( | |||
get_backend, | |||
get_free_ports, | |||
get_master_ip, | |||
get_master_port, | |||
get_rank, | |||
get_world_size, | |||
group_barrier, | |||
init_process_group, | |||
is_distributed, | |||
synchronized, | |||
) |
@@ -1,302 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Union | |||
import megengine._internal as mgb | |||
from megengine._internal.opr_param_defs import CollectiveComm as Param | |||
from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | |||
from ..functional import add_update | |||
from .helper import collective_comm_symvar | |||
from .util import get_rank, is_distributed | |||
@wrap_io_tensor | |||
def _collective_comm(*args, **kargs): | |||
return collective_comm_symvar(*args, **kargs) | |||
def _group_check(*args): | |||
"""Return True when arguments are all None or all not None | |||
""" | |||
l = [val is None for val in args] | |||
return len(set(l)) <= 1 | |||
def reduce_sum( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
) -> Tensor: | |||
"""Create reduce_sum operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this is a root node | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, is_root | |||
), "key, nr_ranks, is_root should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||
) | |||
def gather( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
rank: Optional[int] = None, | |||
) -> Tensor: | |||
"""Create gather operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this is a root node | |||
:param rank: rank of this node | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, is_root, rank | |||
), "key, nr_ranks, is_root, rank should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device, | |||
) | |||
def broadcast( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
) -> Tensor: | |||
"""Create broadcast operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this is a root node | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, is_root | |||
), "key, nr_ranks, is_root should be set at the same time" | |||
if is_root is None: | |||
is_root = get_rank() == 0 | |||
if is_root: | |||
inp = tensor | |||
else: | |||
inp = tensor._symvar.owner_graph | |||
return _collective_comm( | |||
inp, | |||
key, | |||
Param.Mode.BROADCAST, | |||
nr_ranks, | |||
is_root, | |||
dtype=tensor.dtype, | |||
device=tensor.device, | |||
) | |||
def scatter( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
rank: Optional[int] = None, | |||
) -> Tensor: | |||
"""Create scatter operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this is a root node | |||
:param rank: rank of this node | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, is_root, rank | |||
), "key, nr_ranks, is_root, rank should be set at the same time" | |||
if key is None: | |||
key = tensor._symvar.name | |||
if is_root is None: | |||
is_root = get_rank() == 0 | |||
if is_root: | |||
inp = tensor | |||
else: | |||
inp = tensor._symvar.owner_graph | |||
return _collective_comm( | |||
inp, | |||
key, | |||
Param.Mode.SCATTER, | |||
nr_ranks, | |||
is_root, | |||
rank, | |||
dtype=tensor.dtype, | |||
device=tensor.device, | |||
) | |||
def all_to_all( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create all_to_all operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of this node | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, rank | |||
), "key, nr_ranks, rank should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad, | |||
) | |||
def all_gather( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create all_gather operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of this node | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, rank | |||
), "key, nr_ranks, rank should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad | |||
) | |||
def reduce_scatter_sum( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
rank: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create reduce_scatter_sum operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param rank: rank of this node | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check( | |||
key, nr_ranks, rank | |||
), "key, nr_ranks, rank should be set at the same time" | |||
return _collective_comm( | |||
tensor, | |||
key, | |||
Param.Mode.REDUCE_SCATTER_SUM, | |||
nr_ranks, | |||
rank=rank, | |||
local_grad=local_grad, | |||
) | |||
def all_reduce_sum( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create all_reduce_sum operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad | |||
) | |||
def all_reduce_max( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create all_reduce_max operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad | |||
) | |||
def all_reduce_min( | |||
tensor: Tensor, | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
) -> Tensor: | |||
"""Create all_reduce_min operator for collective communication | |||
:param tensor: input tensor | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param local_grad: whether use local grad | |||
""" | |||
assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
return _collective_comm( | |||
tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad | |||
) | |||
def bcast_param( | |||
inp: Union[Buffer, Parameter], | |||
key: Optional[str] = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
) -> None: | |||
"""Broadcast parameters among devices | |||
:param inp: input Buffer or Parameter to be synchronized | |||
:param key: unique identifier for collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this is a root node | |||
""" | |||
if not is_distributed(): | |||
return | |||
assert _group_check( | |||
key, nr_ranks, is_root | |||
), "key, nr_ranks, is_root should be set at the same time" | |||
assert isinstance(inp, (Buffer, Parameter)) | |||
bcast_res = broadcast(inp, key, nr_ranks, is_root) | |||
add_update(inp, bcast_res, alpha=0) |
@@ -1,63 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Union | |||
import megengine._internal as mgb | |||
from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
from .util import ( | |||
get_backend, | |||
get_group_id, | |||
get_master_ip, | |||
get_master_port, | |||
get_rank, | |||
get_world_size, | |||
) | |||
def collective_comm_symvar( | |||
inp: Union[mgb.SymbolVar, mgb.CompGraph], | |||
key: Optional[str] = None, | |||
op: CollParam.Mode = None, | |||
nr_ranks: Optional[int] = None, | |||
is_root: Optional[bool] = None, | |||
rank: Optional[int] = None, | |||
local_grad: Optional[bool] = False, | |||
dtype: Optional[type] = None, | |||
device: Optional[mgb.CompNode] = None, | |||
comp_graph: Optional[mgb.CompGraph] = None, | |||
) -> mgb.SymbolVar: | |||
"""Helper function for creating collective_comm operators | |||
:param inp: tensor or comp_graph | |||
:param key: unique identifier for collective communication | |||
:param op: mode of collective communication | |||
:param nr_ranks: number of ranks, use util.get_world_size() as default | |||
:param is_root: whether this node is root node | |||
:param rank: rank of this node | |||
:param local_grad: whether use local grad | |||
:param dtype: output data type, use dtype of inp as default | |||
:param device: output comp node, use comp node of inp as default | |||
:param comp_graph: output comp graph, use comp graph of inp as default | |||
""" | |||
return mgb.opr.collective_comm( | |||
inp, | |||
key=key if key is not None else ("collective_comm_" + str(get_group_id())), | |||
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
is_root=is_root if is_root is not None else (get_rank() == 0), | |||
rank=rank if rank is not None else get_rank(), | |||
local_grad=local_grad, | |||
server_addr=get_master_ip(), | |||
port=get_master_port(), | |||
param=CollParam(mode=op), | |||
dtype=dtype, | |||
backend=get_backend(), | |||
comp_node=device, | |||
comp_graph=comp_graph, | |||
) |
@@ -1,146 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import socket | |||
from typing import Callable, List, Optional | |||
import megengine._internal as mgb | |||
from ..core import set_default_device | |||
_master_ip = None | |||
_master_port = 0 | |||
_world_size = 0 | |||
_rank = 0 | |||
_backend = None | |||
_group_id = 0 | |||
def init_process_group( | |||
master_ip: str, | |||
master_port: int, | |||
world_size: int, | |||
rank: int, | |||
dev: int, | |||
backend: Optional[str] = "nccl", | |||
) -> None: | |||
"""Initialize the distributed process group, and also specify the device used in the current process. | |||
:param master_ip: IP address of the master node. | |||
:param master_port: Port available for all processes to communicate. | |||
:param world_size: Total number of processes participating in the job. | |||
:param rank: Rank of the current process. | |||
:param dev: The GPU device id to bind this process to. | |||
:param backend: Communicator backend, currently support 'nccl' and 'ucx' | |||
""" | |||
global _master_ip # pylint: disable=global-statement | |||
global _master_port # pylint: disable=global-statement | |||
global _world_size # pylint: disable=global-statement | |||
global _rank # pylint: disable=global-statement | |||
global _backend # pylint: disable=global-statement | |||
global _group_id # pylint: disable=global-statement | |||
if not isinstance(master_ip, str): | |||
raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
if not isinstance(master_port, int): | |||
raise TypeError("Expect type int but got {}".format(type(master_port))) | |||
if not isinstance(world_size, int): | |||
raise TypeError("Expect type int but got {}".format(type(world_size))) | |||
if not isinstance(rank, int): | |||
raise TypeError("Expect type int but got {}".format(type(rank))) | |||
if not isinstance(backend, str): | |||
raise TypeError("Expect type str but got {}".format(type(backend))) | |||
_master_ip = master_ip | |||
_master_port = master_port | |||
_world_size = world_size | |||
_rank = rank | |||
_backend = backend | |||
_group_id = 0 | |||
set_default_device(mgb.comp_node("gpu" + str(dev))) | |||
if rank == 0: | |||
_master_port = mgb.config.create_mm_server("0.0.0.0", master_port) | |||
if _master_port == -1: | |||
raise Exception("Failed to start server on port {}".format(master_port)) | |||
else: | |||
assert master_port > 0, "master_port must be specified for non-zero rank" | |||
def is_distributed() -> bool: | |||
"""Return True if the distributed process group has been initialized""" | |||
return _world_size is not None and _world_size > 1 | |||
def get_master_ip() -> str: | |||
"""Get the IP address of the master node""" | |||
return str(_master_ip) | |||
def get_master_port() -> int: | |||
"""Get the port of the rpc server on the master node""" | |||
return _master_port | |||
def get_world_size() -> int: | |||
"""Get the total number of processes participating in the job""" | |||
return _world_size | |||
def get_rank() -> int: | |||
"""Get the rank of the current process""" | |||
return _rank | |||
def get_backend() -> str: | |||
"""Get the backend str""" | |||
return str(_backend) | |||
def get_group_id() -> int: | |||
"""Get group id for collective communication""" | |||
global _group_id | |||
_group_id += 1 | |||
return _group_id | |||
def group_barrier() -> None: | |||
"""Block until all ranks in the group reach this barrier""" | |||
mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank) | |||
def synchronized(func: Callable): | |||
"""Decorator. Decorated function will synchronize when finished. | |||
Specifically, we use this to prevent data race during hub.load""" | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
if not is_distributed(): | |||
return func(*args, **kwargs) | |||
ret = func(*args, **kwargs) | |||
group_barrier() | |||
return ret | |||
return wrapper | |||
def get_free_ports(num: int) -> List[int]: | |||
"""Get one or more free ports. | |||
""" | |||
socks, ports = [], [] | |||
for i in range(num): | |||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
sock.bind(("", 0)) | |||
socks.append(sock) | |||
ports.append(sock.getsockname()[1]) | |||
for sock in socks: | |||
sock.close() | |||
return ports |
@@ -1,118 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=redefined-builtin | |||
from .elemwise import ( | |||
abs, | |||
add, | |||
arccos, | |||
arcsin, | |||
ceil, | |||
clamp, | |||
cos, | |||
divide, | |||
equal, | |||
exp, | |||
floor, | |||
greater, | |||
greater_equal, | |||
isinf, | |||
isnan, | |||
less, | |||
less_equal, | |||
log, | |||
maximum, | |||
minimum, | |||
mod, | |||
multiply, | |||
power, | |||
relu, | |||
round, | |||
sigmoid, | |||
sin, | |||
subtract, | |||
tanh, | |||
) | |||
from .graph import add_extra_vardep, add_update, grad | |||
from .loss import ( | |||
binary_cross_entropy, | |||
cross_entropy, | |||
cross_entropy_with_softmax, | |||
hinge_loss, | |||
l1_loss, | |||
nll_loss, | |||
smooth_l1_loss, | |||
square_loss, | |||
triplet_margin_loss, | |||
) | |||
from .math import ( | |||
argmax, | |||
argmin, | |||
logsumexp, | |||
max, | |||
mean, | |||
min, | |||
norm, | |||
normalize, | |||
prod, | |||
sqrt, | |||
sum, | |||
) | |||
from .nn import ( | |||
assert_equal, | |||
avg_pool2d, | |||
batch_norm2d, | |||
batched_matrix_mul, | |||
conv2d, | |||
conv_transpose2d, | |||
dropout, | |||
embedding, | |||
eye, | |||
flatten, | |||
identity, | |||
indexing_one_hot, | |||
interpolate, | |||
leaky_relu, | |||
linear, | |||
local_conv2d, | |||
matrix_mul, | |||
max_pool2d, | |||
one_hot, | |||
prelu, | |||
remap, | |||
roi_align, | |||
roi_pooling, | |||
softmax, | |||
softplus, | |||
sync_batch_norm, | |||
warp_perspective, | |||
) | |||
from .quantized import conv_bias_activation | |||
from .sort import argsort, sort, top_k | |||
from .tensor import ( | |||
add_axis, | |||
arange, | |||
broadcast_to, | |||
concat, | |||
cond_take, | |||
dimshuffle, | |||
gather, | |||
linspace, | |||
remove_axis, | |||
reshape, | |||
scatter, | |||
shapeof, | |||
transpose, | |||
where, | |||
zeros_like, | |||
) | |||
from .utils import accuracy, zero_grad | |||
# delete namespace | |||
# pylint: disable=undefined-variable | |||
del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] |
@@ -1,49 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
_conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") | |||
def get_conv_execution_strategy() -> str: | |||
"""Returns the execuation strategy of :class:`~.Conv2d`. | |||
See :func:`~.set_conv_execution_strategy` for possible return values | |||
""" | |||
return _conv_execution_strategy | |||
def set_conv_execution_strategy(option: str): | |||
"""Sets the execuation strategy of :class:`~.Conv2d`. | |||
:param option: Decides how :class:`~.Conv2d` algorithm is chosen. | |||
Available values: | |||
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. | |||
* 'PROFILE' runs possible algorithms on real device to find the best. | |||
* 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. | |||
* 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. | |||
* 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | |||
The default strategy is 'HEURISTIC'. | |||
It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. | |||
""" | |||
valid_option = ( | |||
"HEURISTIC", | |||
"PROFILE", | |||
"PROFILE_HEURISTIC", | |||
"PROFILE_REPRODUCIBLE", | |||
"HEURISTIC_REPRODUCIBLE", | |||
) | |||
if not option in valid_option: | |||
raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||
global _conv_execution_strategy # pylint: disable=global-statement | |||
_conv_execution_strategy = option |
@@ -1,299 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
import functools | |||
import megengine._internal as mgb | |||
from ..core.graph import _use_default_if_none | |||
from ..core.tensor import Tensor, wrap_io_tensor | |||
__all__ = [ | |||
"abs", | |||
"arccos", | |||
"add", | |||
"arcsin", | |||
"clamp", | |||
"ceil", | |||
"cos", | |||
"divide", | |||
"equal", | |||
"exp", | |||
"greater", | |||
"greater_equal", | |||
"floor", | |||
"isinf", | |||
"isnan", | |||
"less", | |||
"less_equal", | |||
"log", | |||
"maximum", | |||
"minimum", | |||
"mod", | |||
"multiply", | |||
"power", | |||
"relu", | |||
"round", | |||
"sigmoid", | |||
"sin", | |||
"subtract", | |||
"tanh", | |||
] | |||
def _elemwise(mode): # DONT export | |||
"""Decorator helps to wrap megbrain element-wise oprs""" | |||
def elemwise_decorator(func): | |||
@functools.wraps(func) | |||
@wrap_io_tensor | |||
def elemwise_func(*inputs) -> Tensor: | |||
if all(isinstance(i, (int, float)) for i in inputs): | |||
device, comp_graph = _use_default_if_none(None, None) | |||
ret = mgb.opr.elemwise( | |||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph | |||
) | |||
return ret.inferred_value[0] | |||
return mgb.opr.elemwise(*inputs, mode=mode) | |||
return elemwise_func | |||
return elemwise_decorator | |||
@_elemwise("ABS") | |||
def abs(x): | |||
"""Calculate the absolute value element-wise.""" | |||
@_elemwise("ACOS") | |||
def arccos(x): | |||
"""Inverse cosine, element-wise.""" | |||
@_elemwise("ADD") | |||
def add(x, y): | |||
"""Element-wise addition.""" | |||
@_elemwise("ASIN") | |||
def arcsin(x): | |||
"""Inverse sine, element-wise.""" | |||
@_elemwise("CEIL") | |||
def ceil(x): | |||
"""Return the ceil of the input, element-wise.""" | |||
@_elemwise("COS") | |||
def cos(x): | |||
"""Cosine, element-wise.""" | |||
@_elemwise("TRUE_DIV") | |||
def divide(x, y): | |||
"""Return (x / y) element-wise.""" | |||
@_elemwise("EQ") | |||
def equal(x, y): | |||
"""Return (x == y) element-wise.""" | |||
@_elemwise("EXP") | |||
def exp(x): | |||
"""Calculate the exponential element-wise""" | |||
@_elemwise("FLOOR") | |||
def floor(x): | |||
"""Return the floor of the input, element-wise""" | |||
def greater(x, y): | |||
"""Return (x > y) element-wise.""" | |||
return less(y, x) | |||
def greater_equal(x, y): | |||
"""Return (x >= y) element-wise""" | |||
return less_equal(y, x) | |||
@_elemwise("LT") | |||
def less(x, y): | |||
"""Return (x < y) element-wise.""" | |||
@_elemwise("LEQ") | |||
def less_equal(x, y): | |||
"""Return (x =< y) element-wise.""" | |||
@_elemwise("LOG") | |||
def log(x): | |||
"""Natural logarithm (base `e`), element-wise.""" | |||
@_elemwise("MAX") | |||
def maximum(x, y): | |||
"""Element-wise maximum of array elements.""" | |||
@_elemwise("MIN") | |||
def minimum(x, y): | |||
"""Element-wise minimum of array elements.""" | |||
@_elemwise("MOD") | |||
def mod(x, y): | |||
"""Return element-wise remainder of division.""" | |||
@_elemwise("MUL") | |||
def multiply(x, y): | |||
"""Element-wise multiplication.""" | |||
@_elemwise("POW") | |||
def power(x, y): | |||
"""First tensor elements raised to powers from second tensor (x ** y), element-wise.""" | |||
@_elemwise("RELU") | |||
def relu(x): | |||
"""Return `max(x, 0)` element-wise.""" | |||
@_elemwise("ROUND") | |||
def round(x): | |||
"""Round tensor to int element-wise.""" | |||
@_elemwise("SIGMOID") | |||
def sigmoid(x): | |||
"""Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||
@_elemwise("SIN") | |||
def sin(x): | |||
"""Sine, element-wise.""" | |||
@_elemwise("SUB") | |||
def subtract(x, y): | |||
"""Subtract arguments element-wise""" | |||
@_elemwise("TANH") | |||
def tanh(x): | |||
"""Compute hyperbolic tangent element-wise.""" | |||
@wrap_io_tensor | |||
def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
r""" | |||
Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||
a resulting tensor: | |||
.. math:: | |||
y_i = \begin{cases} | |||
\text{lower} & \text{if } x_i < \text{lower} \\ | |||
x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ | |||
\text{upper} & \text{if } x_i > \text{upper} | |||
\end{cases} | |||
:param inp: the input tensor. | |||
:param lower: lower-bound of the range to be clamped to | |||
:param upper: upper-bound of the range to be clamped to | |||
Example: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
a = tensor(np.arange(5).astype(np.int32)) | |||
print(F.clamp(a, 2, 4).numpy()) | |||
print(F.clamp(a, lower=3).numpy()) | |||
print(F.clamp(a, upper=3).numpy()) | |||
.. testoutput:: | |||
[2 2 2 3 4] | |||
[3 3 3 3 4] | |||
[0 1 2 3 3] | |||
""" | |||
assert ( | |||
lower is not None or upper is not None | |||
), "At least one of 'lower' or 'upper' must not be None" | |||
if lower is not None: | |||
if upper is not None: | |||
assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||
return minimum(maximum(inp, lower), upper) | |||
else: | |||
return maximum(inp, lower) | |||
else: | |||
return minimum(inp, upper) | |||
def isnan(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is NaN or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, float("nan"), 0]) | |||
print(F.isnan(x)) | |||
.. testoutput:: | |||
Tensor([0 1 0], dtype=uint8) | |||
""" | |||
return (inp != inp).astype("uint8") | |||
def isinf(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is Inf or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, float("inf"), 0]) | |||
print(F.isinf(x)) | |||
.. testoutput:: | |||
Tensor([0 1 0], dtype=uint8) | |||
""" | |||
return (abs(inp) == float("inf")).astype("uint8") |
@@ -1,65 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=too-many-lines | |||
from typing import List | |||
import megengine._internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
@wrap_io_tensor | |||
def cambricon_subgraph( | |||
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||
) -> List[Tensor]: | |||
"""Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
execute the operations defined in the subgraph. | |||
:param inputs: List of input tensors of the subgraph. | |||
:param data: The serialized subgraph. | |||
:param symbol: The name of the function in the subgraph. | |||
The function is corresponding to a cnmlFusionOp | |||
which is added to the cnmlModel_t/cnrtModel_t. | |||
:param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||
in cnrtModel_t | |||
""" | |||
return mgb.opr.cambricon_runtime( | |||
data, symbol, tuple(map(lambda x: x._symvar, inputs)), tensor_dim_mutable | |||
) | |||
@wrap_io_tensor | |||
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]: | |||
"""Load a serialized Atlas subgraph (i.e. om model) and | |||
execute the operations defined in the subgraph. | |||
:param inputs: List of input tensors of the subgraph. | |||
:param data: The serialized subgraph. | |||
""" | |||
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data) | |||
@wrap_io_tensor | |||
def extern_opr_subgraph( | |||
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||
) -> List[Tensor]: | |||
"""Load a serialized extern opr subgraph and fake execute the operator | |||
:param inputs: Tensor or list of input tensors. | |||
:param output_shapes: The output shapes. | |||
:param dump_name: The serialized subgraph name. | |||
:param dump_data: The serialized subgraph. | |||
:return: List of tensors | |||
""" | |||
if not isinstance(inputs, list): | |||
inputs = [inputs] | |||
return mgb.opr.extern_c_opr_placeholder( | |||
inputs, output_shapes, dump_name=dump_name, dump_data=dump_data, | |||
) |
@@ -1,125 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Iterable, Optional, Union | |||
import megengine._internal as mgb | |||
from ..core.graph import get_default_graph | |||
from ..core.tensor import Tensor, wrap_io_tensor | |||
from ..jit import barrier, mark_impure, trace | |||
@wrap_io_tensor | |||
def grad( | |||
target: Tensor, | |||
wrt: Union[Tensor, Iterable[Tensor]], | |||
warn_mid_wrt: bool = True, | |||
use_virtual_grad: bool = None, | |||
return_zero_for_nodep: bool = True, | |||
) -> Union[Tensor, Iterable[Optional[Tensor]], None]: | |||
r"""Compute the symbolic gradient of ``target`` with repect to ``wrt``. | |||
``wrt`` can either be a single tensor or a sequence of tensors. | |||
:param target: ``grad`` target tensor | |||
:param wrt: with respect to which to compute the gradient | |||
:param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint | |||
:param use_virtual_grad: whether to use virtual ``grad`` opr, so fwd graph can | |||
be optimized before applying ``grad``; if ``None`` is given, then virtual | |||
``grad`` would be used if ``graph_opt_level >= 2`` | |||
:param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return | |||
a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using | |||
virtual ``grad`` opr. | |||
:return: :math:`\partial\text{target} / \partial\text{wrt}` | |||
""" | |||
if not isinstance(wrt, mgb.SymbolVar): | |||
assert isinstance(wrt, collections.Iterable) | |||
wrt = [w._symvar for w in wrt] | |||
return mgb.grad(target, wrt, warn_mid_wrt, use_virtual_grad, return_zero_for_nodep) | |||
_add_update_cache = {} # type: dict | |||
_dummy = mgb.SharedScalar(0) | |||
def add_update( | |||
dest: Tensor, | |||
delta: Tensor, | |||
*, | |||
alpha: Union[Tensor, float, int] = 1.0, | |||
beta: Union[Tensor, float, int] = 1.0, | |||
bias: Union[Tensor, float, int] = 0.0 | |||
): | |||
r"""Inplace modify ``dest`` as follows: | |||
.. math:: | |||
dest = alpha * dest + beta * delta + bias | |||
:param dest: input data that will be inplace modified. | |||
:param delta: update value that will be added to ``dest``. | |||
:param alpha: weight ratio of ``dest``. Default: 1.0 | |||
:param beta: weight ratio of ``delta``. Default: 1.0 | |||
:param bias: bias value appended to the result. Default: 0.0 | |||
""" | |||
if isinstance(beta, Tensor) or isinstance(alpha, Tensor): | |||
delta *= beta | |||
beta = 1.0 | |||
if isinstance(alpha, Tensor): | |||
delta += (alpha - 1.0) * dest | |||
alpha = 1.0 | |||
if isinstance(bias, Tensor): | |||
delta += bias | |||
bias = 0.0 | |||
comp_graph = dest._comp_graph or get_default_graph() | |||
comp_node = dest._comp_node | |||
if not isinstance(delta, Tensor): | |||
_delta = mgb.make_immutable( | |||
value=delta, comp_node=comp_node, comp_graph=comp_graph | |||
) | |||
else: | |||
_delta = delta._attach(comp_graph) | |||
_dest = dest._attach(comp_graph) | |||
# use (dest, delta) as the key, so we could not add the same delta to dest in static graph | |||
key = (comp_graph._id(), _dest.id, _delta.id) | |||
if key in _add_update_cache: | |||
_alpha, _beta, _bias, config = _add_update_cache[key] | |||
mgb.mgb._mgb.SharedScalar__set(_alpha, alpha) | |||
mgb.mgb._mgb.SharedScalar__set(_beta, beta) | |||
mgb.mgb._mgb.SharedScalar__set(_bias, bias) | |||
else: | |||
_alpha = mgb.SharedScalar(alpha) | |||
_beta = mgb.SharedScalar(beta) | |||
_bias = mgb.SharedScalar(bias) | |||
config = mgb.helper.gen_config(None, comp_node, None) | |||
_add_update_cache[key] = (_alpha, _beta, _bias, config) | |||
u = mgb.mgb._Opr.add_update( | |||
_dest, barrier(_delta), _alpha, _beta, _bias, _dummy, config | |||
) | |||
mark_impure(u) | |||
if trace._active_instance: | |||
dest._override_symvar_during_trace(trace._active_instance, u) | |||
return Tensor(u) | |||
@wrap_io_tensor | |||
def add_extra_vardep(oup: Tensor, dep: Tensor): | |||
r"""Explicitly set the dependency that tensor ``oup`` depends on tensor ``dep``. | |||
""" | |||
return mgb.config.add_extra_vardep(oup, dep) |
@@ -1,391 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import megengine._internal as mgb | |||
from ..core.tensor import Tensor | |||
from .elemwise import abs, equal, log, maximum, power, relu | |||
from .nn import assert_equal, indexing_one_hot | |||
from .tensor import where | |||
from .utils import zero_grad | |||
def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean absolute error (MAE) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean absolute error can be described as: | |||
.. math:: \ell(x,y) = mean\left(L \right) | |||
where | |||
.. math:: | |||
L = \{l_1,\dots,l_N\}, \quad | |||
l_n = \left| x_n - y_n \right|, | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
loss = F.l1_loss(ipt,tgt) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[2.75] | |||
""" | |||
diff = pred - label | |||
return abs(diff).mean() | |||
def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean squared error (squared L2 norm) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean squared error can be described as: | |||
.. math:: \ell(x, y) = mean\left( L \right) | |||
where | |||
.. math:: | |||
L = \{l_1,\dots,l_N\}, \quad | |||
l_n = \left( x_n - y_n \right)^2, | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Shape: | |||
- pred: :math:`(N, *)` where :math:`*` means any number of additional | |||
dimensions | |||
- label: :math:`(N, *)`. Same shape as ``pred`` | |||
""" | |||
diff = pred - label | |||
return (diff ** 2).mean() | |||
def cross_entropy( | |||
inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||
) -> Tensor: | |||
r""" | |||
Returns the cross entropy loss in a classification problem. | |||
.. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||
:param inp: The input tensor representing the predicted probability. | |||
:param label: The input tensor representing the classification label. | |||
:param axis: An axis along which cross_entropy will be applied. Default: 1 | |||
:param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data_shape = (1, 2) | |||
label_shape = (1, ) | |||
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||
label = tensor(np.ones(label_shape, dtype=np.int32)) | |||
loss = F.cross_entropy(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.69] | |||
""" | |||
n0 = inp.ndim | |||
n1 = target.ndim | |||
assert n0 == n1 + 1, ( | |||
"target ndim must be one less than input ndim; input_ndim={} " | |||
"target_ndim={}".format(n0, n1) | |||
) | |||
if ignore_index != -1: | |||
mask = 1 - equal(target, ignore_index) | |||
target = target * mask | |||
loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||
return loss.sum() / maximum(mask.sum(), 1.0) | |||
else: | |||
return -log(indexing_one_hot(inp, target, axis)).mean() | |||
def cross_entropy_with_softmax( | |||
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||
) -> Tensor: | |||
r""" | |||
Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||
When using label smoothing, the label distribution is as follows: | |||
.. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K | |||
where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||
k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||
:param pred: The input tensor representing the predicted probability. | |||
:param label: The input tensor representing the classification label. | |||
:param axis: An axis along which softmax will be applied. Default: 1. | |||
:param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||
""" | |||
n0 = pred.ndim | |||
n1 = label.ndim | |||
assert n0 == n1 + 1, ( | |||
"target ndim must be one less than input ndim; input_ndim={} " | |||
"target_ndim={}".format(n0, n1) | |||
) | |||
num_classes = pred.shapeof(axis) | |||
# Denominator of the softmax | |||
offset = zero_grad(pred.max(axis=axis, keepdims=True)) | |||
pred = pred - offset | |||
down = mgb.opr.elem.exp(pred).sum(axis=axis, keepdims=True) | |||
up = indexing_one_hot(pred, label, axis) | |||
if label_smooth != 0: | |||
factor = label_smooth / num_classes | |||
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor | |||
return (log(down) - up).mean() | |||
def triplet_margin_loss( | |||
anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||
) -> Tensor: | |||
r""" | |||
Creates a criterion that measures the triplet loss given an input tensors. | |||
.. math:: | |||
L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||
d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||
:param anchor: The input tensor representing the anchor samples. | |||
:param positive: The input tensor representing the positive samples. | |||
:param negative: The input tensor representing the negative samples. | |||
:param margin: Default: 1.0 | |||
:param p: The norm degree for pairwise distance. Default: 2.0 | |||
""" | |||
s0 = anchor.shapeof() | |||
s1 = positive.shapeof() | |||
s2 = negative.shapeof() | |||
assert_equal(s0, s1) | |||
assert_equal(s1, s2) | |||
n0 = anchor.ndim | |||
n1 = positive.ndim | |||
n2 = negative.ndim | |||
assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||
"anchor ndim, positive ndim, and negative ndim must be 2; " | |||
"anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||
) | |||
assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||
diff0 = abs(anchor - positive) | |||
diff1 = abs(anchor - negative) | |||
d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||
d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||
loss = maximum(d1 - d2 + margin, 0) | |||
return loss.mean() | |||
def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||
r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||
:param pred: (N,*) where * means, any number of additional dimensions. | |||
:param label: (N,*), same shape as the input. | |||
""" | |||
s0 = pred.shapeof() | |||
s1 = label.shapeof() | |||
assert_equal(s0, s1) | |||
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||
def nll_loss( | |||
pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||
) -> Tensor: | |||
r""" | |||
The negative log likelihood loss. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data_shape = (2, 2) | |||
label_shape = (2, ) | |||
data = tensor( | |||
np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||
) | |||
label = tensor( | |||
np.ones(label_shape, dtype=np.int32) | |||
) | |||
pred = F.log(F.softmax(data)) | |||
loss1 = F.nll_loss(pred, label) | |||
loss2 = F.cross_entropy_with_softmax(data, label) | |||
print(loss1.numpy(), loss2.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.6576154] [0.6576154] | |||
""" | |||
n0 = pred.ndim | |||
n1 = label.ndim | |||
assert n0 == n1 + 1, ( | |||
"target ndim must be one less than input ndim; input_ndim={} " | |||
"target_ndim={}".format(n0, n1) | |||
) | |||
mask = 1.0 - equal(label, ignore_index) | |||
label = label * mask | |||
loss = indexing_one_hot(pred, label, axis) * mask | |||
return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
r""" | |||
Caculate the hinge loss which is often used in SVMs. | |||
The hinge loss can be described as: | |||
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij})) | |||
:param pred: The input tensor representing the predicted probability, shape is (N, C). | |||
:param label: The input tensor representing the binary classification label, shape is (N, C). | |||
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
label = tensor([[1, -1, -1], [-1, 1, 1]]) | |||
loss = F.hinge_loss(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1.5] | |||
""" | |||
assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||
# Converts binary labels to -1/1 labels. | |||
loss = relu(1.0 - pred * label) | |||
if norm == "L1": | |||
return loss.sum(axis=1).mean() | |||
else: | |||
return (loss ** 2).sum(axis=1).mean() | |||
def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||
The smooth l1 loss can be described as: | |||
.. math:: | |||
\text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||
where :math:`l_{i}` is given by: | |||
.. math:: | |||
l_{i} = | |||
\begin{cases} | |||
0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||
|x_i - y_i| - 0.5, & \text{otherwise } | |||
\end{cases} | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||
loss = F.smooth_l1_loss(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.5608334] | |||
""" | |||
diff = abs(pred - label) | |||
l2_loss = 0.5 * (diff ** 2) | |||
l1_loss = diff - 0.5 | |||
mask = diff < 1 | |||
loss = where(mask, l2_loss, l1_loss) | |||
return loss.mean() |
@@ -1,333 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
import numbers | |||
from typing import Optional, Sequence, Union | |||
import megengine._internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from .elemwise import clamp, exp, isinf, log | |||
from .tensor import remove_axis, where, zeros_like | |||
@wrap_io_tensor | |||
def sum(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||
Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. | |||
Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.sum(data) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[21] | |||
""" | |||
return mgb.opr.reduce_(inp, "SUM", axis, keepdims) | |||
@wrap_io_tensor | |||
def prod(inp: Tensor, axis: Optional[int] = None, keepdims=False) -> Tensor: | |||
r""" | |||
Returns the element product of input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.prod(data) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[720] | |||
""" | |||
return mgb.opr.reduce_(inp, "PRODUCT", axis, keepdims) | |||
@wrap_io_tensor | |||
def mean(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
"""Returns the mean value of each row of the ``inp`` tensor in | |||
the given ``axis``. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.mean(data) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[3.5] | |||
""" | |||
return mgb.opr.mean(inp, axis, keepdims) | |||
@wrap_io_tensor | |||
def min(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
r""" | |||
Returns the min value of input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.min(x) | |||
print(y.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1] | |||
""" | |||
return mgb.opr.reduce_(inp, "MIN", axis, keepdims) | |||
@wrap_io_tensor | |||
def max(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
r"""Returns the max value of the input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.max(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[6] | |||
""" | |||
return mgb.opr.reduce_(inp, "MAX", axis, keepdims) | |||
@wrap_io_tensor | |||
def sqrt(inp: Tensor) -> Tensor: | |||
""" | |||
Return a new tensor with the square-root of the elements of ``inp`` | |||
:param inp: The input tensor | |||
:return: The computed tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.sqrt(data) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0. 1. 1.4142] | |||
[1.7321 2. 2.2361 ]] | |||
""" | |||
return mgb.opr.sqrt(inp) | |||
def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False): | |||
"""Calculate ``p``-norm of input tensor along certain axis. | |||
:param inp: The input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
:return: The output tensor | |||
""" | |||
if axis is None: | |||
inp = inp.reshape(-1) | |||
return (inp ** p).sum(axis=axis, keepdims=keepdims) ** (1.0 / p) | |||
@wrap_io_tensor | |||
def argmin(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
r"""Returns the indices of the minimum values along an axis | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmin(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[0] | |||
""" | |||
return mgb.opr.argmin(inp, axis, keepdims) | |||
@wrap_io_tensor | |||
def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor: | |||
r"""Returns the indices of the maximum values along an axis | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmax(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[5] | |||
""" | |||
return mgb.opr.argmax(inp, axis, keepdims) | |||
def normalize( | |||
inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12 | |||
) -> Tensor: | |||
r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||
For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||
.. math:: | |||
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||
:param inp: the input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||
to calculate the norm. Default: None | |||
:param eps: a small value to avoid division by zero. Default: 1e-12 | |||
:return: the normalized output tensor | |||
""" | |||
if axis is None: | |||
return inp / clamp(norm(inp, p), lower=eps) | |||
else: | |||
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): | |||
r""" | |||
Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. | |||
.. math:: | |||
\mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) | |||
:param inp: The input tensor. | |||
:param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. | |||
:param keepdims: whether to retain :attr:`axis` or not for the output tensor. | |||
""" | |||
if isinstance(axis, numbers.Integral): | |||
axis = (axis,) | |||
max_value = inp | |||
for dim in axis: | |||
max_value = max_value.max(axis=dim, keepdims=True) | |||
max_value = where( | |||
isinf(max_value).astype("int32"), zeros_like(max_value), max_value | |||
) | |||
x = exp(inp - max_value) | |||
for dim in axis: | |||
x = x.sum(axis=dim, keepdims=True) | |||
x = max_value + log(x) | |||
if not keepdims: | |||
axis = sorted(axis, reverse=True) | |||
for i in axis: | |||
x = remove_axis(x, axis=i) | |||
return x |
@@ -1,80 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=too-many-lines | |||
from typing import Tuple, Union | |||
from .. import _internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..utils.types import _pair, _pair_nonzero | |||
from .debug_param import get_conv_execution_strategy | |||
@wrap_io_tensor | |||
def conv_bias_activation( | |||
inp: Tensor, | |||
weight: Tensor, | |||
bias: Tensor, | |||
dtype=None, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
nonlinear_mode="IDENTITY", | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" convolution bias with activation operation, only for inference. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. | |||
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
:param dtype: Support for np.dtype, Default: | |||
np.int8. | |||
:type compute_mode: string or | |||
:class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
""" | |||
ph, pw = _pair(padding) | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
res = mgb.opr.conv_bias_activation( | |||
inp, | |||
weight, | |||
bias, | |||
compute_mode=compute_mode, | |||
dtype=dtype, | |||
strategy=get_conv_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
sparse=sparse_type, | |||
format="NCHW", | |||
pad_h=ph, | |||
pad_w=pw, | |||
stride_h=sh, | |||
stride_w=sw, | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
mode=conv_mode, | |||
) | |||
return res |
@@ -1,123 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
from typing import Optional, Tuple, Union | |||
import megengine._internal as mgb | |||
from ..core.tensor import Tensor, wrap_io_tensor | |||
__all__ = ["argsort", "sort", "top_k"] | |||
@wrap_io_tensor | |||
def argsort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||
:return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([1,2], dtype=np.float32)) | |||
sorted, indices = F.argsort(data) | |||
print(sorted.numpy(), indices.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1. 2.] [0 1] | |||
""" | |||
assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d" | |||
if descending: | |||
order = mgb.opr_param_defs.Argsort.Order.DESCENDING | |||
else: | |||
order = mgb.opr_param_defs.Argsort.Order.ASCENDING | |||
if len(inp.imm_shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
tns, ind = mgb.opr.argsort(inp, order=order) | |||
return tns[0], ind[0] | |||
return mgb.opr.argsort(inp, order=order) | |||
@functools.wraps(argsort) | |||
def sort(*args, **kwargs): | |||
return argsort(*args, **kwargs) | |||
@wrap_io_tensor | |||
def top_k( | |||
inp: Tensor, | |||
k: int, | |||
descending: bool = False, | |||
kth_only: bool = False, | |||
no_sort: bool = False, | |||
) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param k: The number of elements needed | |||
:param descending: If true, return the largest elements instead. Default: ``False`` | |||
:param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||
:param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||
:return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
top, indices = F.top_k(data, 5) | |||
print(top.numpy(), indices.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1. 2. 3. 4. 5.] [7 0 6 1 5] | |||
""" | |||
assert len(inp.imm_shape) <= 2, "Input should be 1d or 2d" | |||
if kth_only: | |||
raise NotImplementedError( | |||
"TODO: would enconter:" | |||
"NotImplementedError: SymbolVar var could not be itered" | |||
) | |||
if descending: | |||
inp = -inp | |||
Mode = mgb.opr_param_defs.TopK.Mode | |||
if kth_only: | |||
mode = Mode.KTH_ONLY | |||
elif no_sort: | |||
mode = Mode.VALUE_IDX_NOSORT | |||
else: | |||
mode = Mode.VALUE_IDX_SORTED | |||
if len(inp.imm_shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
tns, ind = mgb.opr.top_k(inp, k, mode=mode) | |||
tns = tns[0] | |||
ind = ind[0] | |||
else: | |||
tns, ind = mgb.opr.top_k(inp, k, mode=mode) | |||
if descending: | |||
tns = -tns | |||
return tns, ind |
@@ -1,667 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
from typing import Iterable, List, Optional, Union | |||
import numpy as np | |||
import megengine._internal as mgb | |||
from megengine._internal import CompGraph, CompNode | |||
from ..core import zeros | |||
from ..core.graph import _use_default_if_none | |||
from ..core.tensor import Tensor, wrap_io_tensor | |||
from .elemwise import ceil | |||
from .utils import _decide_comp_node_and_comp_graph | |||
@wrap_io_tensor | |||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
""" | |||
Broadcast a tensor to ``shape`` | |||
:param inp: The input tensor | |||
:param shape: The target shape | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.broadcast_to(data, (4, 2, 3)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]]] | |||
""" | |||
if isinstance(shape, int): | |||
shape = (shape,) | |||
return mgb.opr.broadcast(inp, shape) | |||
def _get_idx(index, axis): | |||
index_dims = len(index.imm_shape) | |||
idx = [] | |||
comp_node, comp_graph = _decide_comp_node_and_comp_graph(index) | |||
for i in range(index_dims): | |||
if i != axis: | |||
shape = [1] * index_dims | |||
shape[i] = index.axis_shape(i) | |||
arange = mgb.opr.linspace( | |||
0, | |||
index.axis_shape(i) - 1, | |||
index.axis_shape(i), | |||
comp_node=comp_node, | |||
comp_graph=comp_graph, | |||
) | |||
arange = ( | |||
arange.reshape(*shape) | |||
.broadcast(index.shape) | |||
.reshape(-1) | |||
.astype(np.int32) | |||
) | |||
idx.append(arange) | |||
else: | |||
idx.append(index.reshape(-1)) | |||
return tuple(idx) | |||
@wrap_io_tensor | |||
def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
r""" | |||
Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||
For a 3-D tensor, the output is specified by:: | |||
out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 | |||
out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||
out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||
if :attr:`inp` is an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||
then :attr:`index` must be an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||
output will have the same size as :attr:`index`. | |||
:param inp: the source tensor | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to gather | |||
Examples: | |||
.. testcode:: | |||
import megengine.functional as F | |||
from megengine.core import tensor | |||
inp = tensor([ | |||
[1,2], [3,4], [5,6], | |||
]) | |||
index = tensor([[0,2], [1,0]]) | |||
oup = F.gather(inp, 0, index) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 6] | |||
[3 2]] | |||
""" | |||
input_shape = inp.imm_shape | |||
index_shape = index.imm_shape | |||
input_dims = len(input_shape) | |||
index_dims = len(index_shape) | |||
if input_dims != index_dims: | |||
raise ValueError( | |||
"The index tensor must have same dimensions as input tensor, " | |||
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | |||
) | |||
if axis < 0 or axis >= input_dims: | |||
raise ValueError( | |||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||
axis, input_dims | |||
) | |||
) | |||
for i in range(input_dims): | |||
if i != axis and input_shape[i] != index_shape[i]: | |||
raise ValueError( | |||
"The input {} and index {} must have the same size apart from axis {}".format( | |||
input_shape, index_shape, axis | |||
) | |||
) | |||
idx = _get_idx(index, axis) | |||
return mgb.opr.advanced_indexing(inp)[idx].reshape( | |||
index.shape | |||
) # pylint: disable=no-member | |||
@wrap_io_tensor | |||
def concat( | |||
inps: Iterable[Tensor], | |||
axis: int = 0, | |||
device: Optional[CompNode] = None, | |||
comp_graph: Optional[CompGraph] = None, | |||
) -> Tensor: | |||
r""" | |||
Concat some tensors | |||
:param inps: Input tensors to concat | |||
:param axis: the dimension over which the tensors are concatenated. Default: 0 | |||
:param device: The comp node output on. Default: None | |||
:param comp_graph: The graph in which output is. Default: None | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
out = F.concat([data1, data2]) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[ 0. 1. 2.] | |||
[ 3. 4. 5.] | |||
[ 6. 7. 8.] | |||
[ 9. 10. 11.]] | |||
""" | |||
# Output buffer not supported | |||
return mgb.opr.concat( | |||
*list(inps), axis=axis, comp_node=device, comp_graph=comp_graph | |||
) | |||
@wrap_io_tensor | |||
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
r""" | |||
Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||
For each value in :attr:`source`, its output index is specified by its index | |||
in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||
:attr:`index` for ``axis = dimension``. | |||
For a 3-D tensor, :attr:`inp` is updated as:: | |||
inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||
inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||
inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||
:attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||
It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||
for all dimensions ``d``. | |||
Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
.. note:: | |||
Please notice that, due to performance issues, the result is uncertain on the GPU device | |||
if scatter difference positions from source to the same destination position | |||
regard to index tensor. | |||
Show the case using the following examples, the oup[0][2] is maybe | |||
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||
if set the index[1][2] from 1 to 0. | |||
:param inp: the inp tensor which to be scattered | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to scatter | |||
:param source: the source element(s) to scatter | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
from megengine.core import tensor | |||
inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) | |||
source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | |||
index = tensor([[0,2,0,2,1],[2,0,1,1,2]]) | |||
oup = F.scatter(inp, 0, index,source) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0.9935 0.0718 0.2256 0. 0. ] | |||
[0. 0. 0.5939 0.357 0.4396] | |||
[0.7723 0.9465 0. 0.8926 0.4576]] | |||
""" | |||
input_shape = inp.imm_shape | |||
index_shape = index.imm_shape | |||
source_shape = source.imm_shape | |||
input_dims = len(input_shape) | |||
index_dims = len(index_shape) | |||
source_dims = len(source_shape) | |||
if input_dims != index_dims or input_dims != source_dims: | |||
raise ValueError("The input, source and index tensor must have same dimensions") | |||
if axis < 0 or axis >= input_dims: | |||
raise ValueError( | |||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||
axis, input_dims | |||
) | |||
) | |||
for i in range(source_dims): | |||
if source_shape[i] > input_shape[i]: | |||
raise ValueError( | |||
"The each shape size for source {} must be less than or equal to input {} ".format( | |||
source_shape, input_shape | |||
) | |||
) | |||
for i in range(index_dims): | |||
if index_shape[i] != source_shape[i]: | |||
raise ValueError( | |||
"The each shape size for index {} must be equal to source {} ".format( | |||
index_shape, source_shape | |||
) | |||
) | |||
for i in range(index_dims): | |||
if i != axis and index_shape[i] > input_shape[i]: | |||
raise ValueError( | |||
"The index {} must be less than or equal to input {} size apart from axis {}".format( | |||
index_shape, input_shape, axis | |||
) | |||
) | |||
idx = _get_idx(index, axis) | |||
return mgb.opr.set_advanced_indexing(inp, source.flatten())[idx] | |||
@wrap_io_tensor | |||
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
r""" | |||
Select elements either from Tensor x or Tensor y, according to mask. | |||
.. math:: | |||
\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||
:param mask: a mask used for choosing x or y | |||
:param x: the first choice | |||
:param y: the second choice | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
dtype=np.float32)) | |||
y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) | |||
out = F.where(mask, x, y) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1. 6.] | |||
[7. 4.]] | |||
""" | |||
v0, index0 = mgb.opr.cond_take( | |||
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=1 | |||
) | |||
v1, index1 = mgb.opr.cond_take( | |||
y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0 | |||
) | |||
out = x.flatten() | |||
index = mgb.opr.concat(index0, index1, axis=0) | |||
v = mgb.opr.concat(v0, v1, axis=0) | |||
out = mgb.opr.set_advanced_indexing(out, v)[index] | |||
out = out.reshape(x.shape) | |||
return out | |||
@wrap_io_tensor | |||
def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: | |||
r""" | |||
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
:param mask: condition param; must be the same shape with data | |||
:param x: input tensor from which to take elements | |||
:param val: value to be compared to by mode | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
dtype=np.float32)) | |||
v, index = F.cond_take(mask, x, 1) | |||
print(v, index) | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
""" | |||
v, index = mgb.opr.cond_take( | |||
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val | |||
) | |||
return v, index | |||
def shapeof(x: Tensor, axis=None): | |||
r""" | |||
The shape of input tensor. | |||
""" | |||
return x.shapeof(axis=axis) | |||
@wrap_io_tensor | |||
def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
r""" | |||
Swap shapes and strides according to given pattern | |||
:param inp: Input tensor | |||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
* (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
* (0, 1) -> identity for 2d vectors | |||
* (1, 0) -> inverts the first and second dimensions | |||
* (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN) | |||
* (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1) | |||
* (2, 0, 1) -> AxBxC to CxAxB | |||
* (0, ``'x'``, 1) -> AxB to Ax1xB | |||
* (1, ``'x'``, 0) -> AxB to Bx1xA | |||
* (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||
out = F.dimshuffle(x, (1, 0)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 0] | |||
[1 0]] | |||
""" | |||
return mgb.opr.dimshuffle(inp, pattern) | |||
@wrap_io_tensor | |||
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
r""" | |||
Reshape a tensor to given target shape; total number of logical elements must | |||
remain unchanged | |||
:param inp: Input tensor | |||
:param target_shape: target shape, the components would be concatenated to form the | |||
target shape, and it can contain an element of -1 representing unspec_axis. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(12, dtype=np.int32)) | |||
out = F.reshape(x, (3, 2, 2)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[ 0 1] | |||
[ 2 3]] | |||
[[ 4 5] | |||
[ 6 7]] | |||
[[ 8 9] | |||
[10 11]]] | |||
""" | |||
return mgb.opr.reshape(inp, target_shape) | |||
def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
r"""Equivalent to :func:`dimshuffle` | |||
""" | |||
return dimshuffle(inp, pattern) | |||
@wrap_io_tensor | |||
def add_axis(inp: Tensor, axis: int) -> Tensor: | |||
r""" | |||
Add dimension before given axis. | |||
:param inp: Input tensor | |||
:param axis: Place of new axes | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, 2]) | |||
out = F.add_axis(x, 0) | |||
print(out.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(1, 2) | |||
""" | |||
if not isinstance(axis, int): | |||
raise ValueError("axis must be int, but got type:{}".format(type(axis))) | |||
return mgb.opr.add_axis(inp, axis) | |||
@wrap_io_tensor | |||
def remove_axis(inp: Tensor, axis: int) -> Tensor: | |||
r""" | |||
Remove dimension of shape 1. | |||
:param inp: Input tensor | |||
:param axis: Place of axis to be removed | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||
out = F.remove_axis(x, 3) | |||
print(out.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(1, 1, 2) | |||
""" | |||
if not isinstance(axis, int): | |||
raise ValueError("axis must be int, but got type:{}".format(type(axis))) | |||
return mgb.opr.remove_axis(inp, axis) | |||
def linspace( | |||
start: Union[int, float, Tensor], | |||
stop: Union[int, float, Tensor], | |||
num: Union[int, Tensor], | |||
dtype=np.float32, | |||
device: Optional[CompNode] = None, | |||
comp_graph: Optional[CompGraph] = None, | |||
) -> Tensor: | |||
r""" | |||
Return equally spaced numbers over a specified interval | |||
:param start: Starting value of the squence, shoule be scalar | |||
:param stop: The last value of the squence, shoule be scalar | |||
:param num: number of values to generate | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
a = F.linspace(3,10,5) | |||
print(a.numpy()) | |||
.. testoutput:: | |||
[ 3. 4.75 6.5 8.25 10. ] | |||
""" | |||
if dtype is not np.float32: | |||
raise ValueError("linspace is only implemented for float32") | |||
device, comp_graph = _use_default_if_none(device, comp_graph) | |||
ret = Tensor( | |||
mgb.opr.linspace(start, stop, num, comp_node=device, comp_graph=comp_graph) | |||
) | |||
return ret.astype(dtype) | |||
def arange( | |||
start: Union[int, float, Tensor], | |||
end: Union[int, float, Tensor], | |||
step: Union[int, float, Tensor] = 1, | |||
dtype=np.float32, | |||
device: Optional[CompNode] = None, | |||
comp_graph: Optional[CompGraph] = None, | |||
) -> Tensor: | |||
r""" | |||
Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||
:param start: starting value of the squence, shoule be scalar | |||
:param end: ending value of the squence, shoule be scalar | |||
:param step: the gap between each pair of adjacent values. Default 1 | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
a = F.arange(1, 5, 1) | |||
print(a.numpy()) | |||
.. testoutput:: | |||
[1. 2. 3. 4.] | |||
""" | |||
if dtype is not np.float32: | |||
raise ValueError("arange is only implemented for float32") | |||
num = ceil((end - start) / step) | |||
stop = start + step * (num - 1) | |||
ret = linspace(start, stop, num, device=device, comp_graph=comp_graph) | |||
return ret | |||
def zeros_like(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a zero tensor with the same shape as input tensor | |||
:param inp: input tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
out = F.zeros_like(inp) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[[0 0 0] | |||
[0 0 0]] | |||
""" | |||
return zeros(inp.shapeof()).astype(inp.dtype) |
@@ -1,81 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Union | |||
import megengine._internal as mgb | |||
from ..core.graph import _use_default_if_none | |||
from ..core.tensor import Tensor, wrap_io_tensor | |||
from .elemwise import equal | |||
from .sort import top_k | |||
def _decide_comp_node_and_comp_graph(*args: mgb.SymbolVar): | |||
for i in args: | |||
if isinstance(i, mgb.SymbolVar): | |||
return i.comp_node, i.owner_graph | |||
return _use_default_if_none(None, None) | |||
def accuracy( | |||
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 | |||
) -> Union[Tensor, Iterable[Tensor]]: | |||
r""" | |||
Calculate the classification accuracy given predicted logits and ground-truth labels. | |||
:param logits: Model predictions of shape [batch_size, num_classes], | |||
representing the probability (likelyhood) of each class. | |||
:param target: Ground-truth labels, 1d tensor of int32 | |||
:param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
:return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) | |||
target = tensor(np.arange(8, dtype=np.int32)) | |||
top1, top5 = F.accuracy(logits, target, (1, 5)) | |||
print(top1.numpy(), top5.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.] [0.375] | |||
""" | |||
if isinstance(topk, int): | |||
topk = (topk,) | |||
_, pred = top_k(logits, k=max(topk), descending=True) | |||
accs = [] | |||
for k in topk: | |||
correct = equal( | |||
pred[:, :k], target.dimshuffle(0, "x").broadcast(target.shapeof(0), k) | |||
) | |||
accs.append(correct.sum() / target.shapeof(0)) | |||
if len(topk) == 1: # type: ignore[arg-type] | |||
accs = accs[0] | |||
return accs | |||
@wrap_io_tensor | |||
def zero_grad(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a tensor which is treated as constant during backward gradient calcuation, | |||
i.e. its gradient is zero. | |||
:param inp: Input tensor. | |||
See implementation of :func:`~.softmax` for example. | |||
""" | |||
return mgb.opr.zero_grad(inp) |
@@ -1,16 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .hub import ( | |||
help, | |||
import_module, | |||
list, | |||
load, | |||
load_serialized_obj_from_url, | |||
pretrained, | |||
) |
@@ -1,17 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
DEFAULT_BRANCH_NAME = "master" | |||
HUBCONF = "hubconf.py" | |||
HUBDEPENDENCY = "dependencies" | |||
DEFAULT_GIT_HOST = "github.com" | |||
ENV_MGE_HOME = "MGE_HOME" | |||
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |||
DEFAULT_CACHE_DIR = "~/.cache" | |||
DEFAULT_PROTOCOL = "HTTPS" | |||
HTTP_READ_TIMEOUT = 120 |
@@ -1,30 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
class FetcherError(Exception): | |||
"""Base class for fetch related error.""" | |||
class InvalidRepo(FetcherError): | |||
"""The repo provided was somehow invalid.""" | |||
class InvalidGitHost(FetcherError): | |||
"""The git host provided was somehow invalid.""" | |||
class GitPullError(FetcherError): | |||
"""A git pull error occurred""" | |||
class GitCheckoutError(FetcherError): | |||
"""A git checkout error occurred""" | |||
class InvalidProtocol(FetcherError): | |||
"""The protocol provided was somehow invalid""" |
@@ -1,300 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import hashlib | |||
import os | |||
import re | |||
import shutil | |||
import subprocess | |||
from tempfile import NamedTemporaryFile | |||
from typing import Tuple | |||
from zipfile import ZipFile | |||
import requests | |||
from tqdm import tqdm | |||
from megengine.utils.http_download import ( | |||
CHUNK_SIZE, | |||
HTTP_CONNECTION_TIMEOUT, | |||
HTTPDownloadError, | |||
) | |||
from ..distributed.util import is_distributed, synchronized | |||
from ..logger import get_logger | |||
from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT | |||
from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo | |||
from .tools import cd | |||
logger = get_logger(__name__) | |||
HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT) | |||
pattern = re.compile( | |||
r"^(?:[a-z0-9]" # First character of the domain | |||
r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname | |||
r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD | |||
r"[a-z]$" # Last character of the gTLD | |||
) | |||
class RepoFetcherBase: | |||
@classmethod | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
raise NotImplementedError() | |||
@classmethod | |||
def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: | |||
try: | |||
branch_info = DEFAULT_BRANCH_NAME | |||
if ":" in repo_info: | |||
prefix_info, branch_info = repo_info.split(":") | |||
else: | |||
prefix_info = repo_info | |||
repo_owner, repo_name = prefix_info.split("/") | |||
return repo_owner, repo_name, branch_info | |||
except ValueError: | |||
raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) | |||
@classmethod | |||
def _check_git_host(cls, git_host): | |||
return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host) | |||
@classmethod | |||
def _is_valid_domain(cls, s): | |||
try: | |||
return pattern.match(s.encode("idna").decode("ascii")) | |||
except UnicodeError: | |||
return False | |||
@classmethod | |||
def _is_valid_host(cls, s): | |||
nums = s.split(".") | |||
if len(nums) != 4 or any(not _.isdigit() for _ in nums): | |||
return False | |||
return all(0 <= int(_) < 256 for _ in nums) | |||
@classmethod | |||
def _gen_repo_dir(cls, repo_dir: str) -> str: | |||
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||
class GitSSHFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
""" | |||
Fetches git repo by SSH protocol | |||
:param git_host: | |||
host address of git repo. | |||
example: github.com | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param use_cache: | |||
whether to use locally fetched code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param silent: | |||
whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
displaying on the screen | |||
:return: | |||
directory where the repo code is stored | |||
""" | |||
if not cls._check_git_host(git_host): | |||
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
normalized_branch_info = branch_info.replace("/", "_") | |||
repo_dir_raw = "{}_{}_{}".format( | |||
repo_owner, repo_name, normalized_branch_info | |||
) + ("_{}".format(commit) if commit else "") | |||
repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) | |||
if use_cache and os.path.exists(repo_dir): # use cache | |||
logger.debug("Cache Found in %s", repo_dir) | |||
return repo_dir | |||
if is_distributed(): | |||
logger.warning( | |||
"When using `hub.load` or `hub.list` to fetch git repositories\n" | |||
" in DISTRIBUTED mode for the first time, processes are synchronized to\n" | |||
" ensure that target repository is ready to use for each process.\n" | |||
" Users are expected to see this warning no more than ONCE, otherwise\n" | |||
" (very little chance) you may need to remove corrupt cache\n" | |||
" `%s` and fetch again.", | |||
repo_dir, | |||
) | |||
shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
logger.debug( | |||
"Git Clone from Repo:%s Branch: %s to %s", | |||
git_url, | |||
normalized_branch_info, | |||
repo_dir, | |||
) | |||
kwargs = ( | |||
{"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} | |||
) | |||
if commit is None: | |||
# shallow clone repo by branch/tag | |||
p = subprocess.Popen( | |||
[ | |||
"git", | |||
"clone", | |||
"-b", | |||
normalized_branch_info, | |||
git_url, | |||
repo_dir, | |||
"--depth=1", | |||
], | |||
**kwargs, | |||
) | |||
cls._check_clone_pipe(p) | |||
else: | |||
# clone repo and checkout to commit_id | |||
p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs) | |||
cls._check_clone_pipe(p) | |||
with cd(repo_dir): | |||
logger.debug("git checkout to %s", commit) | |||
p = subprocess.Popen(["git", "checkout", commit], **kwargs) | |||
_, err = p.communicate() | |||
if p.returncode: | |||
shutil.rmtree(repo_dir, ignore_errors=True) | |||
raise GitCheckoutError( | |||
"Git checkout error, please check the commit id.\n" | |||
+ err.decode() | |||
) | |||
with cd(repo_dir): | |||
shutil.rmtree(".git") | |||
return repo_dir | |||
@classmethod | |||
def _check_clone_pipe(cls, p): | |||
_, err = p.communicate() | |||
if p.returncode: | |||
raise GitPullError( | |||
"Repo pull error, please check repo info.\n" + err.decode() | |||
) | |||
class GitHTTPSFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
""" | |||
Fetches git repo by HTTPS protocol | |||
:param git_host: | |||
host address of git repo | |||
example: github.com | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param silent: | |||
whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
displaying on the screen | |||
:return: | |||
directory where the repo code is stored | |||
""" | |||
if not cls._check_git_host(git_host): | |||
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
normalized_branch_info = branch_info.replace("/", "_") | |||
repo_dir_raw = "{}_{}_{}".format( | |||
repo_owner, repo_name, normalized_branch_info | |||
) + ("_{}".format(commit) if commit else "") | |||
repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
archive_url = cls._git_archive_link( | |||
git_host, repo_owner, repo_name, branch_info, commit | |||
) | |||
if use_cache and os.path.exists(repo_dir): # use cache | |||
logger.debug("Cache Found in %s", repo_dir) | |||
return repo_dir | |||
if is_distributed(): | |||
logger.warning( | |||
"When using `hub.load` or `hub.list` to fetch git repositories " | |||
"in DISTRIBUTED mode for the first time, processes are synchronized to " | |||
"ensure that target repository is ready to use for each process.\n" | |||
"Users are expected to see this warning no more than ONCE, otherwise" | |||
"(very little chance) you may need to remove corrupt hub cache %s and fetch again." | |||
) | |||
shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
logger.debug("Downloading from %s to %s", archive_url, repo_dir) | |||
cls._download_zip_and_extract(archive_url, repo_dir) | |||
return repo_dir | |||
@classmethod | |||
def _download_zip_and_extract(cls, url, target_dir): | |||
resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True) | |||
if resp.status_code != 200: | |||
raise HTTPDownloadError( | |||
"An error occured when downloading from {}".format(url) | |||
) | |||
total_size = int(resp.headers.get("Content-Length", 0)) | |||
_bar = tqdm(total=total_size, unit="iB", unit_scale=True) | |||
with NamedTemporaryFile("w+b") as f: | |||
for chunk in resp.iter_content(CHUNK_SIZE): | |||
if not chunk: | |||
break | |||
_bar.update(len(chunk)) | |||
f.write(chunk) | |||
_bar.close() | |||
f.seek(0) | |||
with ZipFile(f) as temp_zip_f: | |||
zip_dir_name = temp_zip_f.namelist()[0].split("/")[0] | |||
temp_zip_f.extractall(".") | |||
shutil.move(zip_dir_name, target_dir) | |||
@classmethod | |||
def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit): | |||
archive_link = "https://{}/{}/{}/archive/{}.zip".format( | |||
git_host, repo_owner, repo_name, commit or branch_info | |||
) | |||
return archive_link |
@@ -1,333 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import hashlib | |||
import os | |||
import sys | |||
import types | |||
from typing import Any, List | |||
from urllib.parse import urlparse | |||
from megengine.utils.http_download import download_from_url | |||
from ..core.serialization import load as _mge_load_serialized | |||
from ..distributed import is_distributed | |||
from ..logger import get_logger | |||
from .const import ( | |||
DEFAULT_CACHE_DIR, | |||
DEFAULT_GIT_HOST, | |||
DEFAULT_PROTOCOL, | |||
ENV_MGE_HOME, | |||
ENV_XDG_CACHE_HOME, | |||
HTTP_READ_TIMEOUT, | |||
HUBCONF, | |||
HUBDEPENDENCY, | |||
) | |||
from .exceptions import InvalidProtocol | |||
from .fetcher import GitHTTPSFetcher, GitSSHFetcher | |||
from .tools import cd, check_module_exists, load_module | |||
logger = get_logger(__name__) | |||
PROTOCOLS = { | |||
"HTTPS": GitHTTPSFetcher, | |||
"SSH": GitSSHFetcher, | |||
} | |||
def _get_megengine_home() -> str: | |||
"""MGE_HOME setting complies with the XDG Base Directory Specification | |||
""" | |||
megengine_home = os.path.expanduser( | |||
os.getenv( | |||
ENV_MGE_HOME, | |||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), | |||
) | |||
) | |||
return megengine_home | |||
def _get_repo( | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> str: | |||
if protocol not in PROTOCOLS: | |||
raise InvalidProtocol( | |||
"Invalid protocol, the value should be one of {}.".format( | |||
", ".join(PROTOCOLS.keys()) | |||
) | |||
) | |||
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
with cd(cache_dir): | |||
fetcher = PROTOCOLS[protocol] | |||
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) | |||
return os.path.join(cache_dir, repo_dir) | |||
def _check_dependencies(module: types.ModuleType) -> None: | |||
if not hasattr(module, HUBDEPENDENCY): | |||
return | |||
dependencies = getattr(module, HUBDEPENDENCY) | |||
if not dependencies: | |||
return | |||
missing_deps = [m for m in dependencies if not check_module_exists(m)] | |||
if len(missing_deps): | |||
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) | |||
def _init_hub( | |||
repo_info: str, | |||
git_host: str, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
): | |||
"""Imports hubmodule like python import | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
hubconf.py as a python module | |||
""" | |||
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
absolute_repo_dir = _get_repo( | |||
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol | |||
) | |||
sys.path.insert(0, absolute_repo_dir) | |||
hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) | |||
sys.path.remove(absolute_repo_dir) | |||
return hubmodule | |||
@functools.wraps(_init_hub) | |||
def import_module(*args, **kwargs): | |||
return _init_hub(*args, **kwargs) | |||
def list( | |||
repo_info: str, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> List[str]: | |||
"""Lists all entrypoints available in repo hubconf | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
all entrypoint names of the model | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
return [ | |||
_ | |||
for _ in dir(hubmodule) | |||
if not _.startswith("__") and callable(getattr(hubmodule, _)) | |||
] | |||
def load( | |||
repo_info: str, | |||
entry: str, | |||
*args, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
**kwargs | |||
) -> Any: | |||
"""Loads model from github or gitlab repo, with pretrained weights. | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param entry: | |||
an entrypoint defined in hubconf | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
a single model with corresponding pretrained weights. | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
_check_dependencies(hubmodule) | |||
module = getattr(hubmodule, entry)(*args, **kwargs) | |||
return module | |||
def help( | |||
repo_info: str, | |||
entry: str, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> str: | |||
"""This function returns docstring of entrypoint ``entry`` by following steps: | |||
1. Pull the repo code specified by git and repo_info | |||
2. Load the entry defined in repo's hubconf.py | |||
3. Return docstring of function entry | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param entry: | |||
an entrypoint defined in hubconf.py | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
docstring of entrypoint ``entry`` | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
doc = getattr(hubmodule, entry).__doc__ | |||
return doc | |||
def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||
"""Loads MegEngine serialized object from the given URL. | |||
If the object is already present in ``model_dir``, it's deserialized and | |||
returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. | |||
:param url: url to serialized object | |||
:param model_dir: dir to cache target serialized file | |||
:return: loaded object | |||
""" | |||
if model_dir is None: | |||
model_dir = os.path.join(_get_megengine_home(), "serialized") | |||
os.makedirs(model_dir, exist_ok=True) | |||
parts = urlparse(url) | |||
filename = os.path.basename(parts.path) | |||
# use hash as prefix to avoid filename conflict from different urls | |||
sha256 = hashlib.sha256() | |||
sha256.update(url.encode()) | |||
digest = sha256.hexdigest()[:6] | |||
filename = digest + "_" + filename | |||
cached_file = os.path.join(model_dir, filename) | |||
logger.info( | |||
"load_serialized_obj_from_url: download to or using cached %s", cached_file | |||
) | |||
if not os.path.exists(cached_file): | |||
if is_distributed(): | |||
logger.warning( | |||
"Downloading serialized object in DISTRIBUTED mode\n" | |||
" File may be downloaded multiple times. We recommend\n" | |||
" users to download in single process first." | |||
) | |||
download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||
state_dict = _mge_load_serialized(cached_file) | |||
return state_dict | |||
class pretrained: | |||
r""" | |||
Decorator which helps to download pretrained weights from the given url. | |||
For example, we can decorate a resnet18 function as follows | |||
.. code-block:: | |||
@hub.pretrained("https://url/to/pretrained_resnet18.pkl") | |||
def resnet18(**kwargs): | |||
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
When decorated function is called with ``pretrained=True``, MegEngine will automatically | |||
download and fill the returned model with pretrained weights. | |||
""" | |||
def __init__(self, url): | |||
self.url = url | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
def pretrained_model_func( | |||
pretrained=False, **kwargs | |||
): # pylint: disable=redefined-outer-name | |||
model = func(**kwargs) | |||
if pretrained: | |||
weights = load_serialized_obj_from_url(self.url) | |||
model.load_state_dict(weights) | |||
return model | |||
return pretrained_model_func | |||
__all__ = [ | |||
"list", | |||
"load", | |||
"help", | |||
"load_serialized_obj_from_url", | |||
"pretrained", | |||
"import_module", | |||
] |
@@ -1,48 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import importlib.util | |||
import os | |||
import types | |||
from contextlib import contextmanager | |||
from typing import Iterator | |||
def load_module(name: str, path: str) -> types.ModuleType: | |||
""" | |||
Loads module specified by name and path | |||
:param name: module name | |||
:param path: module path | |||
""" | |||
spec = importlib.util.spec_from_file_location(name, path) | |||
module = importlib.util.module_from_spec(spec) | |||
spec.loader.exec_module(module) | |||
return module | |||
def check_module_exists(module: str) -> bool: | |||
"""Checks whether python module exists or not | |||
:param module: name of module | |||
""" | |||
return importlib.util.find_spec(module) is not None | |||
@contextmanager | |||
def cd(target: str) -> Iterator[None]: | |||
"""Changes current directory to target | |||
:param target: target directory | |||
""" | |||
prev = os.getcwd() | |||
os.chdir(os.path.expanduser(target)) | |||
try: | |||
yield | |||
finally: | |||
os.chdir(prev) |
@@ -1,570 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import functools | |||
import itertools | |||
import os | |||
from typing import Callable, Tuple, Union | |||
import numpy as np | |||
import megengine._internal as mgb | |||
from megengine._internal.plugin import CompGraphProfiler | |||
from ..core import Tensor, graph, tensor | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
def sideeffect(f): | |||
# during eager tracing, wrapped function is called with proxy inputs | |||
# during static tracing, wrapped function will not be called at all | |||
@functools.wraps(f) | |||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements | |||
if not trace._active_instance: | |||
return f(*args, **kwargs) | |||
tensors = {} | |||
for i, x in itertools.chain(enumerate(args), kwargs.items()): | |||
if isinstance(x, Tensor): | |||
tensors[i] = x | |||
if tensors: | |||
_keys, tensors = zip(*tensors.items()) | |||
else: | |||
_keys, tensors = (), () | |||
def callback(*tensors, f=f, keys=_keys, args=args, kwargs=kwargs): | |||
replace = dict(zip(keys, tensors)) | |||
args = tuple(replace.get(i, x) for i, x in enumerate(args)) | |||
kwargs = {i: replace.get(i, x) for i, x in kwargs.items()} | |||
if f(*args, **kwargs) is not None: | |||
raise TypeError("a sideeffect function should return None") | |||
# TODO: clear memory | |||
trace._active_instance._register_callback(callback, tensors) | |||
return wrapper | |||
def mark_impure(x): | |||
if not trace._active_instance: | |||
return x | |||
return trace._active_instance._mark_impure(x) | |||
def barrier(x): | |||
if not trace._active_instance: | |||
return x | |||
return trace._active_instance._insert_barrier(x) | |||
def _dummy(): | |||
return mgb.make_immutable(*graph._use_default_if_none(None, None), 0) | |||
class unset: | |||
pass | |||
class trace: | |||
""" | |||
Wrap a callable and provide: | |||
* tracing via :meth:`.trace` and :meth:`.dump` | |||
* accelerated evalutaion via :meth:`.__call__` | |||
:param func: Positional only argument. | |||
:param symbolic: Whether to use symbolic tensor. Default: False | |||
:param opt_level: Optimization level for compiling trace. | |||
:param log_level: Log level. | |||
:param sublinear_memory_config: Configuration for sublinear memory optimization. | |||
If not None, it enables sublinear memory optimization with given setting. | |||
:param allreduce_pack_max_size: Maximum size of an allreduce pack in MB. | |||
If not None, multiple gradients will be packed and synchronized together | |||
:param profiling: Whether to profile compiled trace. Default: False | |||
""" | |||
_active_instance = None | |||
enabled = not os.getenv("MGE_DISABLE_TRACE") | |||
_UNSTARTED = "unstarted" | |||
_STARTED = "started" | |||
_FINISHED = "finished" | |||
def __new__(cls, *args, **kwargs): | |||
if not args: | |||
return functools.partial(cls, **kwargs) | |||
return super().__new__(cls) | |||
def __init__( | |||
self, | |||
func: Callable[..., Union[None, Tensor, Tuple[Tensor]]], | |||
*, | |||
symbolic: bool = False, | |||
opt_level: int = None, | |||
log_level: int = None, | |||
sublinear_memory_config: SublinearMemoryConfig = None, | |||
allreduce_pack_max_size: int = None, | |||
profiling: bool = False | |||
): | |||
self.__wrapped__ = func | |||
self._symbolic = symbolic | |||
self._graph_opt_level = opt_level | |||
self._log_level = log_level | |||
self._sublinear_memory_config = sublinear_memory_config | |||
self._allreduce_pack_max_size = allreduce_pack_max_size | |||
self._status = self._UNSTARTED | |||
self._args = None | |||
self._kwargs = None | |||
self._outputs = unset | |||
self._sym_outputs = unset | |||
self._outspec = None | |||
self._checkpoint = None | |||
self._compiled_func = None | |||
self._profiling = profiling | |||
self._profiler = None | |||
@property | |||
def _active(self): | |||
c1 = self._status == self._STARTED | |||
c2 = type(self)._active_instance is self | |||
assert c1 == c2 | |||
return c1 | |||
def _register_callback(self, f, args=()): | |||
assert self._active | |||
assert isinstance(args, (tuple, list)) | |||
proxies = self._make_proxies(args) | |||
self._forward(args, proxies, checkpoint=True) | |||
# NOTE: under eager graph callback will fire immediately | |||
job = mgb.opr.callback_injector( | |||
self._insert_barrier(_dummy()), lambda _: f(*proxies) | |||
) | |||
self._insert_checkpoint(job) | |||
self._outspec.append(job) | |||
def _insert_barrier(self, x): | |||
assert self._active | |||
if self._checkpoint is None: | |||
return x | |||
if isinstance(x, Tensor): | |||
x = x._symvar | |||
wrap = True | |||
else: | |||
wrap = False | |||
if not isinstance(x, mgb.SymbolVar): | |||
raise TypeError | |||
x = mgb.opr.virtual_dep([x, self._checkpoint]) | |||
if wrap: | |||
x = Tensor(x) | |||
return x | |||
def _insert_checkpoint(self, *args, no_barrier=False): | |||
assert self._active | |||
if not args: | |||
return | |||
args = tuple(x._symvar if isinstance(x, Tensor) else x for x in args) | |||
for x in args: | |||
if not isinstance(x, mgb.SymbolVar): | |||
raise TypeError | |||
if not no_barrier and self._checkpoint is not None: | |||
# normally no need to _insert_barrier here, but if | |||
# someone forget to call _insert_barrier beforehand, | |||
# this can make things less broken | |||
args += (self._checkpoint,) | |||
if len(args) == 1: | |||
self._checkpoint = args[0] | |||
else: | |||
self._checkpoint = mgb.opr.virtual_dep(args) | |||
def _mark_impure(self, x): | |||
assert self._active | |||
ret = x | |||
if isinstance(x, Tensor): | |||
x = x._symvar | |||
if not isinstance(x, mgb.SymbolVar): | |||
raise TypeError | |||
self._outspec.append(x) | |||
self._insert_checkpoint(x) | |||
return ret | |||
def _make_proxies(self, args): | |||
assert isinstance(args, (tuple, list)) | |||
for x in args: | |||
assert isinstance(x, Tensor) | |||
return tuple(tensor(dtype=x.dtype, device=x.device) for x in args) | |||
def _forward(self, srcs, dests, checkpoint=True): | |||
# pseudo-op: does not run under static graph; traced | |||
# TODO: use shared memory | |||
assert len(srcs) == len(dests) | |||
if not self._active: | |||
for s, d in zip(srcs, dests): | |||
d.set_value(s, share=False) | |||
return | |||
jobs = [] | |||
for s, d in zip(srcs, dests): | |||
def callback(value, dest=d): | |||
dest.set_value(value, share=False) | |||
s = self._insert_barrier(s._symvar) | |||
# NOTE: callback immediately fire in eager graph | |||
jobs.append(mgb.opr.callback_injector(s, callback)) | |||
self._outspec.extend(jobs) | |||
if checkpoint: | |||
self._insert_checkpoint(*jobs, no_barrier=True) | |||
def _forward_inputs(self, *args, **kwargs): | |||
if self._kwargs is None: | |||
self._kwargs = kwargs | |||
elif self._kwargs != kwargs: | |||
raise ValueError("kwargs must not change between invocations") | |||
if self._args is None: | |||
self._args = [] | |||
for i in args: | |||
if isinstance(i, Tensor): | |||
self._args.append(tensor(dtype=i.dtype, device=i.device)) | |||
self._args[-1].set_value(i, share=False) | |||
else: | |||
self._args.append(tensor(i)) | |||
else: | |||
if not len(args) == len(self._args): | |||
raise TypeError | |||
for i, proxy in zip(args, self._args): | |||
proxy.set_value(i, share=False) | |||
# XXX: sync? | |||
def _make_outputs(self, outputs): | |||
if outputs is None: | |||
self._outputs = None | |||
return | |||
if isinstance(outputs, Tensor): | |||
# no one is able to call barrier after this, so no need to checkpoint | |||
# but checkpoint do little harm anyway | |||
(self._outputs,) = self._make_proxies([outputs]) | |||
return | |||
if not isinstance(outputs, (tuple, list)): | |||
raise TypeError("should return (tuple of) tensor") | |||
for i in outputs: | |||
if not isinstance(i, Tensor): | |||
raise TypeError("should return (tuple of) tensor") | |||
self._outputs = self._make_proxies(outputs) | |||
def _foward_outputs(self, outputs): | |||
# pseudo-op: does not run under static graph; traced | |||
if self._outputs is unset: | |||
self._make_outputs(outputs) | |||
if self._outputs is None: | |||
if outputs is not None: | |||
raise TypeError("should return None") | |||
elif isinstance(self._outputs, Tensor): | |||
if not isinstance(outputs, Tensor): | |||
raise TypeError("should return a tensor") | |||
self._forward([outputs], [self._outputs]) | |||
else: | |||
assert isinstance(self._outputs, tuple) | |||
def check(): | |||
if not isinstance(outputs, (tuple, list)): | |||
return False | |||
if len(self._outputs) != len(outputs): | |||
return False | |||
for x in outputs: | |||
if not isinstance(x, Tensor): | |||
return False | |||
return True | |||
if not check(): | |||
raise TypeError( | |||
"should return tuple of %d tensors" % len(self._outputs) | |||
) | |||
self._forward(outputs, self._outputs) | |||
def _apply_graph_options(self, cg): | |||
# graph opt level | |||
if self._graph_opt_level is not None: | |||
cg.set_option("graph_opt_level", self._graph_opt_level) | |||
# log level | |||
if self._log_level is not None: | |||
cg.set_option("log_level", self._log_level) | |||
# sublinear | |||
if self._sublinear_memory_config is not None: | |||
cg.set_option("enable_sublinear_memory_opt", True) | |||
cg.set_option( | |||
"sublinear_mem_config.lb_memory", | |||
self._sublinear_memory_config.lb_memory, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_config.genetic_nr_iter", | |||
self._sublinear_memory_config.genetic_nr_iter, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_config.genetic_pool_size", | |||
self._sublinear_memory_config.genetic_pool_size, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_config.thresh_nr_try", | |||
self._sublinear_memory_config.thresh_nr_try, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_config.num_worker", | |||
self._sublinear_memory_config.num_worker, | |||
) | |||
# pack allreduce | |||
if self._allreduce_pack_max_size is not None: | |||
cg.set_option("allreduce_pack_max_size", self._allreduce_pack_max_size) | |||
# profile | |||
if self._profiling: | |||
self._profiler = CompGraphProfiler(cg) | |||
def _get_graph(self, eager): | |||
if eager: | |||
if not hasattr(self, "_eager_graph"): | |||
# pylint: disable=attribute-defined-outside-init | |||
self._eager_graph = graph.Graph(eager_evaluation=True) | |||
self._apply_graph_options(self._eager_graph) | |||
return self._eager_graph | |||
else: | |||
if not hasattr(self, "_static_graph"): | |||
# pylint: disable=attribute-defined-outside-init | |||
self._static_graph = graph.Graph(eager_evaluation=False) | |||
self._apply_graph_options(self._static_graph) | |||
return self._static_graph | |||
@contextlib.contextmanager | |||
def _prepare(self, args, kwargs, enable): | |||
# prepare for execution | |||
self._forward_inputs(*args, **kwargs) | |||
if not enable: | |||
# XXX: use our own graph here? | |||
cg = None | |||
elif self._status == self._FINISHED: | |||
cg = None | |||
elif self._symbolic: | |||
cg = self._get_graph(eager=False) | |||
else: | |||
cg = self._get_graph(eager=True) | |||
try: | |||
# NOTE: always trace in a new graph, so capturing an undetached tensor | |||
# will never work (would work if tracing in default graph) | |||
if cg is None: | |||
yield | |||
else: | |||
with cg: | |||
yield | |||
finally: | |||
# XXX: properly release memory | |||
if cg: | |||
cg.clear_device_memory() | |||
@contextlib.contextmanager | |||
def _activate(self): | |||
# prepare for tracing | |||
if self._status != self._UNSTARTED: | |||
raise RuntimeError("cannot trace a second time") | |||
if type(self)._active_instance is not None: | |||
raise RuntimeError("nested trace is unsupported") | |||
self._status = self._STARTED | |||
type(self)._active_instance = self | |||
self._user_cache = {} | |||
try: | |||
yield | |||
finally: | |||
self._status = self._FINISHED | |||
self._user_cache = None | |||
type(self)._active_instance = None | |||
def _run_wrapped(self): | |||
outputs = self.__wrapped__(*self._args, **self._kwargs) | |||
self._foward_outputs(outputs) | |||
return outputs | |||
def _do_trace(self): | |||
with self._activate(): | |||
self._outspec = [] | |||
outputs = self._run_wrapped() | |||
if outputs is None: | |||
self._sym_outputs = None | |||
else: | |||
if isinstance(outputs, Tensor): | |||
outputs = [outputs] | |||
# _run_wrapped has checked validity of outputs | |||
self._sym_outputs = tuple(i._symvar for i in outputs) | |||
mgb.comp_graph_tools.set_priority_to_id(self._outspec) | |||
self._compiled_func = graph.get_default_graph().compile(None, self._outspec) | |||
def trace(self, *args: Tensor, **kwargs): | |||
""" | |||
Trace wrapped callable with provided arguments. | |||
""" | |||
with self._prepare(args, kwargs, enable=True): | |||
self._do_trace() | |||
return self | |||
def __call__(self, *args: Tensor, **kwargs): | |||
""" | |||
Evaluate on provided arguments, using compiled trace | |||
instead of the original callable if applicable. | |||
:return: ``None`` or :class:`~.Tensor` or tuple of :class:`~.Tensor`, depending on the | |||
return value of wrapped callable. | |||
""" | |||
with self._prepare(args, kwargs, enable=self.enabled): | |||
if not self.enabled: | |||
self._run_wrapped() | |||
elif self._status == self._FINISHED: | |||
self._compiled_func() | |||
else: | |||
if self._status == self._UNSTARTED: | |||
self._do_trace() | |||
if self._symbolic: | |||
self._compiled_func() | |||
return self._outputs | |||
def dump( | |||
self, | |||
fpath, | |||
*, | |||
arg_names=None, | |||
append=False, | |||
optimize_for_inference=False, | |||
output_names=None, | |||
**kwargs | |||
): | |||
""" | |||
Serialize trace to file system. | |||
:param fpath: positional only argument. Path of output file. | |||
:param arg_names: names of the input tensors in the traced function. | |||
:param append: whether output is appended to ``fpath``. | |||
:param optimize_for_inference: whether to enable optimize_for_inference | |||
pass before dump. | |||
:param output_names: names of the output tensors in the traced function, | |||
will use the default name if does not specify. | |||
:param enable_io16xc32: whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
changed to float16. | |||
:param enable_ioc16: whether to use float16 for both I/O and computation | |||
precision. | |||
:param enable_hwcd4: whether to use NHWCD4 data layout. This is faster on some | |||
OpenCL backend. | |||
:param enable_nchw88: whether to use NCHW88 data layout. it currently | |||
used in X86 AVX backend. | |||
:param enable_nchw44: whether to use NCHW44 data layout. it currently | |||
used in arm backend. | |||
:param enable_nchw44_dot: whether to use NCHW44_dot data layout. it currently | |||
used in armv8.2+dotprod backend. | |||
:param enable_nchw4: whether to use NCHW4 data layout. it currently | |||
used in nvidia backend(based on cudnn). | |||
:param enable_nchw32: whether to use NCHW32 data layout. it currently | |||
used in nvidia backend with tensorcore(based on cudnn). | |||
:param enable_chwn4: whether to use CHWN4 data layout. it currently | |||
used in nvidia backend with tensorcore. | |||
:param enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
:param enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
""" | |||
if self._status != self._FINISHED: | |||
raise ValueError("not traced") | |||
assert isinstance(self._sym_outputs, (tuple, type(None))) | |||
if not self._sym_outputs: | |||
raise ValueError("not outputs") | |||
if arg_names is None: | |||
arg_names = ["arg_%d" % i for i in range(len(self._args))] | |||
elif len(arg_names) != len(self._args): | |||
raise ValueError( | |||
"len(arg_names) should be {}, got {}".format( | |||
len(self._args), len(arg_names) | |||
) | |||
) | |||
if isinstance(output_names, str): | |||
output_names = [output_names] | |||
if output_names is None: | |||
output_names = [var.name for var in self._sym_outputs] | |||
elif len(output_names) != len(self._sym_outputs): | |||
raise ValueError( | |||
"len(output_names) should be {}, got {}".format( | |||
len(self._sym_outputs), len(output_names) | |||
) | |||
) | |||
optimize_for_inference_args_map = { | |||
"enable_io16xc32": "f16_io_f32_comp", | |||
"enable_ioc16": "f16_io_comp", | |||
"enable_hwcd4": "use_nhwcd4", | |||
"enable_nchw4": "use_nchw4", | |||
"enable_nchw88": "use_nchw88", | |||
"enable_nchw32": "use_nchw32", | |||
"enable_nchw44": "use_nchw44", | |||
"enable_nchw44_dot": "use_nchw44_dot", | |||
"enable_chwn4": "use_chwn4", | |||
"enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity", | |||
"enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z", | |||
} | |||
if optimize_for_inference: | |||
optimize_for_inference_kwargs = {} | |||
for k, v in optimize_for_inference_args_map.items(): | |||
if kwargs.pop(k, False): | |||
optimize_for_inference_kwargs[v] = True | |||
else: | |||
for k in optimize_for_inference_args_map: | |||
if kwargs.get(k, False): | |||
raise ValueError( | |||
"cannot set %s when optimize_for_inference is not set" % k | |||
) | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
cg = self._sym_outputs[0].owner_graph | |||
replace = {} | |||
for t, name in zip(self._args, arg_names): | |||
# relies on symvar dedup | |||
s = t.__mgb_symvar__(comp_graph=cg) | |||
replace[s] = mgb.make_arg( | |||
t.device, cg, dtype=t.dtype, shape=t.shape, name=name | |||
) | |||
# Convert VolatileSharedDeviceTensor to SharedDeviceTensor, | |||
# otherwise some optimizations would not work. The conversion is | |||
# safe because there simply is no way (using builtin ops) to make | |||
# a VolatileSharedDeviceTensor actually volatile. | |||
for s in mgb.cgtools.get_dep_vars( | |||
self._sym_outputs, "VolatileSharedDeviceTensor" | |||
): | |||
if s in replace: | |||
continue # is an input | |||
replace[s] = mgb.SharedND._from_symvar(s).symvar( | |||
cg, name=s.name, volatile=False | |||
) | |||
sym_outputs = mgb.cgtools.replace_vars(self._sym_outputs, replace) | |||
sym_outputs = list(sym_outputs) | |||
if optimize_for_inference: | |||
sym_outputs = mgb.optimize_for_inference( | |||
sym_outputs, **optimize_for_inference_kwargs | |||
) | |||
for var, name in zip(sym_outputs, output_names): | |||
var.rename(name) | |||
mgb.serialize_comp_graph_to_file(fpath, sym_outputs, append=append) | |||
def get_profile(self): | |||
""" | |||
Get profiling result for compiled trace. | |||
:return: a json compatible object. | |||
""" | |||
if not self._profiler: | |||
raise RuntimeError("trace is not set with profiling=True") | |||
return self._profiler.get() |
@@ -1,56 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core.device import get_device_count | |||
class SublinearMemoryConfig: | |||
r""" | |||
Configuration for sublinear memory optimization. | |||
:param thresh_nr_try: number of samples both for searching in linear space | |||
and around current thresh in sublinear memory optimization. Default: 10. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||
Default: 0. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||
:param genetic_pool_size: number of samples for the crossover random selection | |||
during genetic optimization. Default: 20. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||
It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||
:param num_worker: number of thread workers to search the optimum checkpoints | |||
in sublinear memory optimization. Default: half of cpu number in the system. | |||
Note: the value must be greater or equal to one. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||
Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1' | |||
in order for the above environmental variable to be effective. | |||
""" | |||
def __init__( | |||
self, | |||
thresh_nr_try: int = 10, | |||
genetic_nr_iter: int = 0, | |||
genetic_pool_size: int = 20, | |||
lb_memory: int = 0, | |||
num_worker: int = max(1, get_device_count("cpu") // 2), | |||
): | |||
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero" | |||
self.thresh_nr_try = thresh_nr_try | |||
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero" | |||
self.genetic_nr_iter = genetic_nr_iter | |||
assert ( | |||
genetic_pool_size >= 0 | |||
), "genetic_pool_size must be greater or equal to zero" | |||
self.genetic_pool_size = genetic_pool_size | |||
self.lb_memory = lb_memory | |||
assert num_worker > 0, "num_worker must be greater or equal to one" | |||
self.num_worker = num_worker |
@@ -1,231 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import logging | |||
import os | |||
import sys | |||
_all_loggers = [] | |||
_default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "INFO") | |||
_default_level = logging.getLevelName(_default_level_name.upper()) | |||
def set_log_file(fout, mode="a"): | |||
r"""Sets log output file. | |||
:type fout: str or file-like | |||
:param fout: file-like object that supports write and flush, or string for | |||
the filename | |||
:type mode: str | |||
:param mode: specify the mode to open log file if *fout* is a string | |||
""" | |||
if isinstance(fout, str): | |||
fout = open(fout, mode) | |||
MegEngineLogFormatter.log_fout = fout | |||
class MegEngineLogFormatter(logging.Formatter): | |||
log_fout = None | |||
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||
date = "%(asctime)s " | |||
msg = "%(message)s" | |||
max_lines = 256 | |||
def _color_exc(self, msg): | |||
r"""Sets the color of message as the execution type. | |||
""" | |||
return "\x1b[34m{}\x1b[0m".format(msg) | |||
def _color_dbg(self, msg): | |||
r"""Sets the color of message as the debugging type. | |||
""" | |||
return "\x1b[36m{}\x1b[0m".format(msg) | |||
def _color_warn(self, msg): | |||
r"""Sets the color of message as the warning type. | |||
""" | |||
return "\x1b[1;31m{}\x1b[0m".format(msg) | |||
def _color_err(self, msg): | |||
r"""Sets the color of message as the error type. | |||
""" | |||
return "\x1b[1;4;31m{}\x1b[0m".format(msg) | |||
def _color_omitted(self, msg): | |||
r"""Sets the color of message as the omitted type. | |||
""" | |||
return "\x1b[35m{}\x1b[0m".format(msg) | |||
def _color_normal(self, msg): | |||
r"""Sets the color of message as the normal type. | |||
""" | |||
return msg | |||
def _color_date(self, msg): | |||
r"""Sets the color of message the same as date. | |||
""" | |||
return "\x1b[32m{}\x1b[0m".format(msg) | |||
def format(self, record): | |||
if record.levelno == logging.DEBUG: | |||
mcl, mtxt = self._color_dbg, "DBG" | |||
elif record.levelno == logging.WARNING: | |||
mcl, mtxt = self._color_warn, "WRN" | |||
elif record.levelno == logging.ERROR: | |||
mcl, mtxt = self._color_err, "ERR" | |||
else: | |||
mcl, mtxt = self._color_normal, "" | |||
if mtxt: | |||
mtxt += " " | |||
if self.log_fout: | |||
self.__set_fmt(self.date_full + mtxt + self.msg) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
nr_line = formatted.count("\n") + 1 | |||
if nr_line >= self.max_lines: | |||
head, body = formatted.split("\n", 1) | |||
formatted = "\n".join( | |||
[ | |||
head, | |||
"BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1), | |||
body, | |||
"}}END_LONG_LOG_{}_LINES".format(nr_line - 1), | |||
] | |||
) | |||
self.log_fout.write(formatted) | |||
self.log_fout.write("\n") | |||
self.log_fout.flush() | |||
self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
if record.exc_text or record.exc_info: | |||
# handle exception format | |||
b = formatted.find("Traceback ") | |||
if b != -1: | |||
s = formatted[b:] | |||
s = self._color_exc(" " + s.replace("\n", "\n ")) | |||
formatted = formatted[:b] + s | |||
nr_line = formatted.count("\n") + 1 | |||
if nr_line >= self.max_lines: | |||
lines = formatted.split("\n") | |||
remain = self.max_lines // 2 | |||
removed = len(lines) - remain * 2 | |||
if removed > 0: | |||
mid_msg = self._color_omitted( | |||
"[{} log lines omitted (would be written to output file " | |||
"if set_log_file() has been called;\n" | |||
" the threshold can be set at " | |||
"MegEngineLogFormatter.max_lines)]".format(removed) | |||
) | |||
formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:]) | |||
return formatted | |||
if sys.version_info.major < 3: | |||
def __set_fmt(self, fmt): | |||
self._fmt = fmt | |||
else: | |||
def __set_fmt(self, fmt): | |||
self._style._fmt = fmt | |||
def get_logger(name=None, formatter=MegEngineLogFormatter): | |||
r"""Gets megengine logger with given name. | |||
""" | |||
logger = logging.getLogger(name) | |||
if getattr(logger, "_init_done__", None): | |||
return logger | |||
logger._init_done__ = True | |||
logger.propagate = False | |||
logger.setLevel(_default_level) | |||
handler = logging.StreamHandler() | |||
handler.setFormatter(formatter(datefmt="%d %H:%M:%S")) | |||
handler.setLevel(0) | |||
del logger.handlers[:] | |||
logger.addHandler(handler) | |||
_all_loggers.append(logger) | |||
return logger | |||
def set_log_level(level, update_existing=True): | |||
"""Sets default logging level. | |||
:type level: int e.g. logging.INFO | |||
:param level: loggin level given by python :mod:`logging` module | |||
:param update_existing: whether to update existing loggers | |||
""" | |||
global _default_level # pylint: disable=global-statement | |||
_default_level = level | |||
if update_existing: | |||
for i in _all_loggers: | |||
i.setLevel(level) | |||
_logger = get_logger(__name__) | |||
try: | |||
if sys.version_info.major < 3: | |||
raise ImportError() | |||
from megengine._internal.logconf import set_logger as _set_mgb_logger | |||
class MegBrainLogFormatter(MegEngineLogFormatter): | |||
date = "%(asctime)s[mgb] " | |||
def _color_date(self, msg): | |||
return "\x1b[33m{}\x1b[0m".format(msg) | |||
_megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||
_set_mgb_logger(_megbrain_logger) | |||
def set_mgb_log_level(level): | |||
r"""Sets megbrain log level | |||
:type level: int e.g. logging.INFO | |||
:param level: new log level | |||
:return: original log level | |||
""" | |||
logger = _megbrain_logger | |||
rst = logger.getEffectiveLevel() | |||
logger.setLevel(level) | |||
return rst | |||
except ImportError as exc: | |||
def set_mgb_log_level(level): | |||
raise NotImplementedError("megbrain has not been imported") | |||
@contextlib.contextmanager | |||
def replace_mgb_log_level(level): | |||
r"""Replaces megbrain log level in a block and restore after exiting. | |||
:type level: int e.g. logging.INFO | |||
:param level: new log level | |||
""" | |||
old = set_mgb_log_level(level) | |||
try: | |||
yield | |||
finally: | |||
set_mgb_log_level(old) | |||
def enable_debug_log(): | |||
r"""Sets logging level to debug for all components. | |||
""" | |||
set_log_level(logging.DEBUG) | |||
set_mgb_log_level(logging.DEBUG) |
@@ -1,23 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
from .identity import Identity | |||
from .linear import Linear | |||
from .module import Module | |||
from .parampack import ParamPack | |||
from .pooling import AvgPool2d, MaxPool2d | |||
from .quant_dequant import DequantStub, QuantStub | |||
from .sequential import Sequential |
@@ -1,231 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..core import Parameter | |||
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||
from .module import Module | |||
class Softmax(Module): | |||
r""" | |||
Applies a softmax function. Softmax is defined as: | |||
.. math:: | |||
\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} | |||
It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the | |||
n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. | |||
:param axis: An axis along which softmax will be applied. By default, | |||
softmax will apply along the highest ranked axis. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2]).astype(np.float32)) | |||
softmax = M.Softmax() | |||
output = softmax(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.011656 0.031685 0.086129 0.234122 0.636409] | |||
""" | |||
def __init__(self, axis=None): | |||
super().__init__() | |||
self.axis = axis | |||
def forward(self, inputs): | |||
return softmax(inputs, self.axis) | |||
class Sigmoid(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
sigmoid = M.Sigmoid() | |||
output = sigmoid(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.119203 0.268941 0.5 0.731059 0.880797] | |||
""" | |||
def forward(self, inputs): | |||
return sigmoid(inputs) | |||
class ReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{ReLU}(x) = \max(x, 0) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
relu = M.ReLU() | |||
output = relu(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0. 0. 0. 1. 2.] | |||
""" | |||
def forward(self, x): | |||
return relu(x) | |||
class PReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x) | |||
or | |||
.. math:: | |||
\text{PReLU}(x) = | |||
\begin{cases} | |||
x, & \text{ if } x \geq 0 \\ | |||
ax, & \text{ otherwise } | |||
\end{cases} | |||
Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses | |||
a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, | |||
a seperate :math:`a` is used for each input channle. | |||
:param num_parameters: number of :math:`a` to learn, there is only two | |||
values are legitimate: 1, or the number of channels at input. Default: 1 | |||
:param init: the initial value of :math:`a`. Default: 0.25 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-1.2, -3.7, 2.7]).astype(np.float32)) | |||
prelu = M.PReLU() | |||
output = prelu(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[-0.3 -0.925 2.7 ] | |||
""" | |||
def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||
super().__init__() | |||
self.num_parameters = num_parameters | |||
if num_parameters > 1: | |||
# Assume format is NCHW | |||
self.weight = Parameter( | |||
value=np.full((1, num_parameters, 1, 1), init, dtype=np.float32) | |||
) | |||
else: | |||
self.weight = Parameter(value=[init]) | |||
def forward(self, inputs): | |||
assert self.weight.shape == (1,) or self.weight.shape == ( | |||
1, | |||
int(inputs.shape[1]), | |||
1, | |||
1, | |||
), "invalid weight's shape" | |||
return prelu(inputs, self.weight) | |||
class LeakyReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{LeakyReLU}(x) = \max(0,x) + negative\_slope \times \min(0,x) | |||
or | |||
.. math:: | |||
\text{LeakyReLU}(x) = | |||
\begin{cases} | |||
x, & \text{ if } x \geq 0 \\ | |||
negative\_slope \times x, & \text{ otherwise } | |||
\end{cases} | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-8, -12, 6, 10]).astype(np.float32)) | |||
leakyrelu = M.LeakyReLU(0.01) | |||
output = leakyrelu(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[-0.08 -0.12 6. 10. ] | |||
""" | |||
def __init__(self, negative_slope: float = 0.01): | |||
super().__init__() | |||
self.negative_slope = negative_slope | |||
def forward(self, inputs): | |||
return leaky_relu(inputs, self.negative_slope) |
@@ -1,257 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..core import Buffer, Parameter | |||
from ..core.device import get_default_device | |||
from ..functional import batch_norm2d, sync_batch_norm | |||
from . import init | |||
from .module import Module | |||
class _BatchNorm(Module): | |||
def __init__( | |||
self, | |||
num_features, | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
): | |||
super(_BatchNorm, self).__init__() | |||
self.num_features = num_features | |||
self.eps = eps | |||
self.momentum = momentum | |||
self.affine = affine | |||
self.track_running_stats = track_running_stats | |||
if self.affine: | |||
self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||
self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||
else: | |||
self.weight = None | |||
self.bias = None | |||
tshape = (1, self.num_features, 1, 1) | |||
if self.track_running_stats: | |||
self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||
self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||
else: | |||
self.running_mean = None | |||
self.running_var = None | |||
def reset_running_stats(self) -> None: | |||
if self.track_running_stats: | |||
init.zeros_(self.running_mean) | |||
init.ones_(self.running_var) | |||
def reset_parameters(self) -> None: | |||
self.reset_running_stats() | |||
if self.affine: | |||
init.ones_(self.weight) | |||
init.zeros_(self.bias) | |||
def _check_input_ndim(self, inp): | |||
raise NotImplementedError | |||
def forward(self, inp): | |||
self._check_input_ndim(inp) | |||
_ndims = len(inp.shape) | |||
if _ndims != 4: | |||
origin_shape = inp.shapeof() | |||
if _ndims == 2: | |||
n, c = inp.shapeof(0), inp.shapeof(1) | |||
new_shape = (n, c, 1, 1) | |||
elif _ndims == 3: | |||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
new_shape = (n, c, h, 1) | |||
inp = inp.reshape(new_shape) | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
# FIXME currently rocm does not support real bn opr so we just use | |||
# sync_batch_norm(as implemented by elemwise) here, | |||
# we will fix it in the next version | |||
if get_default_device() == "rocmx": | |||
output = sync_batch_norm( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
) | |||
else: | |||
output = batch_norm2d( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
) | |||
if _ndims != 4: | |||
output = output.reshape(origin_shape) | |||
return output | |||
class SyncBatchNorm(_BatchNorm): | |||
r""" | |||
Applies Synchronization Batch Normalization. | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) not in {2, 3, 4}: | |||
raise ValueError( | |||
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
) | |||
def forward(self, inp): | |||
self._check_input_ndim(inp) | |||
_ndims = len(inp.shape) | |||
if _ndims != 4: | |||
origin_shape = inp.shapeof() | |||
if _ndims == 2: | |||
n, c = inp.shapeof(0), inp.shapeof(1) | |||
new_shape = (n, c, 1, 1) | |||
elif _ndims == 3: | |||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
new_shape = (n, c, h, 1) | |||
inp = inp.reshape(new_shape) | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
output = sync_batch_norm( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
) | |||
if _ndims != 4: | |||
output = output.reshape(origin_shape) | |||
return output | |||
class BatchNorm1d(_BatchNorm): | |||
r""" | |||
Applies Batch Normalization over a 2D/3D tensor. | |||
Refer to :class:`~.BatchNorm2d` for more information. | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) not in {2, 3}: | |||
raise ValueError( | |||
"expected 2D or 3D input (got {}D input)".format(len(inp.shape)) | |||
) | |||
class BatchNorm2d(_BatchNorm): | |||
r""" | |||
Applies Batch Normalization over a 4D tensor. | |||
.. math:: | |||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
The mean and standard-deviation are calculated per-dimension over | |||
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable | |||
parameter vectors. | |||
By default, during training this layer keeps running estimates of its | |||
computed mean and variance, which are then used for normalization during | |||
evaluation. The running estimates are kept with a default :attr:`momentum` | |||
of 0.9. | |||
If :attr:`track_running_stats` is set to ``False``, this layer will not | |||
keep running estimates, and batch statistics are instead used during | |||
evaluation time. | |||
.. note:: | |||
This :attr:`momentum` argument is different from one used in optimizer | |||
classes and the conventional notion of momentum. Mathematically, the | |||
update rule for running statistics here is | |||
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`, | |||
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |||
new observed value. | |||
Because the Batch Normalization is done over the `C` dimension, computing | |||
statistics on `(N, H, W)` slices, it's common terminology to call this | |||
Spatial Batch Normalization. | |||
:type num_features: int | |||
:param num_features: usually the :math:`C` from an input of size | |||
:math:`(N, C, H, W)` or the highest ranked dimension of an input with | |||
less than 4D. | |||
:type eps: float | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
:type momentum: float | |||
:param momentum: the value used for the `running_mean` and `running_var` | |||
computation. | |||
Default: 0.9 | |||
:type affine: bool | |||
:param affine: a boolean value that when set to ``True``, this module has | |||
learnable affine parameters. Default: ``True`` | |||
:type track_running_stats: bool | |||
:param track_running_stats: when set to ``True``, this module tracks the | |||
running mean and variance. When set to ``False``, this module does not | |||
track such statistics and always uses batch statistics in both training | |||
and eval modes. Default: ``True``. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
# With Learnable Parameters | |||
m = M.BatchNorm2d(4) | |||
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||
oup = m(inp) | |||
print(m.weight, m.bias) | |||
# Without Learnable Parameters | |||
m = M.BatchNorm2d(4, affine=False) | |||
oup = m(inp) | |||
print(m.weight, m.bias) | |||
.. testoutput:: | |||
Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.]) | |||
None None | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) != 4: | |||
raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape))) |
@@ -1,22 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from .. import functional as F | |||
from ..core.tensor import Tensor | |||
from .module import Module | |||
class Concat(Module): | |||
r""" | |||
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return F.concat(inps, axis) |
@@ -1,392 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Tuple, Union | |||
import numpy as np | |||
import megengine._internal as mgb | |||
from .. import functional as F | |||
from ..core import Parameter | |||
from ..utils.types import _pair, _pair_nonzero | |||
from . import init | |||
from .module import Module | |||
class _ConvNd(Module): | |||
"""base class for convolution modules, including transposed conv""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]], | |||
padding: Union[int, Tuple[int, int]], | |||
dilation: Union[int, Tuple[int, int]], | |||
groups: int, | |||
bias: bool = True, | |||
): | |||
super().__init__() | |||
if in_channels % groups != 0: | |||
raise ValueError("in_channels must be divisible by groups") | |||
if out_channels % groups != 0: | |||
raise ValueError("out_channels must be divisible by groups") | |||
self.in_channels = in_channels | |||
self.out_channels = out_channels | |||
self.kernel_size = kernel_size | |||
self.stride = stride | |||
self.padding = padding | |||
self.dilation = dilation | |||
self.groups = groups | |||
self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||
self.bias = None | |||
if bias: | |||
self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32)) | |||
self.reset_parameters() | |||
@abstractmethod | |||
def _get_fanin(self): | |||
pass | |||
def reset_parameters(self) -> None: | |||
fanin = self._get_fanin() | |||
std = np.sqrt(1 / fanin) | |||
init.normal_(self.weight, 0.0, std) | |||
if self.bias is not None: | |||
init.zeros_(self.bias) | |||
@abstractmethod | |||
def _infer_weight_shape(self): | |||
pass | |||
@abstractmethod | |||
def _infer_bias_shape(self): | |||
pass | |||
class Conv2d(_ConvNd): | |||
r"""Applies a 2D convolution over an input tensor. | |||
For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||
this layer generates an output of the size | |||
:math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||
process described as below: | |||
.. math:: | |||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
where :math:`\star` is the valid 2D cross-correlation operator, | |||
:math:`N` is a batch size, :math:`C` denotes a number of channels, | |||
:math:`H` is a height of input planes in pixels, and :math:`W` is | |||
width in pixels. | |||
When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||
where `K` is a positive integer, this operation is also known as depthwise | |||
convolution. | |||
In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)``. | |||
:param bias: whether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
""" | |||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
_compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
ic = self.in_channels | |||
return kh * kw * ic | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ochl, ichl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ochl // group, ichl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def calc_conv(self, inp, weight, bias): | |||
return F.conv2d( | |||
inp, | |||
weight, | |||
bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
def forward(self, inp): | |||
return self.calc_conv(inp, self.weight, self.bias) | |||
class ConvTranspose2d(_ConvNd): | |||
r"""Applies a 2D transposed convolution over an input tensor. | |||
This module is also known as a deconvolution or a fractionally-strided convolution. | |||
:class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||
with respect to its input. | |||
Convolution usually reduces the size of input, while transposed convolution works | |||
the opposite way, transforming a smaller input to a larger output while preserving the | |||
connectivity pattern. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
:param bias: wether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
""" | |||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
_compute_mode_type = mgb.opr_param_defs.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
oc = self.out_channels | |||
return kh * kw * oc | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ichl, ochl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ichl // group, ochl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def forward(self, inp): | |||
return F.conv_transpose2d( | |||
inp, | |||
self.weight, | |||
self.bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
class LocalConv2d(Conv2d): | |||
r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
It is also known as the locally connected layer. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param input_height: the height of the input images. | |||
:param input_width: the width of the input images. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
The shape of weight is ``(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||
""" | |||
_conv_mode_type = mgb.opr_param_defs.Convolution.Mode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
input_height: int, | |||
input_width: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
): | |||
self.input_height = input_height | |||
self.input_width = input_width | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias=False, | |||
) | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
output_height = ( | |||
self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
) // self.stride[0] + 1 | |||
output_width = ( | |||
self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
) // self.stride[1] + 1 | |||
# Assume format is NCHW | |||
return ( | |||
group, | |||
output_height, | |||
output_width, | |||
self.in_channels // group, | |||
self.kernel_size[0], | |||
self.kernel_size[1], | |||
self.out_channels // group, | |||
) | |||
def forward(self, inp): | |||
return F.local_conv2d( | |||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
) | |||
class ConvRelu2d(Conv2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return F.relu(self.calc_conv(inp, self.weight, self.bias)) |
@@ -1,69 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
from ..functional import relu | |||
from .batchnorm import BatchNorm2d | |||
from .conv import Conv2d | |||
from .module import Module | |||
class _ConvBnActivation2d(Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
): | |||
super().__init__() | |||
self.conv = Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
conv_mode, | |||
compute_mode, | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return self.bn(self.conv(inp)) | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return relu(self.bn(self.conv(inp))) |
@@ -1,29 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional import dropout | |||
from .module import Module | |||
class Dropout(Module): | |||
r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||
Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||
by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||
:param drop_prob: The probability to drop (set to zero) each single element | |||
""" | |||
def __init__(self, drop_prob=0.0): | |||
super().__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, inputs): | |||
if self.training: | |||
return dropout(inputs, self.drop_prob, rescale=True) | |||
else: | |||
return inputs |
@@ -1,90 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import _internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..core.graph import _use_default_if_none | |||
from .module import Module | |||
@wrap_io_tensor | |||
def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
if all(isinstance(i, (int, float)) for i in inputs): | |||
device, comp_graph = _use_default_if_none(None, None) | |||
ret = mgb.opr.elemwise( | |||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs | |||
) | |||
return ret.inferred_value[0] | |||
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
class Elemwise(Module): | |||
r""" | |||
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
:param method: the elemwise method, support the following string. | |||
It will do the normal elemwise operator for float. | |||
* "ADD": a + b | |||
* "FUSE_ADD_RELU": max(x+y, 0) | |||
* "MUL": x * y | |||
* "MIN": min(x, y) | |||
* "MAX": max(x, y) | |||
* "SUB": x - y | |||
* "TRUE_DIV": x / y | |||
* "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
* "FUSE_ADD_TANH": tanh(x + y) | |||
* "RELU": x > 0 ? x : 0 | |||
* "ABS": x > 0 ? x : -x | |||
* "SIGMOID": sigmoid(x) | |||
* "EXP": exp(x) | |||
* "TANH": tanh(x) | |||
* "FUSE_MUL_ADD3": x * y + z | |||
* "FAST_TANH": fast_tanh(x) | |||
* "NEGATE": -x | |||
* "ACOS": acos(x) | |||
* "ASIN": asin(x) | |||
* "CEIL": ceil(x) | |||
* "COS": cos(x) | |||
* "EXPM1": expm1(x) | |||
* "FLOOR": floor(x) | |||
* "LOG": log(x) | |||
* "LOG1P": log1p(x) | |||
* "SIN": sin(x) | |||
* "ROUND": round(x) | |||
* "ERF": erf(x) | |||
* "ERFINV": erfinv(x) | |||
* "ERFC": erfc(x) | |||
* "ERFCINV": erfcinv(x) | |||
* "ABS_GRAD": abs_grad | |||
* "FLOOR_DIV": floor_div | |||
* "MOD": mod | |||
* "SIGMOID_GRAD": sigmoid_grad | |||
* "SWITCH_GT0": switch_gt0 | |||
* "TANH_GRAD": tanh_grad | |||
* "LT": lt | |||
* "LEQ": leq | |||
* "EQ": eq | |||
* "POW": pow | |||
* "LOG_SUM_EXP": log_sum_exp | |||
* "FAST_TANH_GRAD": fast_tanh_grad | |||
* "ATAN2": atan2 | |||
* "COND_LEQ_MOV": cond_leq_mov | |||
* "H_SWISH": h_swish | |||
* "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
* "H_SWISH_GRAD": h_swish_grad | |||
""" | |||
_elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode | |||
def __init__(self, method): | |||
super().__init__() | |||
self.method = self._elemwise_mode_type.convert(method) | |||
def forward(self, *inps): | |||
return _elemwise_func(self.method, *inps) |
@@ -1,171 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional | |||
import numpy as np | |||
from ..core import Parameter | |||
from ..functional import embedding as embedding_func | |||
from . import init | |||
from .module import Module | |||
class Embedding(Module): | |||
r""" | |||
A simple lookup table that stores embeddings of a fixed dictionary and size. | |||
This module is often used to store word embeddings and retrieve them using indices. | |||
The input to the module is a list of indices, and the output is the corresponding word embeddings. | |||
The indices should less than num_embeddings. | |||
:param num_embeddings: size of embedding dictionary. | |||
:param embedding_dim: size of each embedding vector. | |||
:param padding_idx: should be set to None, not support now. | |||
:param max_norm: should be set to None, not support now. | |||
:param norm_type: should be set to None, not support now. | |||
:param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
embedding = M.Embedding(2, 5, initial_weight=weight) | |||
output = embedding(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[0.1 1.1 2.1 3.1 4.1] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[1.2 2.3 3.4 4.5 5.6] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]]] | |||
""" | |||
def __init__( | |||
self, | |||
num_embeddings: int, | |||
embedding_dim: int, | |||
padding_idx: Optional[int] = None, | |||
max_norm: Optional[float] = None, | |||
norm_type: Optional[float] = None, | |||
initial_weight: Parameter = None, | |||
): | |||
super().__init__() | |||
if padding_idx is not None: | |||
raise ValueError("Not support padding index now.") | |||
if max_norm is not None or norm_type is not None: | |||
raise ValueError("Not support weight normalize now.") | |||
self.padding_idx = padding_idx | |||
self.max_norm = max_norm | |||
self.norm_type = norm_type | |||
self.num_embeddings = num_embeddings | |||
self.embedding_dim = embedding_dim | |||
if initial_weight is None: | |||
self.weight = Parameter( | |||
np.random.uniform( | |||
size=(self.num_embeddings, self.embedding_dim) | |||
).astype(np.float32) | |||
) | |||
self.reset_parameters() | |||
else: | |||
if initial_weight.shape != (num_embeddings, embedding_dim): | |||
raise ValueError( | |||
"The weight shape should match num_embeddings and embedding_dim" | |||
) | |||
self.weight = Parameter(initial_weight.numpy()) | |||
def reset_parameters(self) -> None: | |||
init.normal_(self.weight) | |||
def forward(self, inputs): | |||
return embedding_func(inputs, self.weight) | |||
@classmethod | |||
def from_pretrained( | |||
cls, | |||
embeddings: Parameter, | |||
freeze: Optional[bool] = True, | |||
padding_idx: Optional[int] = None, | |||
max_norm: Optional[float] = None, | |||
norm_type: Optional[float] = None, | |||
): | |||
r""" | |||
Creates Embedding instance from given 2-dimensional FloatTensor. | |||
:param embeddings: Tensor contained weight for the embedding. | |||
:param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. | |||
:param padding_idx: should be set to None, not support Now. | |||
:param max_norm: should be set to None, not support Now. | |||
:param norm_type: should be set to None, not support Now. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
embedding = M.Embedding.from_pretrained(weight, freeze=False) | |||
output = embedding(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[0.1 1.1 2.1 3.1 4.1] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[1.2 2.3 3.4 4.5 5.6] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]]] | |||
""" | |||
embeddings_shape = embeddings.shape | |||
embeddings_dim = len(embeddings_shape) | |||
if embeddings_dim != 2: | |||
raise ValueError("Embeddings parameter is expected to be 2-dimensional") | |||
rows = embeddings_shape[0] | |||
cols = embeddings_shape[1] | |||
embedding = cls( | |||
num_embeddings=rows, | |||
embedding_dim=cols, | |||
initial_weight=embeddings, | |||
padding_idx=padding_idx, | |||
max_norm=max_norm, | |||
norm_type=norm_type, | |||
) | |||
embedding.weight.requires_grad = not freeze | |||
return embedding |
@@ -1,83 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional.external import ( | |||
atlas_subgraph, | |||
cambricon_subgraph, | |||
extern_opr_subgraph, | |||
) | |||
from .module import Module | |||
class CambriconSubgraph(Module): | |||
r"""Load a serialized Cambricon subgraph. | |||
See :func:`~.cambricon_subgraph` for more details. | |||
""" | |||
def __init__( | |||
self, data, symbol, tensor_dim_mutable, | |||
): | |||
super(CambriconSubgraph, self).__init__() | |||
self._data = data | |||
self.symbol = symbol | |||
self.tensor_dim_mutable = tensor_dim_mutable | |||
@property | |||
def data(self): | |||
return self._data.tobytes() | |||
@data.setter | |||
def data(self, val): | |||
self._data = np.frombuffer(val, dtype=np.uint8) | |||
def forward(self, inputs): | |||
outputs = cambricon_subgraph( | |||
inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||
) | |||
return outputs | |||
class AtlasSubgraph(Module): | |||
r"""Load a serialized Atlas subgraph. | |||
See :func:`~.atlas_subgraph` for more details. | |||
""" | |||
def __init__(self, data): | |||
super(AtlasSubgraph, self).__init__() | |||
self._data = data | |||
@property | |||
def data(self): | |||
return self._data.tobytes() | |||
@data.setter | |||
def data(self, val): | |||
self._data = np.frombuffer(val, dtype=np.uint8) | |||
def forward(self, inputs): | |||
outputs = atlas_subgraph(inputs, self._data) | |||
return outputs | |||
class ExternOprSubgraph(Module): | |||
r"""Load a serialized extern opr subgraph. | |||
""" | |||
def __init__(self, data, name, output_shapes): | |||
super(ExternOprSubgraph, self).__init__() | |||
self.data = data | |||
self.name = name | |||
self.output_shapes = output_shapes | |||
def forward(self, inputs): | |||
outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||
return outputs |
@@ -1,17 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional import identity | |||
from .module import Module | |||
class Identity(Module): | |||
r"""A placeholder identity operator that will ignore any argument.""" | |||
def forward(self, x): | |||
return identity(x) |
@@ -1,264 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
from functools import reduce | |||
from typing import Optional, Tuple, Union | |||
import numpy as np | |||
from ..core import Graph, Tensor | |||
from ..random import gaussian, uniform | |||
def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||
"""Fill the given ``tensor`` with value ``val``. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param val: The value to be filled throughout the tensor | |||
""" | |||
tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) | |||
def zeros_(tensor: Tensor) -> None: | |||
"""Fill the given ``tensor`` with scalar value `0`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
""" | |||
fill_(tensor, 0) | |||
def ones_(tensor: Tensor) -> None: | |||
"""Fill the given ``tensor`` with the scalar value `1`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
""" | |||
fill_(tensor, 1) | |||
def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||
r"""Fill the given ``tensor`` with random value sampled from uniform distribution | |||
:math:`\mathcal{U}(\text{a}, \text{b})`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Lower bound of the sampling interval | |||
:param b: Upper bound of the sampling interval | |||
""" | |||
with Graph(eager_evaluation=True): | |||
tensor.set_value((b - a) * uniform(tensor.shape) + a) | |||
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
r"""Fill the given ``tensor`` with random value sampled from normal distribution | |||
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param mean: The mean of the normal distribution | |||
:param std: The standard deviation of the normal distribution | |||
""" | |||
with Graph(eager_evaluation=True): | |||
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std)) | |||
def calculate_gain( | |||
nonlinearity: str, param: Optional[Union[int, float]] = None | |||
) -> float: | |||
r"""Return a recommended gain value (see the table below) for the given nonlinearity | |||
function. | |||
================= ==================================================== | |||
nonlinearity gain | |||
================= ==================================================== | |||
Linear / Identity :math:`1` | |||
Conv{1,2,3}D :math:`1` | |||
Sigmoid :math:`1` | |||
Tanh :math:`\frac{5}{3}` | |||
ReLU :math:`\sqrt{2}` | |||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||
================= ==================================================== | |||
:param nonlinearity: Name of the non-linear function | |||
:param param: Optional parameter for leaky_relu. Only effective when | |||
``nonlinearity`` is "leaky_relu". | |||
""" | |||
linear_fns = [ | |||
"linear", | |||
"conv1d", | |||
"conv2d", | |||
"conv3d", | |||
"conv_transpose1d", | |||
"conv_transpose2d", | |||
"conv_transpose3d", | |||
] | |||
if nonlinearity in linear_fns or nonlinearity == "sigmoid": | |||
return 1 | |||
if nonlinearity == "tanh": | |||
return 5.0 / 3 | |||
if nonlinearity == "relu": | |||
return math.sqrt(2.0) | |||
if nonlinearity == "leaky_relu": | |||
if param is None: | |||
negative_slope = 0.01 | |||
elif ( | |||
not isinstance(param, bool) | |||
and isinstance(param, int) | |||
or isinstance(param, float) | |||
): | |||
# True/False are instances of int, hence check above | |||
negative_slope = param | |||
else: | |||
raise ValueError("negative_slope {} not a valid number".format(param)) | |||
return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||
""" | |||
Calculate fan_in / fan_out value for given weight tensor. This function assumes | |||
input tensor is stored in NCHW format. | |||
:param tensor: Weight tensor in NCHW format | |||
""" | |||
shape = tensor.shape | |||
ndim = len(shape) | |||
if ndim < 2: | |||
raise ValueError( | |||
"fan_in and fan_out can not be computed for tensor with fewer than 2 " | |||
"dimensions" | |||
) | |||
if ndim == 2: # Linear | |||
fan_in = shape[1] | |||
fan_out = shape[0] | |||
else: | |||
num_input_fmaps = shape[1] | |||
num_output_fmaps = shape[0] | |||
receptive_field_size = 1 | |||
if ndim > 2: | |||
receptive_field_size = reduce(lambda x, y: x * y, shape[2:], 1) | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||
""" | |||
Calculate fan_in or fan_out value for given weight tensor, depending on given | |||
``mode``. | |||
See :func:`calculate_fan_in_and_fan_out` for details. | |||
:param tensor: Weight tensor in NCHW format | |||
:param mode: ``'fan_in'`` or ``'fan_out'`` | |||
""" | |||
mode = mode.lower() | |||
valid_modes = ["fan_in", "fan_out"] | |||
if mode not in valid_modes: | |||
raise ValueError( | |||
"Mode {} not supported, please use one of {}".format(mode, valid_modes) | |||
) | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
return fan_in if mode == "fan_in" else fan_out | |||
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||
r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` | |||
where | |||
.. math:: | |||
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | |||
Also known as Glorot initialization. Detailed information can be retrieved from | |||
`"Understanding the difficulty of training deep feedforward neural networks" <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param gain: Scaling factor for :math:`a`. | |||
""" | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
a = math.sqrt(3.0) * std | |||
uniform_(tensor, -a, a) | |||
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | |||
r"""Fill ``tensor`` with random values sampled from | |||
:math:`\mathcal{N}(0, \text{std}^2)` where | |||
.. math:: | |||
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | |||
Also known as Glorot initialization. Detailed information can be retrieved from | |||
`"Understanding the difficulty of training deep feedforward neural networks" <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param gain: Scaling factor for :math:`std`. | |||
""" | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
normal_(tensor, 0.0, std) | |||
def msra_uniform_( | |||
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
) -> None: | |||
r"""Fill ``tensor`` wilth random values sampled from | |||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
.. math:: | |||
\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | |||
Detailed information can be retrieved from | |||
`"Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
classification" <https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Optional parameter for calculating gain for leaky_relu. See | |||
:func:`calculate_gain` for details. | |||
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for | |||
details. | |||
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
See :func:`calculate_gain` for details. | |||
""" | |||
fan = calculate_correct_fan(tensor, mode) | |||
gain = calculate_gain(nonlinearity, a) | |||
std = gain / math.sqrt(fan) | |||
bound = math.sqrt(3.0) * std | |||
uniform_(tensor, -bound, bound) | |||
def msra_normal_( | |||
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
) -> None: | |||
r"""Fill ``tensor`` wilth random values sampled from | |||
:math:`\mathcal{N}(0, \text{std}^2)` where | |||
.. math:: | |||
\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | |||
Detailed information can be retrieved from | |||
`"Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
classification" <https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Optional parameter for calculating gain for leaky_relu. See | |||
:func:`calculate_gain` for details. | |||
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for | |||
details. | |||
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
See :func:`calculate_gain` for details. | |||
""" | |||
fan = calculate_correct_fan(tensor, mode) | |||
gain = calculate_gain(nonlinearity, a) | |||
std = gain / math.sqrt(fan) | |||
normal_(tensor, 0, std) |
@@ -1,61 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from .. import functional as F | |||
from ..core import Parameter | |||
from . import init | |||
from .module import Module | |||
class Linear(Module): | |||
r"""Applies a linear transformation to the input. For instance, if input | |||
is x, then output y is: | |||
.. math:: | |||
y = xW^T + b | |||
where :math:`y_i= \sum_j W_{ij} x_j + b_i` | |||
:param in_features: size of each input sample. | |||
:param out_features: size of each output sample. | |||
:param bias: If set to ``False``, the layer will not learn an additive bias. | |||
Default: ``True`` | |||
""" | |||
def __init__( | |||
self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
): | |||
super().__init__(**kwargs) | |||
self.out_features = out_features | |||
self.in_features = in_features | |||
w_shape = (out_features, in_features) | |||
self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
self.bias = None | |||
if bias: | |||
b_shape = (out_features,) | |||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
self.reset_parameters() | |||
def _get_fanin(self): | |||
return self.in_features | |||
def reset_parameters(self) -> None: | |||
fanin = self._get_fanin() | |||
std = np.sqrt(1 / fanin) | |||
init.normal_(self.weight, 0.0, std) | |||
if self.bias is not None: | |||
init.zeros_(self.bias) | |||
def _calc_linear(self, x, weight, bias): | |||
return F.linear(x, weight, bias) | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) |
@@ -1,507 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
from .._internal.dtype import is_quantize | |||
from ..core import Buffer, Parameter, Tensor | |||
from ..logger import get_logger | |||
from ..utils.hook import HookHandler | |||
logger = get_logger(__name__) | |||
def _expand_structure(key, obj): | |||
if isinstance(obj, (Tensor, Module)): | |||
return [(key, obj)] | |||
elif isinstance(obj, (list, tuple, dict)): | |||
ret = [] | |||
if isinstance(obj, dict): | |||
targets = ((k, obj[k]) for k in sorted(obj)) | |||
else: | |||
targets = ((str(k), v) for k, v in enumerate(obj)) | |||
for k, o in targets: | |||
sub_ret = _expand_structure(k, o) | |||
if sub_ret and not isinstance(k, str): | |||
raise AssertionError( | |||
"keys for Tensor and Module must be str, error key: {}".format(k) | |||
) | |||
for kt, vt in sub_ret: | |||
ret.extend([(key + "." + kt, vt)]) | |||
return ret | |||
else: | |||
return [] | |||
def _is_parameter(obj): | |||
return isinstance(obj, Parameter) | |||
def _is_buffer(obj): | |||
return isinstance(obj, Buffer) | |||
def _is_module(obj): | |||
return isinstance(obj, Module) | |||
class Module(metaclass=ABCMeta): | |||
"""Base Module class. | |||
""" | |||
def __init__(self): | |||
# runtime attributes | |||
self.training = True | |||
self.quantize_disabled = False | |||
# hooks | |||
self._forward_pre_hooks = OrderedDict() | |||
self._forward_hooks = OrderedDict() | |||
@abstractmethod | |||
def forward(self, inputs): | |||
pass | |||
def register_forward_pre_hook(self, hook: Callable) -> HookHandler: | |||
"""Register a hook to handle forward inputs. `hook` should be a function | |||
Note that `inputs` keyword inputs | |||
:param hook: a function that receive `module` and `inputs`, then return | |||
a modified `inputs` or `None`. | |||
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
""" | |||
return HookHandler(self._forward_pre_hooks, hook) | |||
def register_forward_hook(self, hook: Callable) -> HookHandler: | |||
"""Register a hook to handle forward results. `hook` should be a function that | |||
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. | |||
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
""" | |||
return HookHandler(self._forward_hooks, hook) | |||
def __call__(self, *inputs, **kwargs): | |||
for hook in self._forward_pre_hooks.values(): | |||
modified_inputs = hook(self, inputs) | |||
if modified_inputs is not None: | |||
if not isinstance(modified_inputs, tuple): | |||
modified_inputs = (modified_inputs,) | |||
inputs = modified_inputs | |||
outputs = self.forward(*inputs, **kwargs) | |||
for hook in self._forward_hooks.values(): | |||
modified_outputs = hook(self, inputs, outputs) | |||
if modified_outputs is not None: | |||
outputs = modified_outputs | |||
return outputs | |||
def _flatten( | |||
self, | |||
*, | |||
recursive: bool = True, | |||
with_key: bool = False, | |||
with_parent: bool = False, | |||
prefix: Optional[str] = None, | |||
predicate: Callable[[Any], bool] = lambda _: True, | |||
seen: Optional[Set[int]] = None | |||
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | |||
"""Scans the module object and returns an iterable for the :class:`~.Tensor` | |||
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||
calls of this function with same arguments, the order of objects within the | |||
returned iterable is guaranteed to be identical, as long as all the involved | |||
module objects' ``__dict__`` does not change thoughout those calls. | |||
:param recursive: Whether to recursively scan all the submodules. | |||
:param with_key: Whether to yield keys along with yielded objects. | |||
:param with_parent: Whether to yield ``self`` along with yielded objects. | |||
:param prefix: The prefix appended to the yielded keys. | |||
:param predicate: The predicate function applied to scanned objects. | |||
:param seen: A dict that records whether a module has been traversed yet. | |||
""" | |||
if seen is None: | |||
seen = set([id(self)]) | |||
module_dict = vars(self) | |||
_prefix = "" if prefix is None else prefix + "." | |||
for key in sorted(module_dict): | |||
for expanded_key, leaf in _expand_structure(key, module_dict[key]): | |||
leaf_id = id(leaf) | |||
if leaf_id in seen: | |||
continue | |||
seen.add(leaf_id) | |||
if predicate(leaf): | |||
if with_key and with_parent: | |||
yield _prefix + expanded_key, leaf, self | |||
elif with_key: | |||
yield _prefix + expanded_key, leaf | |||
elif with_parent: | |||
yield leaf, self | |||
else: | |||
yield leaf | |||
if recursive and isinstance(leaf, Module): | |||
yield from leaf._flatten( | |||
recursive=recursive, | |||
with_key=with_key, | |||
with_parent=with_parent, | |||
prefix=_prefix + expanded_key if with_key else None, | |||
predicate=predicate, | |||
seen=seen, | |||
) | |||
def parameters( | |||
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||
) -> Iterable[Parameter]: | |||
r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||
:param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
module, else only returns :class:`~.Parameter` that are direct attributes | |||
of this module. | |||
""" | |||
def predicate(obj) -> bool: | |||
return _is_parameter(obj) and ( | |||
requires_grad is None or obj.requires_grad == requires_grad | |||
) | |||
yield from self._flatten( | |||
with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||
) | |||
def named_parameters( | |||
self, | |||
requires_grad: Optional[bool] = None, | |||
prefix: Optional[str] = None, | |||
recursive: bool = True, | |||
**kwargs | |||
) -> Iterable[Tuple[str, Parameter]]: | |||
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||
``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||
:param prefix: The prefix prepended to the keys. | |||
:param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
module, else only returns :class:`~.Parameter` that are direct attributes | |||
of this module. | |||
""" | |||
def predicate(obj) -> bool: | |||
return _is_parameter(obj) and ( | |||
requires_grad is None or obj.requires_grad == requires_grad | |||
) | |||
yield from self._flatten( | |||
with_key=True, | |||
prefix=prefix, | |||
predicate=predicate, | |||
recursive=recursive, | |||
**kwargs, | |||
) | |||
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||
"""Returns an iterable for the :class:`~.Buffer` of the module. | |||
:param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
module, else only returns :class:`~.Buffer` that are direct attributes | |||
of this module. | |||
""" | |||
yield from self._flatten( | |||
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs | |||
) | |||
def named_buffers( | |||
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
) -> Iterable[Tuple[str, Buffer]]: | |||
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||
``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||
:param prefix: The prefix prepended to the keys. | |||
:param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
module, else only returns :class:`~.Buffer` that are direct attributes | |||
of this module. | |||
""" | |||
yield from self._flatten( | |||
with_key=True, | |||
prefix=prefix, | |||
predicate=_is_buffer, | |||
recursive=recursive, | |||
**kwargs, | |||
) | |||
def children(self, **kwargs) -> "Iterable[Module]": | |||
"""Returns an iterable for all the submodules that are direct attributes of this | |||
module. | |||
""" | |||
yield from self._flatten( | |||
with_key=False, predicate=_is_module, recursive=False, **kwargs | |||
) | |||
def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": | |||
"""Returns an iterable of key-submodule pairs for all the submodules that are | |||
direct attributes of this module, where 'key' is the attribute name of | |||
submodules. | |||
""" | |||
yield from self._flatten( | |||
with_key=True, predicate=_is_module, recursive=False, **kwargs | |||
) | |||
def modules(self, **kwargs) -> "Iterable[Module]": | |||
"""Returns an iterable for all the modules within this module, including itself. | |||
""" | |||
if "with_parent" in kwargs and kwargs["with_parent"]: | |||
yield self, None | |||
else: | |||
yield self | |||
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) | |||
def named_modules( | |||
self, prefix: Optional[str] = None, **kwargs | |||
) -> "Iterable[Tuple[str, Module]]": | |||
"""Returns an iterable of key-module pairs for all the modules within this | |||
module, including itself, where 'key' is the dotted path from this module to the | |||
submodules. | |||
:param prefix: The prefix prepended to the path. | |||
""" | |||
if "with_parent" in kwargs and kwargs["with_parent"]: | |||
yield ("" if prefix is None else prefix), self, None | |||
else: | |||
yield ("" if prefix is None else prefix), self | |||
yield from self._flatten( | |||
with_key=True, prefix=prefix, predicate=_is_module, **kwargs | |||
) | |||
def apply(self, fn: "Callable[[Module], Any]") -> None: | |||
"""Apply function ``fn`` to all the modules within this module, including | |||
itself. | |||
:param fn: The function to be applied on modules. | |||
""" | |||
for it in self.modules(): | |||
fn(it) | |||
def zero_grad(self) -> None: | |||
"""Set all parameters' grads to zero | |||
""" | |||
for param in self.parameters(): | |||
if param.grad is not None: | |||
param.grad.reset_zero() | |||
def train(self, mode: bool = True, recursive: bool = True) -> None: | |||
"""Set training mode of all the modules within this module (including itself) to | |||
``mode``. This effectively sets the ``training`` attributes of those modules | |||
to ``mode``, but only has effect on certain modules (e.g. | |||
:class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) | |||
:param mode: the training mode to be set on modules. | |||
:param recursive: whether to recursively call submodules' ``train()``. | |||
""" | |||
if not recursive: | |||
self.training = mode | |||
return | |||
def fn(module: Module) -> None: | |||
module.train(mode, recursive=False) | |||
self.apply(fn) | |||
def eval(self) -> None: | |||
"""Set training mode of all the modules within this module (including itself) to | |||
``False``. See :meth:`~.Module.train` for details. | |||
""" | |||
self.train(False) | |||
def disable_quantize(self, value=True): | |||
r""" | |||
Set ``module``'s ``quantize_disabled`` attribute and return ``module``. | |||
Could be used as a decorator. | |||
""" | |||
def fn(module: Module) -> None: | |||
module.quantize_disabled = value | |||
self.apply(fn) | |||
def replace_param( | |||
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||
): | |||
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to | |||
speedup multimachine training. | |||
""" | |||
offset = 0 | |||
if seen is None: | |||
seen = set([id(self)]) | |||
module_dict = vars(self) | |||
for key in sorted(module_dict): | |||
hash_id = id(module_dict[key]) | |||
if hash_id in seen: | |||
continue | |||
seen.add(hash_id) | |||
if isinstance(module_dict[key], Parameter): | |||
if start_pos + offset in params: | |||
assert module_dict[key].shape == params[start_pos + offset].shape | |||
module_dict[key] = params[start_pos + offset] | |||
offset += 1 | |||
if isinstance(module_dict[key], Module): | |||
offset += module_dict[key].replace_param( | |||
params, start_pos + offset, seen | |||
) | |||
return offset | |||
def state_dict(self, rst=None, prefix="", keep_var=False): | |||
r"""Returns a dictionary containing whole states of the module. | |||
""" | |||
def is_state(obj): | |||
return _is_parameter(obj) or _is_buffer(obj) | |||
if rst is None: | |||
rst = OrderedDict() | |||
for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | |||
assert prefix + k not in rst, "duplicated state: {}".format(k) | |||
if keep_var: | |||
rst[prefix + k] = v | |||
else: | |||
rst[prefix + k] = v.numpy() | |||
for k, submodule in self._flatten( | |||
recursive=False, | |||
with_key=True, | |||
predicate=lambda obj: isinstance(obj, Module), | |||
): | |||
submodule.state_dict(rst, prefix + k + ".", keep_var) | |||
return rst | |||
def load_state_dict( | |||
self, | |||
state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], | |||
strict=True, | |||
): | |||
r"""Load a given dictionary created by :func:`state_dict` into this module. | |||
If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys | |||
returned by :func:`state_dict`. | |||
Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` | |||
as a `state_dict`, in order to handle complex situations. For example, load everything | |||
except for the final linear classifier: | |||
.. code-block:: | |||
state_dict = {...} # Dict[str, np.ndarray] | |||
model.load_state_dict({ | |||
k: None if k.startswith('fc') else v | |||
for k, v in state_dict.items() | |||
}, strict=False) | |||
Here returning `None` means skipping parameter `k`. | |||
To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: | |||
.. code-block:: | |||
state_dict = {...} | |||
def reshape_accordingly(k, v): | |||
return state_dict[k].reshape(v.shape) | |||
model.load_state_dict(reshape_accordingly) | |||
We can also perform inplace re-initialization or pruning: | |||
.. code-block:: | |||
def reinit_and_pruning(k, v): | |||
if 'bias' in k: | |||
M.init.zero_(v) | |||
if 'conv' in k: | |||
return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) | |||
model.load_state_dict(reinit_and_pruning, strict=False) | |||
""" | |||
unused = [] | |||
if isinstance(state_dict, dict): | |||
unused = state_dict.keys() | |||
def closure(k, _): # var unused | |||
return state_dict[k] if k in state_dict else None | |||
elif callable(state_dict): | |||
closure = state_dict | |||
else: | |||
raise ValueError( | |||
"`state_dict` must load a dict or callable, got {}".format( | |||
type(state_dict) | |||
) | |||
) | |||
loaded, skipped = self._load_state_dict_with_closure(closure) | |||
unused = set(unused) - loaded | |||
if len(unused) != 0: | |||
if strict: | |||
raise KeyError( | |||
"Unused params violate `strict=True`, unused={}".format(unused) | |||
) | |||
else: | |||
logger.warning( | |||
"Unused params in `strict=False` mode, unused={}".format(unused) | |||
) | |||
if len(skipped) != 0: | |||
if strict: | |||
raise KeyError( | |||
"Missing params violate `strict=True`, missing={}".format(skipped) | |||
) | |||
else: | |||
logger.warning( | |||
"Missing params in `strict=False` mode, missing={}".format(skipped) | |||
) | |||
def _load_state_dict_with_closure(self, closure): | |||
"""Advance state_dict load through callable `closure` whose signature is | |||
`closure(key: str, var: Tensor) -> Union[np.ndarry, None]` | |||
""" | |||
assert callable(closure), "closure must be a function" | |||
loaded = [] | |||
skipped = [] | |||
local_state_dict = self.state_dict(keep_var=True) | |||
for k, var in local_state_dict.items(): | |||
to_be_load = closure(k, var) | |||
if to_be_load is None: | |||
skipped.append(k) | |||
continue | |||
assert isinstance( | |||
to_be_load, np.ndarray | |||
), "closure should return a `np.ndarray`, now `{}` get {}".format( | |||
k, to_be_load | |||
) | |||
assert ( | |||
var.shape == to_be_load.shape | |||
), "param `{}` shape mismatch, should be {}, get {}".format( | |||
k, var.shape, to_be_load.shape | |||
) | |||
# For quantized dtype, the initialized dtype | |||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
var.dtype = to_be_load.dtype | |||
var.set_value(to_be_load) | |||
loaded.append(k) | |||
return set(loaded), set(skipped) |
@@ -1,157 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Callable, Iterable, Optional, Tuple | |||
import numpy as np | |||
from .._internal.opr import param_pack_split | |||
from ..core import Parameter, Tensor | |||
from .module import Module | |||
class ParamPack(Module): | |||
r"""Pack module's parameters by gathering their memory to continuous address. | |||
Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||
parameters with same key will be packed togather. | |||
It helps a lot for multimachine training by speeding up allreduce gradients. | |||
:param model: the module you want to pack parameters. | |||
:param nr_ignore_first: how many parameters will be unpacked at first. | |||
:param max_size_per_group: upper bound of packed parameters' size in MB. | |||
:param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||
""" | |||
def __init__( | |||
self, | |||
model: Module, | |||
nr_ignore_first: int = 8, | |||
max_size_per_group: int = 10, | |||
max_nr_params_per_group: int = 100, | |||
group_func: Callable = lambda name, param: 0, | |||
): | |||
super().__init__() | |||
self._model = model | |||
self._nr_ignore_first = nr_ignore_first | |||
self._max_size_per_group = max_size_per_group | |||
self._max_nr_params_per_group = max_nr_params_per_group | |||
self._group_func = group_func | |||
self._grouped_params = [] | |||
self._packed_params = [] | |||
params = model.named_parameters() | |||
self._pack_params(params) | |||
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||
for param in self._packed_params: | |||
if requires_grad is None or param.requires_grad == requires_grad: | |||
yield param | |||
def named_parameters( | |||
self, requires_grad: Optional[bool] = None | |||
) -> Iterable[Tuple[str, Parameter]]: | |||
for idx, param in enumerate(self._packed_params): | |||
if requires_grad is None or param.requires_grad == requires_grad: | |||
yield "packed_param_" + str(idx), param | |||
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||
groups = collections.defaultdict(list) | |||
ignored = 0 | |||
param_id = 0 | |||
for name, param in params: | |||
if self._nr_ignore_first > ignored: | |||
ignored += 1 | |||
self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||
param.pack_group_key = self._group_func(name, param) | |||
self._packed_params.append(param) | |||
else: | |||
key = ( | |||
param.dtype, | |||
param.device, | |||
param.requires_grad, | |||
self._group_func(name, param), | |||
) | |||
groups[key].append({"tensor": param, "id": param_id}) | |||
param_id += 1 | |||
for (dtype, device, requires_grad, group_key) in groups.keys(): | |||
dtype_sz = np.dtype(dtype).itemsize | |||
align = device.mem_align | |||
if align < dtype_sz: | |||
align = 1 | |||
else: | |||
assert align % dtype_sz == 0 | |||
align //= dtype_sz | |||
group = groups[(dtype, device, requires_grad, group_key)] | |||
while group: | |||
aligned_pos = [] | |||
offset = 0 | |||
params = [] | |||
idx = 0 | |||
while idx < len(group): | |||
param = group[idx] | |||
assert param["tensor"].device == device | |||
padding = (align - (offset & (align - 1))) & (align - 1) | |||
offset += padding | |||
aligned_pos.append(offset) | |||
params.append(param) | |||
offset += int(np.prod(param["tensor"].shape)) | |||
idx += 1 | |||
if ( | |||
offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||
or idx >= self._max_nr_params_per_group | |||
): | |||
break | |||
group = group[idx:] | |||
if idx == 1: | |||
# ignore param packs with only one item | |||
params[0]["tensor"].pack_group_key = group_key | |||
self._packed_params.append(params[0]["tensor"]) | |||
self._grouped_params.append( | |||
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||
) | |||
continue | |||
packed_value = np.zeros((offset,), dtype=dtype) | |||
for param, pos in zip(params, aligned_pos): | |||
val = param["tensor"].numpy() | |||
packed_value[pos : pos + val.size] = val.flatten() | |||
new_param = Parameter( | |||
value=packed_value, | |||
device=device, | |||
dtype=dtype, | |||
requires_grad=requires_grad, | |||
) | |||
new_param.pack_group_key = group_key | |||
self._packed_params.append(new_param) | |||
self._grouped_params.append( | |||
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||
) | |||
def forward(self, *args, **kwargs): | |||
replace_param = dict() | |||
for i in range(len(self._packed_params)): | |||
packed_param = self._packed_params[i] | |||
grouped_params = self._grouped_params[i] | |||
if len(grouped_params) == 1: | |||
continue | |||
split = param_pack_split( | |||
packed_param._symvar, [i["shape"] for i in grouped_params] | |||
) | |||
split = [ | |||
Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||
for i in split | |||
] | |||
for j in range(len(split)): | |||
replace_param[grouped_params[j]["id"]] = split[j] | |||
self._model.replace_param(replace_param, 0) | |||
return self._model.forward(*args, **kwargs) |
@@ -1,80 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Tuple, Union | |||
from ..functional import avg_pool2d, max_pool2d | |||
from .module import Module | |||
class _PoolNd(Module): | |||
def __init__( | |||
self, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = None, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
): | |||
super(_PoolNd, self).__init__() | |||
self.kernel_size = kernel_size | |||
self.stride = stride or kernel_size | |||
self.padding = padding | |||
@abstractmethod | |||
def forward(self, inp): | |||
pass | |||
class MaxPool2d(_PoolNd): | |||
r"""Applies a 2D max pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
:attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
.. math:: | |||
\begin{aligned} | |||
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||
\text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||
\text{stride[1]} \times w + n) | |||
\end{aligned} | |||
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window to take a max over. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param padding: implicit zero padding to be added on both sides. | |||
""" | |||
def forward(self, inp): | |||
return max_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
class AvgPool2d(_PoolNd): | |||
r"""Applies a 2D average pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
:attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
.. math:: | |||
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param padding: implicit zero padding to be added on both sides. | |||
""" | |||
def forward(self, inp): | |||
return avg_pool2d(inp, self.kernel_size, self.stride, self.padding) |
@@ -1,9 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .pytorch import PyTorchModule |
@@ -1,451 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
import functools | |||
import os | |||
from typing import Any, Callable, List, Optional, Tuple | |||
import torch | |||
from torch.utils.cpp_extension import load as load_torch_extension | |||
import megengine._internal as mgb | |||
from megengine._internal import CompGraph | |||
from megengine._internal.mgb import CompGraphCallbackValueProxy | |||
from ...core import Parameter, Tensor, get_default_device | |||
from ..module import Module | |||
from .utils import device_to_torch_device, torch_dtype_to_numpy_dtype | |||
# A global dict to map opr during graph copy | |||
_copy_dict = {} | |||
@functools.lru_cache(None) | |||
def _get_torch_mem_fwd_lib(): | |||
source_file = os.path.join(os.path.dirname(__file__), "torch_mem_fwd.cpp") | |||
return load_torch_extension( | |||
"torch_mem_fwd", | |||
[source_file], | |||
extra_include_paths=[mgb.config.get_include_path()], | |||
) | |||
def inp_mem_fwd(pubapi_dev_tensor_ptr: int) -> torch.Tensor: | |||
"""Forward a MegBrain tensor to torch tensor | |||
:param pubapi_dev_tensor_ptr: pointer to MegBrain tensor | |||
""" | |||
return _get_torch_mem_fwd_lib().inp_mem_fwd(pubapi_dev_tensor_ptr) | |||
def oup_mem_fwd( | |||
pubapi_dev_tensor_ptr: int, tensor: torch.Tensor, keep_data_ptr: bool = True | |||
) -> None: | |||
"""Forward a torch tensor to a contiguous MegBrain tensor | |||
:param pubapi_dev_tensor_ptr: Pointer to the MegBrain tensor | |||
:param tensor: The input torch tensor | |||
:param keep_data_ptr: if True, memory copy is not allowed here, | |||
thus the input torch tensor must be contiguous also. | |||
defaults to True | |||
""" | |||
_get_torch_mem_fwd_lib().oup_mem_fwd(pubapi_dev_tensor_ptr, tensor, keep_data_ptr) | |||
def torch_param_to_mge( | |||
name: str, param: torch.nn.Parameter, device, comp_graph: CompGraph | |||
) -> Parameter: | |||
"""Convert a torch parameter to a megengine parameter | |||
:param name: parametr name | |||
:param param: torch parameter | |||
:param device: the device on which the megengine parameter is, | |||
should be physically the same as the one on torch parameter | |||
:param comp_graph: the owner graph of megengine parameter | |||
:return: megengine parameter | |||
""" | |||
assert isinstance(param, torch.nn.Parameter) | |||
dtype = torch_dtype_to_numpy_dtype(param.dtype) | |||
mge_param = Parameter(None, dtype=dtype) | |||
shared_nd = mge_param._Tensor__val | |||
oup_mem_fwd(shared_nd.pubapi_dev_tensor_ptr, param.data, True) | |||
return mge_param | |||
class _PyTorchSubgraphGradOpr(mgb.craniotome.CraniotomeBase): | |||
__nr_inputs__ = None | |||
__nr_outputs__ = None | |||
__allow_duplicate__ = False | |||
__disable_sys_mem_alloc__ = True | |||
__is_dynamic_output_shape__ = True | |||
_forward_opr = None # type: PyTorchSubgraphImplOpr | |||
_shape_infer_func = None | |||
_condensed_out_grad_idx = None # type: List[Optional[int]] | |||
_forward_input_cnt = None | |||
_forward_output_cnt = None | |||
_output_grad_cnt = None | |||
_param_cnt = None | |||
def setup( | |||
self, forward_opr, condensed_out_grad_idx: List[Optional[int]], infer_shape=None | |||
): | |||
self._forward_opr = forward_opr | |||
self._forward_input_cnt = forward_opr.input_cnt | |||
self._forward_output_cnt = forward_opr.output_cnt | |||
self._param_cnt = forward_opr.param_cnt | |||
self._output_grad_cnt = sum([idx is not None for idx in condensed_out_grad_idx]) | |||
self.__nr_inputs__ = ( | |||
self._forward_input_cnt | |||
+ self._param_cnt | |||
+ self._forward_output_cnt | |||
+ self._output_grad_cnt | |||
) | |||
self.__nr_outputs__ = self._forward_input_cnt + self._param_cnt | |||
self._forward_opr = forward_opr | |||
self._condensed_out_grad_idx = condensed_out_grad_idx | |||
self._shape_infer_func = infer_shape | |||
if infer_shape is not None: | |||
type(self).__is_dynamic_output_shape__ = False | |||
def execute( | |||
self, | |||
inputs: Tuple[CompGraphCallbackValueProxy, ...], | |||
outputs: Tuple[mgb.SharedND, ...], | |||
): | |||
assert self._forward_opr._last_forward_inputs is not None | |||
assert self._forward_opr._last_forward_outputs is not None | |||
if self._forward_opr._last_forward_outputs is None: | |||
self._forward_opr.execute(inputs[: self.__nr_outputs__], None) | |||
out_grads = [ | |||
inp_mem_fwd(inputs[idx].pubapi_dev_tensor_ptr) if idx else None | |||
for idx in self._condensed_out_grad_idx | |||
] | |||
grads = torch.autograd.grad( | |||
self._forward_opr._last_forward_outputs, | |||
self._forward_opr._last_forward_inputs | |||
+ self._forward_opr._last_forward_params, | |||
out_grads, # type: ignore | |||
only_inputs=True, | |||
allow_unused=True, | |||
) | |||
for ovar, oten in zip(outputs, grads): | |||
oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten) | |||
def grad(self, wrt_idx, inputs, outputs, out_grad): | |||
raise NotImplementedError("Apply grad to a grad opr is not supported") | |||
def infer_shape(self, inp_shapes): | |||
if callable(self._shape_infer_func): | |||
return self._shape_infer_func(inp_shapes) | |||
raise NotImplementedError( | |||
"No shape inference function specified on PyTorchSubgraphImplOpr" | |||
) | |||
def copy(self): | |||
ret = type(self)() | |||
d0 = self.__dict__.copy() | |||
d0.pop("this") | |||
d0.pop("_forward_opr") | |||
later_copy = self._forward_opr in _copy_dict | |||
if later_copy: | |||
assert len(_copy_dict) == 1 | |||
forward_opr_copy = _copy_dict[self._forward_opr] | |||
else: | |||
forward_opr_copy = self._forward_opr | |||
ret.__dict__["_forward_opr"] = forward_opr_copy | |||
ret.__dict__.update(copy.deepcopy(d0)) | |||
_copy_dict[self] = ret | |||
if later_copy: | |||
forward_opr_copy._grad_opr = ret | |||
_copy_dict.clear() | |||
return ret | |||
class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase): | |||
# pylint: disable=abstract-method | |||
"""This is a pytorch module wrapper to operator""" | |||
__nr_inputs__ = None # type: int | |||
__nr_outputs__ = None # type: int | |||
__allow_duplicate__ = False | |||
__disable_sys_mem_alloc__ = True | |||
__is_dynamic_output_shape__ = True | |||
_grad_opr = None | |||
_func = None # type: Callable[[Any], Any] | |||
input_cnt = None # type: int | |||
output_cnt = None # type: int | |||
param_cnt = None # type: int | |||
_shape_infer_func = None | |||
_last_forward_inputs = None | |||
_last_forward_outputs = None # type: List[torch.Tensor] | |||
_last_forward_params = None # type: List[torch.Tensor] | |||
def setup(self, *, input_cnt, output_cnt, func, params, infer_shape=None): | |||
"""Setup the operator by accepted kwargs | |||
:param input_cnt: input count of torch module | |||
:param output_cnt: output count of torch module | |||
:param func: a callable object accept inputs and returns outputs | |||
usually a torch module itself | |||
:param params: parameters of the torch module | |||
:param infer_shape: a callable infers output shapes from input shapes, | |||
defaults to None | |||
""" | |||
param_cnt = len(params) | |||
self.input_cnt = input_cnt | |||
self.output_cnt = output_cnt | |||
self.param_cnt = param_cnt | |||
self.__nr_inputs__ = input_cnt + param_cnt | |||
self.__nr_outputs__ = output_cnt | |||
self._func = func | |||
self._shape_infer_func = infer_shape | |||
if infer_shape is not None: | |||
type(self).__is_dynamic_output_shape__ = False | |||
self._last_forward_params = params | |||
def execute( | |||
self, | |||
inputs: Tuple[CompGraphCallbackValueProxy, ...], | |||
outputs: Optional[Tuple[mgb.SharedND, ...]], | |||
): | |||
"""execute the operator, read values from *inputs*, | |||
forward them to torch tensor and do execution by self.func | |||
and forward results to outputs | |||
:param inputs: values for each input var | |||
:param outputs: values for each output var | |||
""" | |||
input_value_proxys = inputs[: self.input_cnt] | |||
input_torch_tensors = [ | |||
inp_mem_fwd(ivar.pubapi_dev_tensor_ptr).requires_grad_() | |||
for ivar in input_value_proxys | |||
] | |||
output_torch_tensors = self._func(*input_torch_tensors) | |||
if isinstance(output_torch_tensors, torch.Tensor): | |||
output_torch_tensors = [output_torch_tensors] | |||
# `execute` may be called in _PyTorchSubgraphGradOp with None as outputs | |||
if outputs: | |||
for ovar, oten in zip(outputs, output_torch_tensors): | |||
oup_mem_fwd(ovar.pubapi_dev_tensor_ptr, oten) | |||
# Retain input / output tensors for backward | |||
self._last_forward_inputs = input_torch_tensors | |||
self._last_forward_outputs = output_torch_tensors | |||
def grad( | |||
self, | |||
wrt_idx, | |||
inputs: Tuple[mgb.SymbolVar, ...], | |||
outputs: Tuple[mgb.SymbolVar, ...], | |||
out_grads: Tuple[mgb.SymbolVar, ...], | |||
): | |||
"""generate a grad opr which calculates grad by torch.autograd.grad and cache it | |||
:param wrt_idx: the input var with respect to which the gradient should | |||
be computed | |||
:param inputs: operator inputs | |||
:param outputs: operator outputs | |||
:param out_grads: gradients of each output var | |||
:return: an initialized grad opr | |||
""" | |||
if self._grad_opr is None: | |||
condensed_out_grad = [] | |||
condensed_out_grad_idx = [] # type: List[Optional[int]] | |||
idx = self.__nr_inputs__ + len(outputs) | |||
for out_grad in out_grads: | |||
if out_grad is None: | |||
condensed_out_grad_idx.append(None) | |||
else: | |||
condensed_out_grad.append(out_grad) | |||
condensed_out_grad_idx.append(idx) | |||
idx += 1 | |||
self._grad_opr = _PyTorchSubgraphGradOpr.make( | |||
*(inputs + outputs + tuple(condensed_out_grad)), | |||
forward_opr=self, | |||
condensed_out_grad_idx=condensed_out_grad_idx, | |||
) | |||
return self._grad_opr | |||
def infer_shape(self, inp_shapes): | |||
"""infer output shape from input shapes | |||
:param inp_shapes: input shapes as tuple | |||
:return: output shapes | |||
""" | |||
if callable(self._shape_infer_func): | |||
return self._shape_infer_func(inp_shapes) | |||
raise NotImplementedError( | |||
"No shape inference function specified on PyTorchSubgraphImplOpr" | |||
) | |||
def copy(self): | |||
ret = type(self)() | |||
d0 = self.__dict__.copy() | |||
d0.pop("this") | |||
ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs") | |||
ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs") | |||
ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params") | |||
ret.__dict__["_func"] = d0.pop("_func") | |||
d0.pop("_grad_opr") | |||
later_copy = self._grad_opr in _copy_dict | |||
if later_copy: | |||
assert len(_copy_dict) == 1 | |||
grad_opr_copy = _copy_dict[self._grad_opr] | |||
else: | |||
grad_opr_copy = self._grad_opr | |||
ret.__dict__["_grad_opr"] = grad_opr_copy | |||
ret.__dict__.update(copy.deepcopy(d0)) | |||
_copy_dict[self] = ret | |||
if later_copy: | |||
grad_opr_copy._forward_opr = ret | |||
_copy_dict.clear() | |||
return ret | |||
class PyTorchModule(Module): | |||
"""Wrap a pytorch module as megengine module | |||
:param torch_module: torch module to be wrapped | |||
:param device: target device this module would be in | |||
:param output_cnt: output count of this module | |||
:param input_shape: input shape inferrer | |||
:param comp_graph: target comp_graph on which this module would be in | |||
""" | |||
__torch_module = None # type: torch.nn.Module | |||
__output_cnt = None | |||
__infer_shape = None | |||
__comp_graph = None | |||
__device = None | |||
_torch_params = None | |||
_param_inputs = None | |||
_name_param_list = None # type: List[Tuple[str, Parameter]] | |||
def __init__( | |||
self, | |||
torch_module, | |||
device=None, | |||
output_cnt=1, | |||
*, | |||
infer_shape=None, | |||
comp_graph=None | |||
): | |||
super().__init__() | |||
if not isinstance(torch_module, torch.nn.Module): | |||
raise TypeError( | |||
"torch_module should either be an instance of torch.nn.Module " | |||
"or its subclass" | |||
) | |||
self.__torch_module = torch_module | |||
if not isinstance(output_cnt, int): | |||
raise TypeError("output_cnt must be int") | |||
if output_cnt <= 0: | |||
raise ValueError("output_cnt must be greater than zero") | |||
self.__output_cnt = output_cnt | |||
if infer_shape and not callable(infer_shape): | |||
raise TypeError("infer_shape should either be None or a callable object") | |||
self.__infer_shape = infer_shape | |||
if comp_graph and not isinstance(comp_graph, mgb.CompGraph): | |||
raise TypeError("comp_graph shoud eighter be None or a mgb.CompGraph") | |||
self.__comp_graph = comp_graph | |||
self._torch_params = [] | |||
self._param_inputs = [] | |||
self._name_param_list = [] | |||
if device is None: | |||
device = get_default_device() | |||
if isinstance(device, str): | |||
device = mgb.comp_node(device) | |||
self.device = device | |||
def init_params(self): | |||
"""forward torch parameters to megengine parameters and store, | |||
would be called in constructor and setter of device | |||
""" | |||
self._torch_params = [] | |||
self._param_inputs = [] | |||
self._name_param_list = [] | |||
for name, torch_param in self.__torch_module.named_parameters(recurse=True): | |||
formated_name = "_torch_{}_{}".format(id(self.__torch_module), name) | |||
mge_param = torch_param_to_mge( | |||
formated_name, torch_param, self.device, self.__comp_graph | |||
) | |||
self._param_inputs.append(mge_param) | |||
self._torch_params.append(torch_param) | |||
self._name_param_list.append((name, mge_param)) | |||
def get_param_by_name(self, param_name: str) -> Parameter: | |||
"""find parameter by its name | |||
:param param_name: name of parameter | |||
:return: the parameter | |||
""" | |||
for name, param in self._name_param_list: | |||
if param_name == name: | |||
return param | |||
raise KeyError("Cannot find param: {}".format(param_name)) | |||
def forward(self, *inputs): | |||
"""apply the module on given inputs | |||
:return: output vars | |||
""" | |||
param_inputs = [param._symvar for param in self._param_inputs] | |||
inputs = [tensor._symvar for tensor in list(inputs)] + param_inputs | |||
out = PyTorchSubgraphImplOpr.make( | |||
*inputs, | |||
input_cnt=len(inputs) - len(param_inputs), | |||
output_cnt=self.__output_cnt, | |||
func=self.__torch_module.forward, | |||
params=self._torch_params, | |||
infer_shape=self.__infer_shape, | |||
) | |||
if isinstance(out, mgb.SymbolVar): | |||
return Tensor(out) | |||
assert isinstance(out, collections.Iterable) | |||
return [Tensor(sym) for sym in out] | |||
def get_device(self): | |||
"""get the device this module belongs to""" | |||
return self.__device | |||
def set_device(self, device: mgb.CompNode): | |||
"""set the device and move torch module to corresponding device""" | |||
touch_device = device_to_torch_device(device) | |||
self.__torch_module.to(device=touch_device) | |||
self.__device = device | |||
self.init_params() | |||
device = property(get_device, set_device) |
@@ -1,148 +0,0 @@ | |||
/** | |||
* \file python_module/megengine/module/pytorch/torch_mem_fwd.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "torch/extension.h" | |||
#include "megbrain_pubapi.h" | |||
using MGBTensor = mgb::pubapi::DeviceTensor; | |||
torch::Tensor mgb_to_torch(const MGBTensor *src) { | |||
mgb::pubapi::CallbackOnce deleter; | |||
void* tensor_raw_ptr; | |||
src->forward_to(&tensor_raw_ptr, &deleter); | |||
auto deleter_wrap = [deleter](void*) mutable { | |||
deleter.consume(); | |||
}; | |||
// TODO: support non-contiguous layout | |||
std::vector<int64_t> sizes; | |||
for (size_t i = 0; i < src->desc.ndim; ++ i) { | |||
sizes.push_back(src->desc.shape[i]); | |||
} | |||
torch::TensorOptions options; | |||
switch (src->desc.dtype) { | |||
#define map_dtype(mgb_dtype, torch_dtype) \ | |||
case MGBTensor::DataType::mgb_dtype: \ | |||
options = options.dtype(caffe2::TypeMeta::Make<torch_dtype>()); \ | |||
break; | |||
map_dtype(FLOAT32, float); | |||
map_dtype(FLOAT16, torch::Half); | |||
map_dtype(INT32, int); | |||
map_dtype(INT16, int16_t); | |||
map_dtype(INT8, int8_t); | |||
map_dtype(UINT8, uint8_t); | |||
#undef map_dtype | |||
default: | |||
throw std::runtime_error("bad case for data type."); | |||
} | |||
// TODO: Maybe we should impl copy on different devices? | |||
switch (src->desc.type) { | |||
case MGBTensor::Type::CUDA: { | |||
int device_id = src->desc.cuda_ctx.device; | |||
if (device_id >= 0) { | |||
options = options.device(torch::DeviceType::CUDA, device_id); | |||
} else { | |||
throw std::runtime_error("bad case for device(cuda) id."); | |||
} | |||
// TODO: consider cuda synchronization here | |||
// Maybe all tasks issued on cuda_ctx(device, stream) should be done? | |||
break; | |||
} | |||
case MGBTensor::Type::CPU: | |||
options = options.device(torch::DeviceType::CPU); | |||
// Torch's API are all synchronous. | |||
src->sync(); | |||
break; | |||
default: | |||
throw std::runtime_error("bad case for device type."); | |||
} | |||
auto tensor = torch::from_blob(tensor_raw_ptr, sizes, deleter_wrap, options); | |||
return tensor; | |||
} | |||
void torch_to_mgb(MGBTensor* dst, torch::Tensor src) { | |||
MGBTensor::Desc desc; | |||
desc.dev_ptr = src.data_ptr(); | |||
// src is contiguous torch tensor here, so no strides needed | |||
std::vector<size_t> shape; | |||
// desc.shape is the pointer to a size array used to construct | |||
// an inner-mgb tensor, which should be valid until calling of | |||
// forward_other_memory return | |||
for (auto &&i : src.sizes()) { | |||
shape.push_back(i); | |||
} | |||
desc.shape = shape.data(); | |||
desc.ndim = shape.size(); | |||
switch (src.scalar_type()) { | |||
#define map_dtype(mgb_dtype, torch_dtype) \ | |||
case torch::ScalarType::torch_dtype: \ | |||
desc.dtype = MGBTensor::DataType::mgb_dtype; \ | |||
break; | |||
map_dtype(FLOAT32, Float); | |||
map_dtype(FLOAT16, Half); | |||
map_dtype(INT32, Int); | |||
map_dtype(INT16, Short); | |||
map_dtype(INT8, Char); | |||
map_dtype(UINT8, Byte); | |||
#undef map_dtype | |||
default: | |||
throw std::runtime_error("bad case for data type."); | |||
} | |||
// TODO: cuda setting and synchronization like mgb_to_torch | |||
if (src.device().type() == torch::DeviceType::CUDA) { | |||
desc.type = MGBTensor::Type::CUDA; | |||
desc.cuda_ctx.device = src.get_device(); | |||
desc.cuda_ctx.stream = nullptr; | |||
} else { | |||
assert(src.device().type() == torch::DeviceType::CPU); | |||
desc.type = MGBTensor::Type::CUDA; | |||
} | |||
mgb::pubapi::CallbackOnce deleter; | |||
deleter.user_data = new torch::Tensor(src); | |||
deleter.fptr = [](void* ptr) { | |||
delete static_cast<torch::Tensor*>(ptr); | |||
}; | |||
dst->forward_other_memory(desc, deleter); | |||
} | |||
torch::Tensor inp_mem_fwd(uintptr_t dv_ptr) { | |||
// construct torch Tensor from mgb DeviceTensor stored in dv_ptr. | |||
return mgb_to_torch(reinterpret_cast<MGBTensor*>(dv_ptr)); | |||
} | |||
void oup_mem_fwd(uintptr_t dv_ptr, torch::Tensor src, | |||
bool keep_data_ptr=false) { | |||
// forward storage in torch Tensor to mgb DeviceTensor | |||
// keep_data_ptr: set to True to ensure forwarding data_ptr under \p src | |||
// to megbrain, or it maybe copy src to a new contiguous tensor storage. | |||
// which would return src itself if tensor is contiguous | |||
auto src_contig = src.contiguous(); | |||
if (keep_data_ptr && src_contig.data_ptr() != src.data_ptr()) { | |||
throw std::runtime_error("should keep tensor data ptr, but it changed"); | |||
} | |||
torch_to_mgb(reinterpret_cast<MGBTensor*>(dv_ptr), src_contig); | |||
} | |||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||
m.def("inp_mem_fwd", &inp_mem_fwd, "Forward mgb DeviceTensor ptr into torch Tensor as network input."); | |||
m.def("oup_mem_fwd", &oup_mem_fwd, "Forward torch network Tensor to corresponding mgb VarNode.", | |||
py::arg("dv_ptr"), py::arg("src"), py::arg("keep_data_ptr") = false); | |||
} |
@@ -1,67 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import torch | |||
import megengine._internal as mgb | |||
_TORCH_NUMPY_MAPPING = { | |||
torch.float16: np.float16, | |||
torch.float32: np.float32, | |||
torch.float64: np.float64, | |||
torch.int8: np.int8, | |||
torch.int16: np.int16, | |||
torch.int32: np.int32, | |||
} | |||
def torch_dtype_to_numpy_dtype(torch_dtype: torch.dtype): | |||
"""map torch dtype to numpy dtype | |||
:param torch_dtype: torch dtype | |||
:return: numpy dtype | |||
""" | |||
if not isinstance(torch_dtype, torch.dtype): | |||
raise TypeError("Argument `torch_dtype` should be an instance of torch.dtype") | |||
if torch_dtype not in _TORCH_NUMPY_MAPPING: | |||
raise ValueError("Unknown PyTorch dtype: {}".format(torch_dtype)) | |||
return _TORCH_NUMPY_MAPPING[torch_dtype] | |||
def torch_device_to_device(device: torch.device): | |||
"""map torch device to device | |||
:param device: torch device | |||
:return: device | |||
""" | |||
if not isinstance(device, torch.device): | |||
raise TypeError("Argument `device` should be an instance of torch.device") | |||
index = device.index | |||
if index is None: | |||
index = "x" | |||
if device.type == "cpu": | |||
return "cpu{}".format(index) | |||
elif device.type == "cuda": | |||
return "gpu{}".format(index) | |||
raise ValueError("Unknown PyTorch device: {}".format(device)) | |||
def device_to_torch_device(device: mgb.CompNode): | |||
"""map device to torch device | |||
:param device: megbrain compute node | |||
:return: corresponding torch device | |||
""" | |||
t, d, _ = device.locator_physical | |||
if t == "CUDA": | |||
return torch.device("cuda", d) | |||
elif t == "CPU": | |||
return torch.device("cpu", d) | |||
else: | |||
raise Exception("Unsupported device type: {}".format(t)) |
@@ -1,14 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QATModule | |||
from .quant_dequant import DequantStub, QuantStub |