GitOrigin-RevId: 11411b6964
tags/v1.0.0-rc1
@@ -247,10 +247,6 @@ if(MGE_BUILD_IMPERATIVE_RT) | |||
set(CMAKE_CXX_STANDARD 17) | |||
endif() | |||
if(MGE_BUILD_IMPERATIVE_RT) | |||
set(MGE_BUILD_SDK OFF) | |||
endif() | |||
if(NOT MGE_WITH_CUDA) | |||
message("-- Disable distributed support, as CUDA is not enabled.") | |||
set(MGE_WITH_DISTRIBUTED OFF) | |||
@@ -697,9 +693,7 @@ if(MGE_WITH_PYTHON_MODULE) | |||
endif() | |||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
if(NOT MGE_BUILD_IMPERATIVE_RT) | |||
add_subdirectory(test) | |||
endif() | |||
add_subdirectory(test) | |||
endif() | |||
if(TARGET mgb) | |||
@@ -66,9 +66,7 @@ if(MGE_WITH_CUDA) | |||
endif() | |||
if(MGE_WITH_TEST) | |||
if(NOT MGE_BUILD_IMPERATIVE_RT) | |||
add_subdirectory(test) | |||
endif() | |||
add_subdirectory(test) | |||
endif() | |||
add_subdirectory(src) | |||
@@ -0,0 +1,5 @@ | |||
Makefile | |||
/test/imperative_test | |||
*.so | |||
/python/megengine/core/ops/_internal/generated_ops.py | |||
/python/megengine/core/ops/_internal/param_defs.py |
@@ -0,0 +1,110 @@ | |||
find_package(NumPy REQUIRED) | |||
set(PACKAGE_NAME megengine) | |||
set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
set(MODULE_NAME _imperative_rt) | |||
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") | |||
file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) | |||
list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) | |||
file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h | |||
${PROJECT_SOURCE_DIR}/src/core/include/* | |||
${PROJECT_SOURCE_DIR}/src/opr/include/* | |||
${PROJECT_SOURCE_DIR}/src/serialization/include/* | |||
${PROJECT_SOURCE_DIR}/src/plugin/include/* | |||
${PROJECT_SOURCE_DIR}/dnn/include/*) | |||
set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) | |||
set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) | |||
file(MAKE_DIRECTORY ${GEN_OPS_DIR}) | |||
set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) | |||
set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) | |||
set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) | |||
##################### generate python opr_param_defs.py ############## | |||
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) | |||
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | |||
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) | |||
add_custom_command( | |||
OUTPUT ${GEN_OPS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} | |||
COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} | |||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test | |||
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} | |||
DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} | |||
VERBATIM | |||
) | |||
add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) | |||
##################### generate opdef c header and python binding ############## | |||
set(OP_DEF_HEADER_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/include) | |||
file(MAKE_DIRECTORY ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef) | |||
set(OP_DEF_HEADER ${OP_DEF_HEADER_OUT_DIR}/megbrain/imperative/opdef/all.h) | |||
set(OP_DEF_PYTHON_BINDING_OUT_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/src) | |||
file(MAKE_DIRECTORY ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||
set(OP_DEF_PYTHON_BINDING ${OP_DEF_PYTHON_BINDING_OUT_DIR}/opdef.inl) | |||
set(OP_PARAM_DEF ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py) | |||
set(GEN_OP_DEF_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_op_defs.py) | |||
add_custom_command( | |||
OUTPUT ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING} | |||
COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} ${OP_DEF_HEADER} | |||
COMMAND ${PYTHON_EXECUTABLE} ${GEN_OP_DEF_SCRIPT} -t py ${OP_PARAM_DEF} ${OP_DEF_PYTHON_BINDING} | |||
DEPENDS ${GEN_OP_DEF_SCRIPT} ${OP_PARAM_DEF} | |||
VERBATIM | |||
) | |||
add_custom_target(gen_op_def_internal DEPENDS ${OP_DEF_HEADER} ${OP_DEF_PYTHON_BINDING}) | |||
add_library(gen_op_def INTERFACE) | |||
target_include_directories(gen_op_def INTERFACE ${OP_DEF_HEADER_OUT_DIR} ${OP_DEF_PYTHON_BINDING_OUT_DIR}) | |||
add_dependencies(gen_op_def gen_op_def_internal) | |||
##################### end of opdef generation ######################### | |||
set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) | |||
add_custom_target(_version_ld SOURCES ${VERSION_SCRIPT}) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/pybind11 ${PROJECT_BINARY_DIR}/third_party/pybind11) | |||
pybind11_add_module(${MODULE_NAME} NO_EXTRAS ${SRCS}) | |||
target_link_libraries(${MODULE_NAME} PRIVATE gen_op_def megbrain megdnn -Wl,--version-script=${VERSION_SCRIPT}) | |||
if (MGE_WITH_DISTRIBUTED) | |||
message("Imperative configured to link megray") | |||
target_link_libraries(${MODULE_NAME} PRIVATE megray) | |||
endif() | |||
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
target_compile_options(${MODULE_NAME} PRIVATE "-Wno-class-memaccess") | |||
endif() | |||
set_target_properties(${MODULE_NAME} PROPERTIES | |||
SUFFIX ${CMAKE_SHARED_LIBRARY_SUFFIX} | |||
LIBRARY_OUTPUT_DIRECTORY ${MEGENGINE_DIR}/${PACKAGE_NAME}/core | |||
) | |||
add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) | |||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
add_subdirectory(test) | |||
endif() | |||
add_custom_command( | |||
TARGET ${MODULE_NAME} POST_BUILD | |||
COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/LICENSE ${PROJECT_SOURCE_DIR}/ACKNOWLEDGMENTS ${PROJECT_BINARY_DIR} | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine ${CMAKE_CURRENT_BINARY_DIR}/python/megengine | |||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${CMAKE_CURRENT_BINARY_DIR}/python/test | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/setup.py ${CMAKE_CURRENT_BINARY_DIR}/python/setup.py | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires.txt | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-style.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-style.txt | |||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/requires-test.txt ${CMAKE_CURRENT_BINARY_DIR}/python/requires-test.txt | |||
) | |||
@@ -0,0 +1,25 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import sys | |||
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
from .device import * | |||
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
from .serialization import load, save | |||
from .tensor import Tensor, tensor | |||
from .tensor_nn import Buffer, Parameter | |||
from .version import __version__ | |||
_set_fork_exec_path_for_timed_func( | |||
sys.executable, | |||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
) | |||
del _set_fork_exec_path_for_timed_func |
@@ -0,0 +1,12 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import sys | |||
from .tensor import Tensor |
@@ -0,0 +1,46 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ._imperative_rt import CompNode | |||
class Device: | |||
def __init__(self, device=None): | |||
if device is None: | |||
self._cn = CompNode() | |||
elif isinstance(device, Device): | |||
self._cn = device._cn | |||
elif isinstance(device, CompNode): | |||
self._cn = device | |||
else: | |||
self._cn = CompNode(device) | |||
def to_c(self): | |||
return self._cn | |||
def __repr__(self): | |||
return "{}({})".format(type(self).__qualname__, self) | |||
def __str__(self): | |||
return str(self._cn) | |||
def __hash__(self): | |||
return hash(str(self._cn)) | |||
def __eq__(self, rhs): | |||
if not isinstance(rhs, Device): | |||
rhs = Device(rhs) | |||
return str(self._cn) == str(rhs._cn) | |||
def device(obj): | |||
if isinstance(obj, Device): | |||
return obj | |||
return Device(obj) |
@@ -0,0 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
@@ -0,0 +1,134 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import itertools | |||
import numpy as np | |||
from .._imperative_rt import TensorAttr, imperative | |||
from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape | |||
from ..tensor.core import apply | |||
from ..tensor.function import Function | |||
@functools.singledispatch | |||
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
assert 0 | |||
_elemwise_add_param = Elemwise(mode="add").to_c().param | |||
@builtin_op_get_backward_fn.register(OpDef) | |||
def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
if ( | |||
isinstance(op, OprAttr) | |||
and op.type == "Elemwise" | |||
and op.param == _elemwise_add_param | |||
): | |||
grad_fn = elemwise_grad_fn | |||
elif isinstance(op, OprAttr) and op.type == Reshape.name: | |||
grad_fn = reshape_grad_fn | |||
else: | |||
grad_fn = default_grad_fn | |||
return grad_fn(op, inputs, outputs, input_requires_grad) | |||
@builtin_op_get_backward_fn.register(Function) | |||
def _(op: Function, inputs, outputs, input_requires_grad): | |||
return op.get_backward_fn(), [True,] * len(outputs) | |||
def default_grad_fn(op, inputs, outputs, input_requires_grad): | |||
def get_tensor_attr(x): | |||
attr = TensorAttr() | |||
attr.dtype = x.dtype | |||
attr.comp_node = x.device.to_c() | |||
return attr | |||
output_has_grads = [True,] * len(outputs) | |||
result = imperative.make_backward_graph( | |||
op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads | |||
) | |||
if result is None: | |||
nr_inputs = len(inputs) | |||
nr_outputs = len(outputs) | |||
def backward(*args): | |||
return nr_inputs * [ | |||
None, | |||
] | |||
return backward, nr_outputs * [False,] | |||
backward_graph, save_for_backward_mask, input_has_grad = result | |||
intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :] | |||
output_grad_mask = save_for_backward_mask[len(inputs + outputs) :] | |||
save_for_backward = tuple( | |||
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask | |||
) | |||
del inputs | |||
del outputs | |||
def backward(*args): | |||
output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask) | |||
assert None not in output_grads | |||
ret = iter(apply(backward_graph, *(save_for_backward + output_grads))) | |||
return tuple(next(ret) if mask else None for mask in input_has_grad) | |||
return backward, output_grad_mask | |||
# override for elemwise | |||
def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): | |||
assert len(inputs) == len(input_requires_grad) == 2 | |||
def get_shape(x): | |||
(s,) = apply(GetVarShape(), x) | |||
return s | |||
input_shapes = [ | |||
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||
] | |||
def reduce_to(x, s): | |||
(y,) = apply(Reduce(), x, s) | |||
return y | |||
def backward(dy): | |||
return tuple( | |||
reduce_to(dy, s) if i else None | |||
for i, s in zip(input_requires_grad, input_shapes) | |||
) | |||
return backward, [True] | |||
def reshape_grad_fn(op, inputs, outputs, input_requires_grad): | |||
assert len(inputs) == len(input_requires_grad) == 2 | |||
def get_shape(x): | |||
(s,) = apply(GetVarShape(), x) | |||
return s | |||
input_shapes = [ | |||
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) | |||
] | |||
def reshape_to(dy, s): | |||
(dx,) = apply(Reshape(), dy, s) | |||
return dx | |||
def backward(dy): | |||
return tuple( | |||
reshape_to(dy, s) if i else None | |||
for i, s in zip(input_requires_grad, input_shapes) | |||
) | |||
return backward, [True] |
@@ -0,0 +1,390 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import heapq | |||
import itertools | |||
import typing | |||
import weakref | |||
import numpy as np | |||
from ..ops.builtin import Elemwise, OpDef | |||
from ..ops.special import Const | |||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..tensor.function import Function | |||
from ..tensor.tensor import Tensor, get_context | |||
from . import builtin_op_utils | |||
""" Some notes: | |||
1. Initialize the optimizer: | |||
for each trainable parameter: | |||
call wrt(param, callback) | |||
Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data | |||
2. Tracer has one member: node, which is a VariableNode | |||
3. VariableNode has a OpNode member: opnode | |||
4. OpNode has four members: | |||
a. id | |||
b. inputs, which is made of VariableNode | |||
c. outputs, which are weakref's to VariableNode | |||
d. backward: call back function | |||
e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist | |||
f. backward_allow_noinput: whether backward allow noinput | |||
""" | |||
_grad_count = 0 | |||
_grad_manager_dict = weakref.WeakValueDictionary() | |||
def get_grad_managers(): | |||
return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||
def add(a, b): | |||
(c,) = apply(Elemwise(mode="add"), a, b) | |||
return c | |||
def get_tensor(x): | |||
# use recursion to avoid infinite loop | |||
if isinstance(x, Tensor): | |||
return x | |||
try: | |||
x = x.__wrapped__ | |||
except AttributeError: | |||
raise TypeError(type(x)) | |||
return get_tensor(x) | |||
class Grad: | |||
def __init__(self, name=None): | |||
if name is None: | |||
global _grad_count | |||
self._name = "grad_" + str(_grad_count) | |||
_grad_count += 1 | |||
else: | |||
self._name = name | |||
assert self._name not in _grad_manager_dict, "grad manager name duplicated" | |||
_grad_manager_dict[self._name] = self | |||
# list of all x in partial(y) / partial(x) | |||
self.xs = [] | |||
# constains weak reference of all OpNode during forward | |||
# OpNode contains inputs, outputs and its backward | |||
# ops forms the computational graph | |||
self.ops = [] | |||
self._enabled = True | |||
@property | |||
def name(self): | |||
return self._name | |||
def wrt(self, *args: Tensor, callback=None): | |||
""" Indicates the loss is a function of the input tensors (usually the net trainable parameters), | |||
i.e., d (loss) / d (Tensor) != 0 | |||
callback is used to perform additional operations after gradient is obtained in backward. | |||
e.g., copy the grad to a particular place | |||
A VariableNode will be created and saved in the tensor/s _extra_data slot. | |||
""" | |||
for x in map(get_tensor, args): | |||
v = self._new_variable(x, callback=callback) | |||
assert self not in x._extra_data | |||
x._extra_data[self] = Tracer(v) | |||
self.xs.append(v) | |||
return self | |||
def _new_variable(self, owner, opnode=None, callback=None): | |||
return VariableNode(self, owner, opnode=opnode, callback=callback) | |||
def _new_opnode(self, inputs, outputs): | |||
inputs = tuple(inputs) | |||
for i in inputs: | |||
assert i is None or isinstance(i, VariableNode) | |||
o = OpNode() | |||
o.inputs = inputs | |||
o.outputs = [] | |||
tracers = [] | |||
for i in outputs: | |||
assert isinstance(i, Tensor) | |||
v = self._new_variable(i, o) | |||
o.outputs.append(weakref.ref(v)) | |||
tracers.append(Tracer(v)) | |||
self.ops.append(weakref.ref(o)) | |||
return o, tracers | |||
def copy(self): | |||
raise NotImplementedError | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, *_): | |||
"""clear all resources""" | |||
self._enabled = False | |||
for o in self.ops: | |||
o = o() | |||
if o: | |||
o.clear() | |||
def __call__(self, ys, dys): | |||
""" Defines Grad(). | |||
:param ys: outputs of forward operators, e.g., the loss tensor | |||
:type ys: list of Tensor or TensorWrapperBase | |||
:param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, | |||
e.g., one for the loss itself | |||
:type dys: list of Tensor or TensorWrapperBase | |||
""" | |||
assert self._enabled | |||
self._enabled = False | |||
def check_wrapper(): | |||
if isinstance(dys, TensorWrapperBase): | |||
return type(dys) | |||
if isinstance(dys, TensorBase): | |||
return | |||
assert isinstance(dys, (tuple, list)) | |||
for i in dys: | |||
if isinstance(i, TensorWrapperBase): | |||
return type(i) | |||
Wrapper = check_wrapper() | |||
def aslist(x): | |||
if isinstance(x, (Tensor, TensorWrapperBase)): | |||
x = [x] | |||
else: | |||
x = list(x) | |||
x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] | |||
for i in x: | |||
assert isinstance(i, Tensor) | |||
return x | |||
ys = aslist(ys) | |||
dys = aslist(dys) | |||
assert len(ys) == len(dys) | |||
# ys is changed to a list of VariableNode which contains more information | |||
# such as OpNode, callback, etc. | |||
ys = [i._extra_data[self].node for i in ys] | |||
# NOTE: callback is called only if grad is not None | |||
# the OpNode sequence in backward | |||
op_seq = [] | |||
# VariableNode -> (i, j), where i is time stamp in backward, j means jth input | |||
last_written_to = {} | |||
def schedule(): | |||
reached = set(ys) | |||
# i is the time stamp in backward | |||
i = 0 | |||
for o in self.ops[::-1]: | |||
o = o() | |||
if o is None: | |||
continue | |||
if not o.has_grad_fn(o, reached): | |||
continue | |||
op_seq.append(o) | |||
for j, v in enumerate(o.inputs): | |||
reached.add(v) | |||
last_written_to[v] = i, j | |||
i += 1 | |||
schedule() | |||
# VariableNode -> Tensor | |||
cache = {} | |||
def initialize(): | |||
for y, dy in zip(ys, dys): | |||
cache[y] = dy | |||
if y not in last_written_to and y.callback: | |||
y.callback(y.owner(), dy) | |||
initialize() | |||
# NOTE: None is used to mark a node has been consumed | |||
for seqno, opnode in enumerate(op_seq): | |||
input_nodes = opnode.inputs | |||
output_nodes = [i() for i in opnode.outputs] | |||
backward = opnode.backward | |||
backward_allow_noinput = opnode.backward_allow_noinput | |||
opnode.clear() | |||
output_grads = [] | |||
for i in output_nodes: | |||
if i is not None: | |||
if i in cache: | |||
assert cache[i] is not None | |||
output_grads.append(cache[i]) | |||
else: | |||
output_grads.append(None) | |||
# read by backward, mark consumed | |||
cache[i] = None | |||
else: | |||
output_grads.append(None) | |||
if ( | |||
any([grad is not None for grad in output_grads]) | |||
or backward_allow_noinput | |||
): | |||
input_grads = backward(*output_grads) | |||
else: | |||
input_grads = [None] * len(input_nodes) | |||
assert len(input_nodes) == len(input_grads) | |||
for i, (v, g) in enumerate(zip(input_nodes, input_grads)): | |||
if v is None: | |||
continue | |||
if v in cache: | |||
assert cache[v] | |||
if g is not None: | |||
cache[v] = add(cache[v], g) | |||
elif g is not None: | |||
cache[v] = g | |||
if last_written_to[v] == (seqno, i): | |||
if v.callback: | |||
v.callback( | |||
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | |||
) | |||
if v.opnode is None: | |||
# won't read by backward, mark consumed | |||
cache[v] = None | |||
for v in cache.values(): | |||
assert v is None | |||
class clearable: | |||
__cleared = False | |||
def __bool__(self): | |||
return not self.__cleared | |||
def clear(self): | |||
self.__dict__.clear() | |||
self.__cleared = True | |||
class OpNode(clearable): | |||
""" OpNode saves all the information to form the computational graph. | |||
""" | |||
def __init__(self): | |||
self.id = None | |||
self.inputs = None # Could be VariableNode | |||
self.outputs = None # Could be VariableNode | |||
self.backward = None | |||
self.has_grad_fn = None | |||
self.backward_allow_noinput = False | |||
class VariableNode(clearable): | |||
""" VariableNode saves OpNode and callback. | |||
FIXME!!! Explain manager and owner | |||
""" | |||
def __init__(self, manager, owner, opnode=None, callback=None): | |||
# manager is Grad type | |||
self.manager = weakref.ref(manager) | |||
# owner is Tensor type | |||
self.owner = weakref.ref(owner) | |||
self.opnode = opnode | |||
self.callback = callback | |||
class Tracer(clearable, TensorBase): | |||
def __init__(self, node=None): | |||
""" type(node) is VariableNode | |||
""" | |||
self.node = node | |||
@functools.singledispatch | |||
def check_backward_allow_noinput(op: OpDef): | |||
return False | |||
@functools.singledispatch | |||
def get_op_has_grad_fn(op: OpDef): | |||
assert 0 | |||
@get_op_has_grad_fn.register(OpDef) | |||
def _(op: OpDef): | |||
return default_has_grad_fn | |||
@get_op_has_grad_fn.register(Function) | |||
def _(op: Function): | |||
return default_has_grad_fn | |||
def default_has_grad_fn(opnode, reached): | |||
for v in opnode.outputs: | |||
if v() in reached: | |||
return True | |||
return False | |||
@apply.add | |||
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||
input_requires_grad = list(map(bool, args)) | |||
if not any(input_requires_grad): | |||
return | |||
ctx = get_context() | |||
manager = None | |||
assert len(ctx.inputs) == len(args) | |||
for i, j in zip(ctx.inputs, args): | |||
if j: | |||
j = j.node | |||
assert i is j.owner() | |||
if manager is None: | |||
manager = j.manager() | |||
assert manager | |||
else: | |||
assert manager is j.manager() | |||
if not manager._enabled: | |||
return | |||
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||
# register backward method | |||
# tuple of backward functions corresponding to dy / dx_i | |||
# None means y is not a function of x_i | |||
opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||
op, ctx.inputs, ctx.outputs, input_requires_grad | |||
) | |||
assert len(outputs) == len(output_need_grad) | |||
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | |||
opnode.backward_allow_noinput = check_backward_allow_noinput(op) | |||
opnode.has_grad_fn = get_op_has_grad_fn(op) | |||
return tuple(outputs) | |||
@apply.add | |||
def _(op: Const, *_: typing.Optional[Tracer]): | |||
return None |
@@ -0,0 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
@@ -0,0 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
@@ -0,0 +1,10 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .generated_ops import * | |||
from .misc_ops import * |
@@ -0,0 +1,929 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import sys | |||
from functools import reduce | |||
from operator import or_ as _or_ | |||
from types import DynamicClassAttribute, MappingProxyType | |||
# try _collections first to reduce startup cost | |||
try: | |||
from _collections import OrderedDict | |||
except ImportError: | |||
from collections import OrderedDict | |||
__all__ = [ | |||
"EnumMeta", | |||
"Enum", | |||
"IntEnum", | |||
"Flag", | |||
"IntFlag", | |||
"auto", | |||
"unique", | |||
] | |||
def _is_descriptor(obj): | |||
"""Returns True if obj is a descriptor, False otherwise.""" | |||
return ( | |||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") | |||
) | |||
def _is_dunder(name): | |||
"""Returns True if a __dunder__ name, False otherwise.""" | |||
return ( | |||
name[:2] == name[-2:] == "__" | |||
and name[2:3] != "_" | |||
and name[-3:-2] != "_" | |||
and len(name) > 4 | |||
) | |||
def _is_sunder(name): | |||
"""Returns True if a _sunder_ name, False otherwise.""" | |||
return ( | |||
name[0] == name[-1] == "_" | |||
and name[1:2] != "_" | |||
and name[-2:-1] != "_" | |||
and len(name) > 2 | |||
) | |||
def _make_class_unpicklable(cls): | |||
"""Make the given class un-picklable.""" | |||
def _break_on_call_reduce(self, proto): | |||
raise TypeError("%r cannot be pickled" % self) | |||
cls.__reduce_ex__ = _break_on_call_reduce | |||
cls.__module__ = "<unknown>" | |||
_auto_null = object() | |||
class auto: | |||
""" | |||
Instances are replaced with an appropriate value in Enum class suites. | |||
""" | |||
value = _auto_null | |||
class _EnumDict(dict): | |||
"""Track enum member order and ensure member names are not reused. | |||
EnumMeta will use the names found in self._member_names as the | |||
enumeration member names. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self._member_names = [] | |||
self._last_values = [] | |||
def __setitem__(self, key, value): | |||
"""Changes anything not dundered or not a descriptor. | |||
If an enum member name is used twice, an error is raised; duplicate | |||
values are not checked for. | |||
Single underscore (sunder) names are reserved. | |||
""" | |||
if _is_sunder(key): | |||
if key not in ( | |||
"_order_", | |||
"_create_pseudo_member_", | |||
"_generate_next_value_", | |||
"_missing_", | |||
): | |||
raise ValueError("_names_ are reserved for future Enum use") | |||
if key == "_generate_next_value_": | |||
setattr(self, "_generate_next_value", value) | |||
elif _is_dunder(key): | |||
if key == "__order__": | |||
key = "_order_" | |||
elif key in self._member_names: | |||
# descriptor overwriting an enum? | |||
raise TypeError("Attempted to reuse key: %r" % key) | |||
elif not _is_descriptor(value): | |||
if key in self: | |||
# enum overwriting a descriptor? | |||
raise TypeError("%r already defined as: %r" % (key, self[key])) | |||
if isinstance(value, auto): | |||
if value.value == _auto_null: | |||
value.value = self._generate_next_value( | |||
key, 1, len(self._member_names), self._last_values[:] | |||
) | |||
value = value.value | |||
self._member_names.append(key) | |||
self._last_values.append(value) | |||
super().__setitem__(key, value) | |||
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course | |||
# until EnumMeta finishes running the first time the Enum class doesn't exist. | |||
# This is also why there are checks in EnumMeta like `if Enum is not None` | |||
Enum = None | |||
class EnumMeta(type): | |||
"""Metaclass for Enum""" | |||
@classmethod | |||
def __prepare__(metacls, cls, bases): | |||
# create the namespace dict | |||
enum_dict = _EnumDict() | |||
# inherit previous flags and _generate_next_value_ function | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
if first_enum is not None: | |||
enum_dict["_generate_next_value_"] = getattr( | |||
first_enum, "_generate_next_value_", None | |||
) | |||
return enum_dict | |||
def __new__(metacls, cls, bases, classdict): | |||
# an Enum class is final once enumeration items have been defined; it | |||
# cannot be mixed with other types (int, float, etc.) if it has an | |||
# inherited __new__ unless a new __new__ is defined (or the resulting | |||
# class will fail). | |||
member_type, first_enum = metacls._get_mixins_(bases) | |||
__new__, save_new, use_args = metacls._find_new_( | |||
classdict, member_type, first_enum | |||
) | |||
# save enum items into separate mapping so they don't get baked into | |||
# the new class | |||
enum_members = {k: classdict[k] for k in classdict._member_names} | |||
for name in classdict._member_names: | |||
del classdict[name] | |||
# adjust the sunders | |||
_order_ = classdict.pop("_order_", None) | |||
# check for illegal enum names (any others?) | |||
invalid_names = set(enum_members) & { | |||
"mro", | |||
} | |||
if invalid_names: | |||
raise ValueError( | |||
"Invalid enum member name: {0}".format(",".join(invalid_names)) | |||
) | |||
# create a default docstring if one has not been provided | |||
if "__doc__" not in classdict: | |||
classdict["__doc__"] = "An enumeration." | |||
# create our new Enum type | |||
enum_class = super().__new__(metacls, cls, bases, classdict) | |||
enum_class._member_names_ = [] # names in definition order | |||
enum_class._member_map_ = OrderedDict() # name->value map | |||
enum_class._member_type_ = member_type | |||
# save attributes from super classes so we know if we can take | |||
# the shortcut of storing members in the class dict | |||
base_attributes = {a for b in enum_class.mro() for a in b.__dict__} | |||
# Reverse value->name map for hashable values. | |||
enum_class._value2member_map_ = {} | |||
# If a custom type is mixed into the Enum, and it does not know how | |||
# to pickle itself, pickle.dumps will succeed but pickle.loads will | |||
# fail. Rather than have the error show up later and possibly far | |||
# from the source, sabotage the pickle protocol for this class so | |||
# that pickle.dumps also fails. | |||
# | |||
# However, if the new class implements its own __reduce_ex__, do not | |||
# sabotage -- it's on them to make sure it works correctly. We use | |||
# __reduce_ex__ instead of any of the others as it is preferred by | |||
# pickle over __reduce__, and it handles all pickle protocols. | |||
if "__reduce_ex__" not in classdict: | |||
if member_type is not object: | |||
methods = ( | |||
"__getnewargs_ex__", | |||
"__getnewargs__", | |||
"__reduce_ex__", | |||
"__reduce__", | |||
) | |||
if not any(m in member_type.__dict__ for m in methods): | |||
_make_class_unpicklable(enum_class) | |||
# instantiate them, checking for duplicates as we go | |||
# we instantiate first instead of checking for duplicates first in case | |||
# a custom __new__ is doing something funky with the values -- such as | |||
# auto-numbering ;) | |||
for member_name in classdict._member_names: | |||
value = enum_members[member_name] | |||
if not isinstance(value, tuple): | |||
args = (value,) | |||
else: | |||
args = value | |||
if member_type is tuple: # special case for tuple enums | |||
args = (args,) # wrap it one more time | |||
if not use_args: | |||
enum_member = __new__(enum_class) | |||
if not hasattr(enum_member, "_value_"): | |||
enum_member._value_ = value | |||
else: | |||
enum_member = __new__(enum_class, *args) | |||
if not hasattr(enum_member, "_value_"): | |||
if member_type is object: | |||
enum_member._value_ = value | |||
else: | |||
enum_member._value_ = member_type(*args) | |||
value = enum_member._value_ | |||
enum_member._name_ = member_name | |||
enum_member.__objclass__ = enum_class | |||
enum_member.__init__(*args) | |||
# If another member with the same value was already defined, the | |||
# new member becomes an alias to the existing one. | |||
for name, canonical_member in enum_class._member_map_.items(): | |||
if canonical_member._value_ == enum_member._value_: | |||
enum_member = canonical_member | |||
break | |||
else: | |||
# Aliases don't appear in member names (only in __members__). | |||
enum_class._member_names_.append(member_name) | |||
# performance boost for any member that would not shadow | |||
# a DynamicClassAttribute | |||
if member_name not in base_attributes: | |||
setattr(enum_class, member_name, enum_member) | |||
# now add to _member_map_ | |||
enum_class._member_map_[member_name] = enum_member | |||
try: | |||
# This may fail if value is not hashable. We can't add the value | |||
# to the map, and by-value lookups for this value will be | |||
# linear. | |||
enum_class._value2member_map_[value] = enum_member | |||
except TypeError: | |||
pass | |||
# double check that repr and friends are not the mixin's or various | |||
# things break (such as pickle) | |||
for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): | |||
class_method = getattr(enum_class, name) | |||
obj_method = getattr(member_type, name, None) | |||
enum_method = getattr(first_enum, name, None) | |||
if obj_method is not None and obj_method is class_method: | |||
setattr(enum_class, name, enum_method) | |||
# replace any other __new__ with our own (as long as Enum is not None, | |||
# anyway) -- again, this is to support pickle | |||
if Enum is not None: | |||
# if the user defined their own __new__, save it before it gets | |||
# clobbered in case they subclass later | |||
if save_new: | |||
enum_class.__new_member__ = __new__ | |||
enum_class.__new__ = Enum.__new__ | |||
# py3 support for definition order (helps keep py2/py3 code in sync) | |||
if _order_ is not None: | |||
if isinstance(_order_, str): | |||
_order_ = _order_.replace(",", " ").split() | |||
if _order_ != enum_class._member_names_: | |||
raise TypeError("member order does not match _order_") | |||
return enum_class | |||
def __bool__(self): | |||
""" | |||
classes/types should always be True. | |||
""" | |||
return True | |||
def __call__( | |||
cls, value, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
"""Either returns an existing member, or creates a new enum class. | |||
This method is used both when an enum class is given a value to match | |||
to an enumeration member (i.e. Color(3)) and for the functional API | |||
(i.e. Color = Enum('Color', names='RED GREEN BLUE')). | |||
When used for the functional API: | |||
`value` will be the name of the new class. | |||
`names` should be either a string of white-space/comma delimited names | |||
(values will start at `start`), or an iterator/mapping of name, value pairs. | |||
`module` should be set to the module this class is being created in; | |||
if it is not set, an attempt to find that module will be made, but if | |||
it fails the class will not be picklable. | |||
`qualname` should be set to the actual location this class can be found | |||
at in its module; by default it is set to the global scope. If this is | |||
not correct, unpickling will fail in some circumstances. | |||
`type`, if set, will be mixed in as the first base class. | |||
""" | |||
if names is None: # simple value lookup | |||
return cls.__new__(cls, value) | |||
# otherwise, functional API: we're creating a new Enum type | |||
return cls._create_( | |||
value, names, module=module, qualname=qualname, type=type, start=start | |||
) | |||
def __contains__(cls, member): | |||
return isinstance(member, cls) and member._name_ in cls._member_map_ | |||
def __delattr__(cls, attr): | |||
# nicer error message when someone tries to delete an attribute | |||
# (see issue19025). | |||
if attr in cls._member_map_: | |||
raise AttributeError("%s: cannot delete Enum member." % cls.__name__) | |||
super().__delattr__(attr) | |||
def __dir__(self): | |||
return [ | |||
"__class__", | |||
"__doc__", | |||
"__members__", | |||
"__module__", | |||
] + self._member_names_ | |||
def __getattr__(cls, name): | |||
"""Return the enum member matching `name` | |||
We use __getattr__ instead of descriptors or inserting into the enum | |||
class' __dict__ in order to support `name` and `value` being both | |||
properties for enum members (which live in the class' __dict__) and | |||
enum members themselves. | |||
""" | |||
if _is_dunder(name): | |||
raise AttributeError(name) | |||
try: | |||
return cls._member_map_[name] | |||
except KeyError: | |||
raise AttributeError(name) from None | |||
def __getitem__(cls, name): | |||
return cls._member_map_[name] | |||
def __iter__(cls): | |||
return (cls._member_map_[name] for name in cls._member_names_) | |||
def __len__(cls): | |||
return len(cls._member_names_) | |||
@property | |||
def __members__(cls): | |||
"""Returns a mapping of member name->value. | |||
This mapping lists all enum members, including aliases. Note that this | |||
is a read-only view of the internal mapping. | |||
""" | |||
return MappingProxyType(cls._member_map_) | |||
def __repr__(cls): | |||
return "<enum %r>" % cls.__name__ | |||
def __reversed__(cls): | |||
return (cls._member_map_[name] for name in reversed(cls._member_names_)) | |||
def __setattr__(cls, name, value): | |||
"""Block attempts to reassign Enum members. | |||
A simple assignment to the class namespace only changes one of the | |||
several possible ways to get an Enum member from the Enum class, | |||
resulting in an inconsistent Enumeration. | |||
""" | |||
member_map = cls.__dict__.get("_member_map_", {}) | |||
if name in member_map: | |||
raise AttributeError("Cannot reassign members.") | |||
super().__setattr__(name, value) | |||
def _create_( | |||
cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 | |||
): | |||
"""Convenience method to create a new Enum class. | |||
`names` can be: | |||
* A string containing member names, separated either with spaces or | |||
commas. Values are incremented by 1 from `start`. | |||
* An iterable of member names. Values are incremented by 1 from `start`. | |||
* An iterable of (member name, value) pairs. | |||
* A mapping of member name -> value pairs. | |||
""" | |||
metacls = cls.__class__ | |||
bases = (cls,) if type is None else (type, cls) | |||
_, first_enum = cls._get_mixins_(bases) | |||
classdict = metacls.__prepare__(class_name, bases) | |||
# special processing needed for names? | |||
if isinstance(names, str): | |||
names = names.replace(",", " ").split() | |||
if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): | |||
original_names, names = names, [] | |||
last_values = [] | |||
for count, name in enumerate(original_names): | |||
value = first_enum._generate_next_value_( | |||
name, start, count, last_values[:] | |||
) | |||
last_values.append(value) | |||
names.append((name, value)) | |||
# Here, names is either an iterable of (name, value) or a mapping. | |||
for item in names: | |||
if isinstance(item, str): | |||
member_name, member_value = item, names[item] | |||
else: | |||
member_name, member_value = item | |||
classdict[member_name] = member_value | |||
enum_class = metacls.__new__(metacls, class_name, bases, classdict) | |||
# TODO: replace the frame hack if a blessed way to know the calling | |||
# module is ever developed | |||
if module is None: | |||
try: | |||
module = sys._getframe(2).f_globals["__name__"] | |||
except (AttributeError, ValueError) as exc: | |||
pass | |||
if module is None: | |||
_make_class_unpicklable(enum_class) | |||
else: | |||
enum_class.__module__ = module | |||
if qualname is not None: | |||
enum_class.__qualname__ = qualname | |||
return enum_class | |||
@staticmethod | |||
def _get_mixins_(bases): | |||
"""Returns the type for creating enum members, and the first inherited | |||
enum class. | |||
bases: the tuple of bases that was given to __new__ | |||
""" | |||
if not bases: | |||
return object, Enum | |||
# double check that we are not subclassing a class with existing | |||
# enumeration members; while we're at it, see if any other data | |||
# type has been mixed in so we can use the correct __new__ | |||
member_type = first_enum = None | |||
for base in bases: | |||
if base is not Enum and issubclass(base, Enum) and base._member_names_: | |||
raise TypeError("Cannot extend enumerations") | |||
# base is now the last base in bases | |||
if not issubclass(base, Enum): | |||
raise TypeError( | |||
"new enumerations must be created as " | |||
"`ClassName([mixin_type,] enum_type)`" | |||
) | |||
# get correct mix-in type (either mix-in type of Enum subclass, or | |||
# first base if last base is Enum) | |||
if not issubclass(bases[0], Enum): | |||
member_type = bases[0] # first data type | |||
first_enum = bases[-1] # enum type | |||
else: | |||
for base in bases[0].__mro__: | |||
# most common: (IntEnum, int, Enum, object) | |||
# possible: (<Enum 'AutoIntEnum'>, <Enum 'IntEnum'>, | |||
# <class 'int'>, <Enum 'Enum'>, | |||
# <class 'object'>) | |||
if issubclass(base, Enum): | |||
if first_enum is None: | |||
first_enum = base | |||
else: | |||
if member_type is None: | |||
member_type = base | |||
return member_type, first_enum | |||
@staticmethod | |||
def _find_new_(classdict, member_type, first_enum): | |||
"""Returns the __new__ to be used for creating the enum members. | |||
classdict: the class dictionary given to __new__ | |||
member_type: the data type whose __new__ will be used by default | |||
first_enum: enumeration to check for an overriding __new__ | |||
""" | |||
# now find the correct __new__, checking to see of one was defined | |||
# by the user; also check earlier enum classes in case a __new__ was | |||
# saved as __new_member__ | |||
__new__ = classdict.get("__new__", None) | |||
# should __new__ be saved as __new_member__ later? | |||
save_new = __new__ is not None | |||
if __new__ is None: | |||
# check all possibles for __new_member__ before falling back to | |||
# __new__ | |||
for method in ("__new_member__", "__new__"): | |||
for possible in (member_type, first_enum): | |||
target = getattr(possible, method, None) | |||
if target not in { | |||
None, | |||
None.__new__, | |||
object.__new__, | |||
Enum.__new__, | |||
}: | |||
__new__ = target | |||
break | |||
if __new__ is not None: | |||
break | |||
else: | |||
__new__ = object.__new__ | |||
# if a non-object.__new__ is used then whatever value/tuple was | |||
# assigned to the enum member name will be passed to __new__ and to the | |||
# new enum member's __init__ | |||
if __new__ is object.__new__: | |||
use_args = False | |||
else: | |||
use_args = True | |||
return __new__, save_new, use_args | |||
class Enum(metaclass=EnumMeta): | |||
"""Generic enumeration. | |||
Derive from this class to define new enumerations. | |||
""" | |||
def __new__(cls, value): | |||
# all enum instances are actually created during class construction | |||
# without calling this method; this method is called by the metaclass' | |||
# __call__ (i.e. Color(3) ), and by pickle | |||
if type(value) is cls: | |||
# For lookups like Color(Color.RED) | |||
return value | |||
# by-value search for a matching enum member | |||
# see if it's in the reverse mapping (for hashable values) | |||
try: | |||
if value in cls._value2member_map_: | |||
return cls._value2member_map_[value] | |||
except TypeError: | |||
# not there, now do long search -- O(n) behavior | |||
for member in cls._member_map_.values(): | |||
if member._value_ == value: | |||
return member | |||
# still not found -- try _missing_ hook | |||
return cls._missing_(value) | |||
def _generate_next_value_(name, start, count, last_values): | |||
for last_value in reversed(last_values): | |||
try: | |||
return last_value + 1 | |||
except TypeError: | |||
pass | |||
else: | |||
return start | |||
@classmethod | |||
def _missing_(cls, value): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
def __repr__(self): | |||
return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) | |||
def __str__(self): | |||
return "%s.%s" % (self.__class__.__name__, self._name_) | |||
def __dir__(self): | |||
added_behavior = [ | |||
m | |||
for cls in self.__class__.mro() | |||
for m in cls.__dict__ | |||
if m[0] != "_" and m not in self._member_map_ | |||
] | |||
return ["__class__", "__doc__", "__module__"] + added_behavior | |||
def __format__(self, format_spec): | |||
# mixed-in Enums should use the mixed-in type's __format__, otherwise | |||
# we can get strange results with the Enum name showing up instead of | |||
# the value | |||
# pure Enum branch | |||
if self._member_type_ is object: | |||
cls = str | |||
val = str(self) | |||
# mix-in branch | |||
else: | |||
cls = self._member_type_ | |||
val = self._value_ | |||
return cls.__format__(val, format_spec) | |||
def __hash__(self): | |||
return hash(self._name_) | |||
def __reduce_ex__(self, proto): | |||
return self.__class__, (self._value_,) | |||
# DynamicClassAttribute is used to provide access to the `name` and | |||
# `value` properties of enum members while keeping some measure of | |||
# protection from modification, while still allowing for an enumeration | |||
# to have members named `name` and `value`. This works because enumeration | |||
# members are not set directly on the enum class -- __getattr__ is | |||
# used to look them up. | |||
@DynamicClassAttribute | |||
def name(self): | |||
"""The name of the Enum member.""" | |||
return self._name_ | |||
@DynamicClassAttribute | |||
def value(self): | |||
"""The value of the Enum member.""" | |||
return self._value_ | |||
@classmethod | |||
def _convert(cls, name, module, filter, source=None): | |||
""" | |||
Create a new Enum subclass that replaces a collection of global constants | |||
""" | |||
# convert all constants from source (or module) that pass filter() to | |||
# a new Enum called name, and export the enum and its members back to | |||
# module; | |||
# also, replace the __reduce_ex__ method so unpickling works in | |||
# previous Python versions | |||
module_globals = vars(sys.modules[module]) | |||
if source: | |||
source = vars(source) | |||
else: | |||
source = module_globals | |||
# We use an OrderedDict of sorted source keys so that the | |||
# _value2member_map is populated in the same order every time | |||
# for a consistent reverse mapping of number to name when there | |||
# are multiple names for the same number rather than varying | |||
# between runs due to hash randomization of the module dictionary. | |||
members = [(name, source[name]) for name in source.keys() if filter(name)] | |||
try: | |||
# sort by value | |||
members.sort(key=lambda t: (t[1], t[0])) | |||
except TypeError: | |||
# unless some values aren't comparable, in which case sort by name | |||
members.sort(key=lambda t: t[0]) | |||
cls = cls(name, members, module=module) | |||
cls.__reduce_ex__ = _reduce_ex_by_name | |||
module_globals.update(cls.__members__) | |||
module_globals[name] = cls | |||
return cls | |||
class IntEnum(int, Enum): | |||
"""Enum where members are also (and must be) ints""" | |||
def _reduce_ex_by_name(self, proto): | |||
return self.name | |||
class Flag(Enum): | |||
"""Support for flags""" | |||
def _generate_next_value_(name, start, count, last_values): | |||
""" | |||
Generate the next value when not given. | |||
name: the name of the member | |||
start: the initital start value or None | |||
count: the number of existing members | |||
last_value: the last value assigned or None | |||
""" | |||
if not count: | |||
return start if start is not None else 1 | |||
for last_value in reversed(last_values): | |||
try: | |||
high_bit = _high_bit(last_value) | |||
break | |||
except Exception: | |||
raise TypeError("Invalid Flag value: %r" % last_value) from None | |||
return 2 ** (high_bit + 1) | |||
@classmethod | |||
def _missing_(cls, value): | |||
original_value = value | |||
if value < 0: | |||
value = ~value | |||
possible_member = cls._create_pseudo_member_(value) | |||
if original_value < 0: | |||
possible_member = ~possible_member | |||
return possible_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
""" | |||
Create a composite member iff value contains only members. | |||
""" | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
# verify all bits are accounted for | |||
_, extra_flags = _decompose(cls, value) | |||
if extra_flags: | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
# construct a singleton enum pseudo-member | |||
pseudo_member = object.__new__(cls) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __contains__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return other._value_ & self._value_ == other._value_ | |||
def __repr__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
return "<%s.%s: %r>" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
self._value_, | |||
) | |||
def __str__(self): | |||
cls = self.__class__ | |||
if self._name_ is not None: | |||
return "%s.%s" % (cls.__name__, self._name_) | |||
members, uncovered = _decompose(cls, self._value_) | |||
if len(members) == 1 and members[0]._name_ is None: | |||
return "%s.%r" % (cls.__name__, members[0]._value_) | |||
else: | |||
return "%s.%s" % ( | |||
cls.__name__, | |||
"|".join([str(m._name_ or m._value_) for m in members]), | |||
) | |||
def __bool__(self): | |||
return bool(self._value_) | |||
def __or__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ | other._value_) | |||
def __and__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ & other._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, self.__class__): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ other._value_) | |||
def __invert__(self): | |||
members, uncovered = _decompose(self.__class__, self._value_) | |||
inverted_members = [ | |||
m | |||
for m in self.__class__ | |||
if m not in members and not m._value_ & self._value_ | |||
] | |||
inverted = reduce(_or_, inverted_members, self.__class__(0)) | |||
return self.__class__(inverted) | |||
class IntFlag(int, Flag): | |||
"""Support for integer-based Flags""" | |||
@classmethod | |||
def _missing_(cls, value): | |||
if not isinstance(value, int): | |||
raise ValueError("%r is not a valid %s" % (value, cls.__name__)) | |||
new_member = cls._create_pseudo_member_(value) | |||
return new_member | |||
@classmethod | |||
def _create_pseudo_member_(cls, value): | |||
pseudo_member = cls._value2member_map_.get(value, None) | |||
if pseudo_member is None: | |||
need_to_create = [value] | |||
# get unaccounted for bits | |||
_, extra_flags = _decompose(cls, value) | |||
# timer = 10 | |||
while extra_flags: | |||
# timer -= 1 | |||
bit = _high_bit(extra_flags) | |||
flag_value = 2 ** bit | |||
if ( | |||
flag_value not in cls._value2member_map_ | |||
and flag_value not in need_to_create | |||
): | |||
need_to_create.append(flag_value) | |||
if extra_flags == -flag_value: | |||
extra_flags = 0 | |||
else: | |||
extra_flags ^= flag_value | |||
for value in reversed(need_to_create): | |||
# construct singleton pseudo-members | |||
pseudo_member = int.__new__(cls, value) | |||
pseudo_member._name_ = None | |||
pseudo_member._value_ = value | |||
# use setdefault in case another thread already created a composite | |||
# with this value | |||
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) | |||
return pseudo_member | |||
def __or__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
result = self.__class__(self._value_ | self.__class__(other)._value_) | |||
return result | |||
def __and__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ & self.__class__(other)._value_) | |||
def __xor__(self, other): | |||
if not isinstance(other, (self.__class__, int)): | |||
return NotImplemented | |||
return self.__class__(self._value_ ^ self.__class__(other)._value_) | |||
__ror__ = __or__ | |||
__rand__ = __and__ | |||
__rxor__ = __xor__ | |||
def __invert__(self): | |||
result = self.__class__(~self._value_) | |||
return result | |||
def _high_bit(value): | |||
"""returns index of highest bit, or -1 if value is zero or negative""" | |||
return value.bit_length() - 1 | |||
def unique(enumeration): | |||
"""Class decorator for enumerations ensuring unique member values.""" | |||
duplicates = [] | |||
for name, member in enumeration.__members__.items(): | |||
if name != member.name: | |||
duplicates.append((name, member.name)) | |||
if duplicates: | |||
alias_details = ", ".join( | |||
["%s -> %s" % (alias, name) for (alias, name) in duplicates] | |||
) | |||
raise ValueError( | |||
"duplicate values found in %r: %s" % (enumeration, alias_details) | |||
) | |||
return enumeration | |||
def _decompose(flag, value): | |||
"""Extract all members from the value.""" | |||
# _decompose is only called if the value is not named | |||
not_covered = value | |||
negative = value < 0 | |||
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race | |||
# conditions between iterating over it and having more psuedo- | |||
# members added to it | |||
if negative: | |||
# only check for named flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None | |||
] | |||
else: | |||
# check for named flags and powers-of-two flags | |||
flags_to_check = [ | |||
(m, v) | |||
for v, m in list(flag._value2member_map_.items()) | |||
if m.name is not None or _power_of_two(v) | |||
] | |||
members = [] | |||
for member, member_value in flags_to_check: | |||
if member_value and member_value & value == member_value: | |||
members.append(member) | |||
not_covered &= ~member_value | |||
if not members and value in flag._value2member_map_: | |||
members.append(flag._value2member_map_[value]) | |||
members.sort(key=lambda m: m._value_, reverse=True) | |||
if len(members) > 1 and members[0].value == value: | |||
# we have the breakdown, don't need the value member itself | |||
members.pop(0) | |||
return members, not_covered | |||
def _power_of_two(value): | |||
if value < 1: | |||
return False | |||
return value == 2 ** _high_bit(value) |
@@ -0,0 +1,94 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import warnings | |||
from ..._imperative_rt.ops import OprAttr | |||
from . import param_defs | |||
def make_param(param, ptype, kwargs): | |||
if param is not None: | |||
if isinstance(param, ptype): | |||
return param | |||
param = [param] | |||
assert len(param) == len( | |||
ptype.__slots__ | |||
), "{} needs {} params, but {} are provided".format( | |||
ptype, len(ptype.__slots__), len(param) | |||
) | |||
return ptype(*param) | |||
ckw = {} | |||
for i in ptype.__slots__: | |||
val = kwargs.pop(i, ckw) | |||
if val is not ckw: | |||
ckw[i] = val | |||
return ptype(**ckw) | |||
class PodOpVisitor: | |||
__name2subclass = {} | |||
__c = None | |||
name = None | |||
param_names = [] | |||
config = None | |||
def __init__(self, config, **params): | |||
self.config = config | |||
assert set(params) == set(self.param_names) | |||
self.__dict__.update(params) | |||
def __init_subclass__(cls, **kwargs): | |||
super().__init_subclass__(**kwargs) # python 3.5 does not have this | |||
name = cls.name | |||
if name in cls.__name2subclass: | |||
if not issubclass(cls, cls.__name2subclass[name]): | |||
warnings.warn("Multiple subclasses for bultin op: %s" % name) | |||
cls.__name2subclass[name] = cls | |||
def to_c(self): | |||
if self.__c: | |||
return self.__c | |||
op = OprAttr() | |||
op.type = self.name | |||
if self.config is not None: | |||
op.config = self.config | |||
# first 4 bytes is TAG, has to remove them currently | |||
op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) | |||
self.__c = op | |||
return op | |||
def __eq__(self, rhs): | |||
return self.to_c() == rhs.to_c() | |||
def __repr__(self): | |||
name = self.__class__.__name__ | |||
if self.__c: | |||
return "{}(<binary data>)".format(name) | |||
kwargs = {} | |||
for i in self.param_names: | |||
p = self.__dict__[i] | |||
if isinstance(p, param_defs._ParamDefBase): | |||
for k in p.__slots__: | |||
v = getattr(p, k) | |||
if isinstance(v, param_defs._EnumBase): | |||
v = v.name | |||
kwargs[k] = repr(v) | |||
else: | |||
kwargs[i] = repr(p) | |||
if self.config: | |||
if len(self.config.comp_node_arr) == 1: | |||
kwargs["device"] = "'%s'" % self.config.comp_node | |||
return "{}({})".format( | |||
name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) | |||
) |
@@ -0,0 +1,194 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import ctypes | |||
from ..._imperative_rt import OperatorNodeConfig as Config | |||
from . import param_defs | |||
from .helper import PodOpVisitor, make_param | |||
__all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] | |||
class TensorShape: | |||
MAX_NDIM = 7 | |||
class ConvolutionBackwardData(PodOpVisitor): | |||
param_names = ( | |||
"param", | |||
"execution_polity", | |||
) | |||
name = "ConvolutionBackwardDataV1" | |||
def __init__( | |||
self, | |||
*, | |||
param=None, | |||
execution_polity=None, | |||
name=None, | |||
comp_node=None, | |||
config=None, | |||
dtype=None, | |||
**kwargs | |||
): | |||
config = config or Config() | |||
if name: | |||
config.name = name | |||
if comp_node: | |||
config.comp_node = comp_node | |||
if dtype: | |||
config.dtype = dtype | |||
self.config = config | |||
self.param = make_param(param, param_defs.Convolution, kwargs) | |||
self.execution_polity = make_param( | |||
execution_polity, param_defs.ExecutionPolicy, kwargs | |||
) | |||
assert not kwargs, "extra kwargs: {}".format(kwargs) | |||
class Dimshuffle(PodOpVisitor): | |||
name = "Dimshuffle" | |||
param_names = ("pattern",) | |||
class Pattern(ctypes.Structure): | |||
Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM | |||
_fields_ = [ | |||
("length", ctypes.c_uint32), | |||
("pattern", Pattern_Array), | |||
("ndim", ctypes.c_uint32), | |||
] | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, pattern, ndim=0): | |||
assert isinstance(pattern, collections.Iterable) | |||
assert len(pattern) <= TensorShape.MAX_NDIM | |||
pattern_array = Dimshuffle.Pattern.Pattern_Array() | |||
for idx, v in enumerate(pattern): | |||
pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) | |||
self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) | |||
class Reshape(PodOpVisitor): | |||
name = "ReshapeV1" | |||
param_names = ("unspec_axis",) | |||
def __init__(self, unspec_axis=None): | |||
if unspec_axis is None: | |||
self.unspec_axis = param_defs.OptionalAxisV1() | |||
else: | |||
self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) | |||
class AxisNum(ctypes.Structure): | |||
_fields_ = [ | |||
("m_num", ctypes.c_int), | |||
] | |||
class AxisDesc(ctypes.Structure): | |||
class Method(ctypes.c_int): | |||
ADD_1 = 0 | |||
REMOVE = 1 | |||
_fields_ = [ | |||
("method", Method), | |||
("axis", AxisNum), | |||
] | |||
@classmethod | |||
def make_add(cls, axis): | |||
return cls(cls.Method.ADD_1, AxisNum(axis)) | |||
@classmethod | |||
def make_remove(cls, axis): | |||
return cls(cls.Method.REMOVE, AxisNum(axis)) | |||
class AxisAddRemove(PodOpVisitor): | |||
name = "AxisAddRemove" | |||
param_names = ("param",) | |||
AxisDesc = AxisDesc | |||
class Param(ctypes.Structure): | |||
MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 | |||
_fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] | |||
def __init__(self, *args): | |||
super().__init__() | |||
self.nr_desc = len(args) | |||
for i, a in enumerate(args): | |||
self.desc[i] = a | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, param): | |||
assert isinstance(param, self.Param) | |||
self.param = param | |||
del AxisDesc | |||
class IndexingOpBase(PodOpVisitor): | |||
param_names = ("index_desc",) | |||
class IndexDescMaskDump(ctypes.Structure): | |||
class Item(ctypes.Structure): | |||
_fields_ = [ | |||
("axis", ctypes.c_int8), | |||
("begin", ctypes.c_bool), | |||
("end", ctypes.c_bool), | |||
("step", ctypes.c_bool), | |||
("idx", ctypes.c_bool), | |||
] | |||
Item_Array = Item * TensorShape.MAX_NDIM | |||
_fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] | |||
def serialize(self): | |||
return bytes(ctypes.c_uint32(0)) + bytes(self) | |||
def __init__(self, items): | |||
nr_item = len(items) | |||
assert nr_item <= TensorShape.MAX_NDIM | |||
item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() | |||
for idx, item in enumerate(items): | |||
assert isinstance(item, (tuple, list)) and len(item) == 5 | |||
item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) | |||
self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) | |||
def _gen_indexing_defs(*names): | |||
for name in names: | |||
globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) | |||
__all__.append(name) | |||
_gen_indexing_defs( | |||
"Subtensor", | |||
"SetSubtensor", | |||
"IncrSubtensor", | |||
"IndexingMultiAxisVec", | |||
"IndexingSetMultiAxisVec", | |||
"IndexingIncrMultiAxisVec", | |||
"MeshIndexing", | |||
"IncrMeshIndexing", | |||
"SetMeshIndexing", | |||
"BatchedMeshIndexing", | |||
"BatchedIncrMeshIndexing", | |||
"BatchedSetMeshIndexing", | |||
) |
@@ -0,0 +1,37 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import warnings | |||
from typing import Union | |||
from ..._imperative_rt import OpDef, ops | |||
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from .._internal import all_ops | |||
from .._internal.helper import PodOpVisitor | |||
# register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
# apply(OpBase, ...) rules could work well on OpDef | |||
OpBase.register(OpDef) | |||
# forward to apply(OpDef, ...) | |||
@apply.add | |||
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | |||
return apply(op.to_c(), *args) | |||
__all__ = ["OpDef", "PodOpVisitor"] | |||
for k, v in all_ops.__dict__.items(): | |||
if isinstance(v, type) and issubclass(v, PodOpVisitor): | |||
globals()[k] = v | |||
__all__.append(k) | |||
for k, v in ops.__dict__.items(): | |||
if isinstance(v, type) and issubclass(v, OpDef): | |||
globals()[k] = v | |||
__all__.append(k) |
@@ -0,0 +1,16 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..tensor.core import OpBase, TensorBase, apply | |||
class Const(OpBase): | |||
def __init__(self, value=None, *, dtype=None, device=None): | |||
self.value = value | |||
self.dtype = dtype | |||
self.device = device |
@@ -0,0 +1,9 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .tensor_wrapper import TensorWrapper as Tensor |
@@ -0,0 +1,115 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import functools | |||
import inspect | |||
import sys | |||
import typing | |||
from abc import ABC | |||
import multipledispatch | |||
class OpBase(ABC): | |||
def __call__(self, *args): | |||
return apply(self, *args) | |||
class TensorBase: | |||
pass | |||
class TensorWrapperBase: | |||
pass | |||
class Dispatcher(multipledispatch.Dispatcher): | |||
def add(self, f, g=None): | |||
if g is None: | |||
super().add(get_signature(f), f) | |||
else: | |||
super().add(f, g) | |||
return f | |||
def __get__(self, instance, owner=None): | |||
if instance is not None: | |||
return self | |||
return functools.partial(self, instance) | |||
if sys.version_info < (3, 6): | |||
def parse_union(ann): | |||
if type(ann) is not typing.UnionMeta: | |||
return | |||
return ann.__union_params__ | |||
elif sys.version_info < (3, 7): | |||
def parse_union(ann): | |||
if type(ann) is not typing._Union: | |||
return | |||
return ann.__args__ | |||
elif sys.version_info < (3, 8): | |||
def parse_union(ann): | |||
if type(ann) is not typing._GenericAlias: | |||
if type(ann) is not typing.Union: | |||
return | |||
else: | |||
if ann.__origin__ is not typing.Union: | |||
return | |||
return ann.__args__ | |||
else: | |||
def parse_union(ann): | |||
if typing.get_origin(ann) is not typing.Union: | |||
return | |||
return typing.get_args(ann) | |||
def get_signature(function, op_type=None): | |||
sig = inspect.signature(function) | |||
types = [] | |||
for p in sig.parameters.values(): | |||
ann = p.annotation | |||
ann = parse_union(ann) or ann | |||
if p.kind in ( | |||
inspect.Parameter.POSITIONAL_ONLY, | |||
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||
): | |||
types.append(ann) | |||
if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||
types.append([ann]) | |||
return tuple(types) | |||
apply = Dispatcher("apply") | |||
OpBase.apply = apply | |||
@apply.add | |||
def _(op: OpBase, *args: TensorBase): | |||
raise NotImplementedError | |||
@apply.add | |||
def _(op: OpBase, *args: TensorWrapperBase): | |||
assert args | |||
Wrapper = type(args[0]) | |||
outputs = apply(op, *(i.__wrapped__ for i in args)) | |||
assert isinstance(outputs, tuple) | |||
return tuple(map(Wrapper, outputs)) |
@@ -0,0 +1,289 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Union | |||
import numpy as np | |||
# normal dtype related | |||
from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||
def is_lowbit(dtype): | |||
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
def is_bfloat16(dtype): | |||
return dtype is bfloat16 | |||
# quantization dtype related | |||
_QuantDtypeMetadata = collections.namedtuple( | |||
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||
) | |||
_metadata_dict = { | |||
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||
"qint32": _QuantDtypeMetadata( | |||
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||
), | |||
# NOTE: int2 is not supported for model dump yet | |||
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||
} | |||
def is_quantize(dtype): | |||
return ( | |||
hasattr(dtype, "metadata") | |||
and dtype.metadata is not None | |||
and "mgb_dtype" in dtype.metadata | |||
) | |||
def get_scale(dtype): | |||
assert is_quantize(dtype) | |||
return dtype.metadata["mgb_dtype"]["scale"] | |||
def get_zero_point(dtype): | |||
assert is_quantize(dtype) | |||
metadata = dtype.metadata["mgb_dtype"] | |||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
return metadata["zero_point"] | |||
def _check_zero_point(zp: int, dtype_str: str): | |||
qmin = _metadata_dict[dtype_str].qmin | |||
qmax = _metadata_dict[dtype_str].qmax | |||
if zp < qmin or zp > qmax: | |||
raise ValueError( | |||
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||
) | |||
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
r""" | |||
Get quantized dtype with metadata attribute according to _metadata_dict. | |||
Note that unsigned dtype must have ``zero_point`` and signed dtype must | |||
not have ``zero_point``, to be consitent with tensor generated by calling | |||
compiled function from `CompGraph.compile(inputs, outspec)`. | |||
:param dtype: a string indicating which dtype to return | |||
:param scale: a number for scale to store in dtype's metadata | |||
:param zp: a number for zero_point to store in dtype's metadata | |||
""" | |||
metadata = _metadata_dict[dtype_str] | |||
np_dtype_str = metadata.np_dtype_str | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
if zp is None or int(zp) != zp: | |||
raise ValueError("zero_point should be an integer") | |||
zp = int(zp) | |||
_check_zero_point(zp, dtype_str) | |||
return np.dtype( | |||
np_dtype_str, | |||
metadata={ | |||
"mgb_dtype": { | |||
"name": metadata.name, | |||
"scale": float(scale), | |||
"zero_point": zp, | |||
} | |||
}, | |||
) | |||
else: | |||
return np.dtype( | |||
np_dtype_str, | |||
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||
) | |||
def quint8(scale, zero_point): | |||
""" | |||
Consturct a quantized unsigned int8 data type with ``scale`` (float) and | |||
``zero_point`` (uint8). The real value represented by a quint8 data type is | |||
float_val = scale * (uint8_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint8", scale, zero_point) | |||
def qint8(scale): | |||
""" | |||
Construct a quantized int8 data type with ``scale`` (float). The real value | |||
represented by a qint8 data type is float_val = scale * int8_val | |||
""" | |||
return get_quantized_dtype("qint8", scale, None) | |||
def qint32(scale): | |||
""" | |||
Construct a quantized int32 data type with ``scale`` (float). The real value | |||
represented by a qint32 data type is float_val = scale * int32_val | |||
""" | |||
return get_quantized_dtype("qint32", scale, None) | |||
def quint4(scale, zero_point): | |||
""" | |||
Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
float_val = scale * (uint4_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint4", scale, zero_point) | |||
def qint4(scale): | |||
""" | |||
Construct a quantized int4 data type with ``scale`` (float). The real value | |||
represented by a qint4 data type is float_val = scale * int4_val | |||
""" | |||
return get_quantized_dtype("qint4", scale, None) | |||
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = dtype.metadata["mgb_dtype"] | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
) | |||
return ( | |||
(np.round(arr / scale) + zp) | |||
.clip(metadata.qmin, metadata.qmax) | |||
.astype(dtype) | |||
) | |||
else: | |||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
scale = arr_metadata["scale"] | |||
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
) | |||
return (arr.astype(np.float32) - zp) * scale | |||
else: | |||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
scale = arr_metadata["scale"] | |||
return (arr.astype(np.float32)) * scale | |||
def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint8") | |||
def convert_from_quint8(arr: np.ndarray): | |||
""" | |||
Dequantize a quint8 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint8") | |||
def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint8 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint8") | |||
def convert_from_qint8(arr: np.ndarray): | |||
""" | |||
Dequantize a qint8 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint8") | |||
def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint32 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint32") | |||
def convert_from_qint32(arr): | |||
""" | |||
Dequantize a qint32 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint32") | |||
def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint4") | |||
def convert_from_quint4(arr: np.ndarray): | |||
""" | |||
Dequantize a quint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint4") | |||
def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
""" | |||
Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint4") | |||
def convert_from_qint4(arr: np.ndarray): | |||
""" | |||
Dequantize a qint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint4") |
@@ -0,0 +1,158 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..ops.builtin import OpDef | |||
from .core import TensorBase, TensorWrapperBase, apply | |||
from .raw_tensor import RawTensor | |||
from .tensor import Tensor, push_context | |||
from .tensor_wrapper import TensorWrapper | |||
class Function: | |||
""" | |||
Defines a block of operations with customizable differentiation. | |||
The computation should be defined in ``forward`` method, with gradient | |||
computation defined in ``backward`` method. | |||
Each instance of ``Function`` should be used only once during forwardding. | |||
Examples: | |||
.. testcode:: | |||
class Sigmoid(Function): | |||
def forward(self, x): | |||
y = 1 / (1 + F.exp(-x)) | |||
self.y = y | |||
return y | |||
def backward(self. output_grads): | |||
y = self.y | |||
return output_grads * y * (1-y) | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def __call__(self, *args): | |||
ret = apply(self, *args) | |||
if type(ret) == tuple and len(ret) == 1: | |||
return ret[0] | |||
return ret | |||
def forward(self, *args, **kwargs): | |||
""" | |||
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||
:param input: Input tensors. | |||
:return: A tuple of Tensor or a single Tensor. | |||
.. note:: | |||
This method should return a tuple of Tensor or a single Tensor representing the output | |||
of the function. | |||
""" | |||
raise NotImplementedError | |||
def backward(self, *output_grads): | |||
""" | |||
Compute the gradient of the forward function. It must be overriden by all subclasses. | |||
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` | |||
.. note:: | |||
In case when some tensors of outputs are not related to loss function, the corresponding | |||
values in ``output_grads`` would be ``None``. | |||
.. note:: | |||
This method should return a tuple which containing the gradients of all inputs, in the same order | |||
as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||
instead if there is only one input. If users want to stop the propagation of some gradients, | |||
the corresponding returned values should be set ``None`` . | |||
""" | |||
raise NotImplementedError | |||
def get_backward_fn(self): | |||
if self.backward is None: | |||
return None | |||
def _backward(*output_grads): | |||
if type(output_grads) is tuple: | |||
_output_grads = map(TensorWrapper, output_grads) | |||
else: | |||
_output_grads = (TensorWrapper(output_grads),) | |||
ret = self.backward(*_output_grads) | |||
if type(ret) is not tuple: | |||
ret = (ret,) | |||
ret = tuple([i.__wrapped__ for i in ret]) | |||
return ret | |||
return _backward | |||
Function.apply = Function.__call__ | |||
@apply.add | |||
def _(op: Function, *args: TensorWrapperBase): | |||
assert args | |||
Wrapper = type(args[0]) | |||
# compute the value for self define function | |||
extra_data_dic = {} | |||
for arg in args: | |||
extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data | |||
arg.__wrapped__._extra_data = {} | |||
rets = op.forward(*args) | |||
for arg in args: | |||
arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] | |||
# update the gradient information for self define function | |||
inputs = tuple(map(lambda i: i.__wrapped__, args)) | |||
outputs = ( | |||
tuple(map(lambda i: i.__wrapped__, rets)) | |||
if type(rets) is tuple | |||
else (rets.__wrapped__,) | |||
) | |||
for output in outputs: | |||
output._extra_data = {} | |||
with push_context() as ctx: | |||
ctx.inputs = inputs | |||
ctx.outputs = outputs | |||
for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): | |||
ctx.key = k | |||
data = tuple( | |||
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs | |||
) | |||
# data are instances of Tracer | |||
# dispatched to apply.add@grad.py | |||
rets = apply(op, *data) | |||
if rets is not None: | |||
assert len(outputs) == len(rets) | |||
for t, i in zip(outputs, rets): | |||
t._extra_data[k] = i | |||
return tuple(map(Wrapper, outputs)) | |||
@apply.add | |||
def _(op: Function, *args: Tensor): | |||
raise NotImplementedError | |||
@apply.add | |||
def _(op: Function, *args: RawTensor): | |||
raise NotImplementedError |
@@ -0,0 +1,251 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .core import TensorBase, TensorWrapperBase, apply | |||
def remove_ellipsis(tensor, tuple_val): | |||
ndim_sum = tensor.ndim | |||
cur_sum = 0 | |||
pos = -1 | |||
for i_idx, i in enumerate(tuple_val): | |||
if i is Ellipsis: | |||
for j in tuple_val[:i_idx:-1]: | |||
if j is Ellipsis: | |||
raise IndexError("only one ellipsis is allowed") | |||
pos = i_idx | |||
else: | |||
cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||
if pos == -1: | |||
return tuple_val | |||
else: | |||
return ( | |||
tuple_val[:pos] | |||
+ (slice(None, None, None),) * (ndim_sum - cur_sum) | |||
+ tuple_val[pos + 1 :] | |||
) | |||
def check_bool_index(tensor, tuple_val): | |||
cur_shape = tensor.shape | |||
new_tuple_val = [] | |||
offset = 0 | |||
tdim = 0 | |||
for idx, i in enumerate(tuple_val): | |||
if hasattr(i, "dtype") and i.dtype == np.bool_: | |||
if i.ndim > 1: | |||
tot = i.ndim | |||
for j in range(i.ndim): | |||
if cur_shape[tdim + j - offset] != i.shape[j]: | |||
raise IndexError( | |||
"boolean index did not match tensor along dimension {}; dimension is {} but corresponding boolean dimension is {}".format( | |||
tdim + j, cur_shape[tdim + j - offset], i.shape[j] | |||
) | |||
) | |||
i = i.reshape(-1) | |||
cur_shape = ( | |||
cur_shape[:idx] + (i.shape[0],) + cur_shape[tdim + tot - offset :] | |||
) | |||
offset += 1 | |||
tensor = tensor.reshape(cur_shape) | |||
tdim += tot | |||
new_tuple_val.append(i) | |||
else: | |||
new_tuple_val.append(i) | |||
tdim += 1 | |||
return tensor, new_tuple_val | |||
def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
if not isinstance(tuple_val, tuple): | |||
tuple_val = (tuple_val,) | |||
ndim_indexed = 0 | |||
for i in tuple_val: | |||
if not i is Ellipsis: | |||
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||
if ndim_indexed > inp.ndim: | |||
raise IndexError( | |||
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
inp.ndim, ndim_indexed | |||
) | |||
) | |||
tuple_val = remove_ellipsis(inp, tuple_val) | |||
use_subtensor = True | |||
inp, tuple_val = check_bool_index(inp, tuple_val) | |||
def is_scalar(d): | |||
if isinstance(i, int): | |||
return True | |||
if type(d).__module__ == np.__name__: | |||
return np.isscalar(d) | |||
# if isinstance(d, (TensorBase, TensorWrapperBase)): | |||
# return d.shape == (1,) | |||
return False | |||
new_axes = [] | |||
tensors = [] | |||
items = [] | |||
cur_axis = -1 | |||
for i_idx, i in enumerate(tuple_val): | |||
cur_axis += 1 | |||
if i is np.newaxis: | |||
if cur_axis >= 0: | |||
new_axes.append(cur_axis) | |||
continue | |||
if i is Ellipsis: | |||
cur_axis = -1 | |||
for j in tuple_val[:i_idx:-1]: | |||
if j is Ellipsis: | |||
raise IndexError("only one ellipsis is allowed") | |||
if j is np.newaxis: | |||
new_axes.append(cur_axis) | |||
cur_axis -= 1 | |||
continue | |||
if ( | |||
not is_scalar(i) | |||
and not i is np.newaxis | |||
and not i is Ellipsis | |||
and not isinstance(i, slice) | |||
): | |||
use_subtensor = False | |||
item = [ | |||
cur_axis, | |||
] | |||
def is_bool_list(x): | |||
if not isinstance(x, list): | |||
return False | |||
for i in x: | |||
if not isinstance(i, bool): | |||
return False | |||
return True | |||
def get_index(i): | |||
if not isinstance(i, (TensorBase, TensorWrapperBase)): | |||
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
else: | |||
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
return i | |||
assert isinstance(i, (TensorBase, TensorWrapperBase)) | |||
if i.dtype != np.bool_: | |||
return i | |||
_, ind = apply(builtin.CondTake(), i, i) | |||
return ind | |||
def push(v, item, tensors): | |||
if v is None: | |||
item.append(False) | |||
else: | |||
item.append(True) | |||
v = get_index(v) | |||
assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( | |||
v.dtype, np.bool | |||
), "var type in the subscript must be int or bool" | |||
tensors.append(v) | |||
if isinstance(i, slice): | |||
if i.start is None and i.stop is None and i.step is None: | |||
continue | |||
push(i.start, item, tensors) | |||
push(i.stop, item, tensors) | |||
push(i.step, item, tensors) | |||
item.append(False) # idx | |||
else: | |||
item += [False,] * 3 # begin, end, stop | |||
push(i, item, tensors) | |||
assert len(item) == 5 | |||
items.append(item) | |||
if new_axes: | |||
raise IndexError("newaxis is not allowed here") | |||
return inp, tensors, items, use_subtensor | |||
def try_condtake(tensor, index): | |||
if not hasattr(index, "dtype") or not hasattr(index, "shape"): | |||
return [] | |||
if index.dtype != np.bool_ or index.shape != tensor.shape: | |||
return [] | |||
if isinstance(index, np.ndarray): | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
assert isinstance(index, (TensorBase, TensorWrapperBase)) | |||
if not isinstance(tensor, (TensorWrapperBase, TensorBase)): | |||
raise TypeError("input must be a tensor") | |||
if tensor.device != index.device: | |||
raise ValueError( | |||
"ambiguous device: {} vs {}".format(tensor.device, index.device) | |||
) | |||
return apply(builtin.CondTake(), tensor, index) | |||
def getitem(tensor, index): | |||
try_result = try_condtake(tensor, index) | |||
if len(try_result) == 2: | |||
return try_result[0] | |||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
for v in tensors: | |||
if v.shape[0] == 0: | |||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
tensor | |||
) | |||
return empty_tensor | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
else: | |||
op = builtin.IndexingMultiAxisVec(items=items) | |||
(result,) = apply(op, tensor, *tensors) | |||
return result | |||
def setitem(tensor, index, value): | |||
org_shape = tensor.shape | |||
try_result = try_condtake(tensor, index) | |||
if len(try_result) == 2: | |||
index = try_result[1] | |||
if index.shape[0] == 0: | |||
return tensor | |||
tensor = tensor.reshape(-1) | |||
if not isinstance(value, (TensorBase, TensorWrapperBase)): | |||
op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
(value,) = op(tensor) | |||
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) | |||
for v in tensors: | |||
if v.shape[0] == 0: | |||
return tensor | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
else: | |||
op = builtin.IndexingMultiAxisVec(items=items) | |||
(tmp_result,) = apply(op, tensor, *tensors) | |||
if value.shape != tmp_result.shape: | |||
for i in range(min(len(value.shape), len(tmp_result.shape))): | |||
if ( | |||
value.shape[-i - 1] != 1 | |||
and value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||
): | |||
raise ValueError( | |||
"cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
value.shape, tmp_result.shape | |||
) | |||
) | |||
value = value.broadcast(tmp_result.shape) | |||
if use_subtensor: | |||
op = builtin.SetSubtensor(items=items) | |||
else: | |||
op = builtin.IndexingSetMultiAxisVec(items=items) | |||
(result,) = apply(op, tensor, value, *tensors) | |||
result = result.reshape(org_shape) | |||
return result |
@@ -0,0 +1,196 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import threading | |||
import weakref | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
from .. import _imperative_rt | |||
from .._wrap import device as as_device | |||
from ..ops.builtin import OpDef | |||
from .core import OpBase, TensorBase, apply | |||
class CompiledFunction: | |||
def __init__(self, graph, function): | |||
self._graph = graph | |||
self._function = function | |||
self._future = None | |||
def execute(self, *args): | |||
assert self._future is None | |||
self._future = self._graph._executor.submit(self._function.execute, *args) | |||
def wait(self): | |||
assert self._future is not None | |||
self._future.exception() | |||
self._function.wait() | |||
try: | |||
return self._future.result() | |||
finally: | |||
self._future = None | |||
def __call__(self, *args): | |||
self.execute(*args) | |||
return self.wait() | |||
class Graph(_imperative_rt.ComputingGraph): | |||
def __init__(self): | |||
super().__init__() | |||
self._var_cache = weakref.WeakKeyDictionary() | |||
self._op_cache = weakref.WeakKeyDictionary() | |||
self._executor = ThreadPoolExecutor(1) | |||
def _wrap(self, obj): | |||
if type(obj) is _imperative_rt.VarNode: | |||
wrapper, cache = VarNode, self._var_cache | |||
elif type(obj) is _imperative_rt.OperatorNode: | |||
wrapper, cache = OpNode, self._op_cache | |||
if obj not in cache: | |||
cache[obj] = wrapper(obj) | |||
return cache[obj] | |||
def compile(self, *args): | |||
return CompiledFunction(self, super().compile(_unwrap(args))) | |||
class VarNode(TensorBase): | |||
def __init__(self, node: _imperative_rt.VarNode): | |||
self._node = node | |||
@property | |||
def graph(self) -> Graph: | |||
return self._node.graph | |||
@property | |||
def op(self): | |||
return self.graph._wrap(self._node.owner) | |||
@property | |||
def dtype(self): | |||
return self._node.dtype | |||
@property | |||
def device(self): | |||
return as_device(self._node.comp_node) | |||
class OpNode: | |||
def __init__(self, node: _imperative_rt.OperatorNode): | |||
self._node = node | |||
@property | |||
def graph(self) -> Graph: | |||
return self._node.graph | |||
@property | |||
def inputs(self): | |||
return tuple(map(self.graph._wrap, self._node.inputs)) | |||
@property | |||
def outputs(self): | |||
return tuple(map(self.graph._wrap, self._node.outputs)) | |||
def _wrap(x): | |||
if isinstance(x, collections.Sequence): | |||
return type(x)(map(_wrap, x)) | |||
return x.graph._wrap(x) | |||
def _unwrap(x): | |||
if isinstance(x, collections.Sequence): | |||
return type(x)(map(_unwrap, x)) | |||
return x._node | |||
@apply.add | |||
def _(op: OpDef, *args: VarNode): | |||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
return _wrap(outputs) | |||
def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||
outputs = _imperative_rt.input_callback( | |||
callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||
) | |||
value, dummy = _wrap(outputs) | |||
return value, dummy | |||
class InputNode(OpNode): | |||
def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||
r = _imperative_rt.DeviceTensorNDRendezvous() | |||
if device is not None: | |||
device = as_device(device).to_c() | |||
outputs = _imperative_rt.input_callback( | |||
r, device, dtype, _unwrap(args), graph=graph | |||
) | |||
super().__init__(outputs[0].owner) | |||
self._rendezvous = r | |||
def set_value(self, value): | |||
assert isinstance(value, _imperative_rt.DeviceTensorND) | |||
self._rendezvous.set(value) | |||
def reset(self): | |||
self._rendezvous.reset() | |||
@property | |||
def device(self): | |||
return self.outputs[0].device | |||
@property | |||
def dtype(self): | |||
return self.outputs[0].dtype | |||
def output_callback(callback, var, *args): | |||
args = (var,) + args | |||
dummy = _imperative_rt.output_callback(callback, _unwrap(args)) | |||
return _wrap(dummy) | |||
class OutputNode(OpNode): | |||
def __init__(self, var, *args): | |||
args = (var,) + args | |||
r = _imperative_rt.DeviceTensorNDRendezvous() | |||
dummy = _imperative_rt.output_callback(r, _unwrap(args)) | |||
super().__init__(dummy.owner) | |||
self._rendezvous = r | |||
def get_value(self): | |||
return self._rendezvous.get() | |||
def reset(self): | |||
self._rendezvous.reset() | |||
class TensorAttr: | |||
def __init__(self, shape, dtype, device): | |||
self.shape = shape | |||
self.dtype = dtype | |||
self.device = device | |||
class AttrOutputNode(OpNode): | |||
def __init__(self, var, *args): | |||
args = (var,) + args | |||
r = _imperative_rt.TensorAttrRendezvous() | |||
dummy = _imperative_rt.attr_output_callback(r, _unwrap(args)) | |||
super().__init__(dummy.owner) | |||
self._rendezvous = r | |||
def get_value(self): | |||
attr = self._rendezvous.get() | |||
return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | |||
def reset(self): | |||
self._rendezvous.reset() |
@@ -0,0 +1,108 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import numpy as np | |||
from ..._imperative_rt import CompNode, DeviceTensorND | |||
from ..._imperative_rt.imperative import ( | |||
_get_dev_tensor, | |||
apply_op, | |||
delete, | |||
get_device, | |||
get_dtype, | |||
get_shape, | |||
get_value, | |||
put, | |||
) | |||
from ..._wrap import device as as_device | |||
from ...ops.builtin import Copy, OpDef, TypeCvt | |||
from ...ops.special import Const | |||
from ..core import OpBase, TensorBase, apply | |||
class RawTensor(TensorBase): | |||
_init_cb = None | |||
_del_cb = None | |||
def __init__(self, handle): | |||
self._handle = handle | |||
if self._init_cb: | |||
self._init_cb() | |||
@property | |||
def dtype(self): | |||
return get_dtype(self._handle) | |||
@property | |||
def device(self): | |||
return as_device(get_device(self._handle)) | |||
@property | |||
def shape(self): | |||
return get_shape(self._handle) | |||
def numpy(self): | |||
return get_value(self._handle) | |||
def _dev_tensor(self): | |||
return _get_dev_tensor(self._handle) | |||
def __repr__(self): | |||
return "{}({}, device='{}')".format( | |||
type(self).__qualname__, repr(self.numpy()), self.device | |||
) | |||
def __del__(self): | |||
if self._del_cb: | |||
self._del_cb() | |||
delete(self._handle) | |||
@apply.add | |||
def _(op: OpDef, *args: RawTensor): | |||
outputs = apply_op(op, tuple(i._handle for i in args)) | |||
return tuple(map(RawTensor, outputs)) | |||
@apply.add | |||
def _(op: Const, *args: RawTensor): | |||
dtype = op.dtype | |||
device = as_device(op.device).to_c() | |||
return (as_raw_tensor(op.value, dtype=dtype, device=device),) | |||
@functools.singledispatch | |||
def as_raw_tensor(obj, dtype=None, device=None): | |||
obj = np.asarray(obj, dtype=dtype) | |||
if obj.dtype == np.float64: | |||
obj = obj.astype(np.float32) | |||
if obj.dtype == np.int64: | |||
obj = obj.astype(np.int32) | |||
return as_raw_tensor(obj, device=device) | |||
@as_raw_tensor.register(np.ndarray) | |||
def _(array: np.ndarray, dtype=None, device=None): | |||
device = None if device is None else as_device(device).to_c() | |||
return RawTensor(put(array, dtype=dtype, device=device)) | |||
@as_raw_tensor.register(RawTensor) | |||
def _(tensor: RawTensor, dtype=None, device=None): | |||
if dtype is not None: | |||
dtype = np.dtype(dtype) | |||
if dtype != tensor.dtype: | |||
(tensor,) = apply(TypeCvt(dtype=dtype), tensor) | |||
if device is not None: | |||
device = as_device(device) | |||
if device != tensor.device: | |||
(tensor,) = apply(Copy(comp_node=device.to_c()), tensor) | |||
return tensor |
@@ -0,0 +1,251 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import io | |||
import weakref | |||
class partial(functools.partial): | |||
def __get__(self, instance, owner=None): | |||
if instance is None: | |||
return self | |||
return functools.partial(self, instance) | |||
def hook(f): | |||
def decorator(impl): | |||
return functools.update_wrapper(partial(f, impl), impl) | |||
return decorator | |||
def on_input(impl, value): | |||
tensor = impl(value) | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(tensor) | |||
event = InputEvent(var) | |||
trace.append(event) | |||
return tensor | |||
def on_read_dtype(impl, self): | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(self) | |||
event = ReadDtypeEvent(var) | |||
trace.append(event) | |||
return impl(self) | |||
def on_read_device(impl, self): | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(self) | |||
event = ReadDeviceEvent(var) | |||
trace.append(event) | |||
return impl(self) | |||
def on_read_shape(impl, self): | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(self) | |||
event = ReadShapeEvent(var) | |||
trace.append(event) | |||
return impl(self) | |||
def on_read_value(impl, self): | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(self) | |||
event = ReadValueEvent(var) | |||
trace.append(event) | |||
return impl(self) | |||
def on_builtin_op(impl, op, *args): | |||
outputs = impl(op, *args) | |||
trace = get_trace() | |||
if trace: | |||
input_vars = tuple(map(trace.get_var, args)) | |||
output_vars = outputs and tuple(map(trace.get_var, outputs)) | |||
event = OpEvent(op, input_vars, output_vars) | |||
trace.append(event) | |||
return outputs | |||
def on_del(impl, self): | |||
trace = get_trace() | |||
if trace: | |||
var = trace.get_var(self) | |||
event = DelEvent(var) | |||
trace.append(event) | |||
return impl(self) | |||
class Trace(list): | |||
def __init__(self): | |||
self._var_id = 1 | |||
self._t2v = weakref.WeakKeyDictionary() | |||
self._v2t = weakref.WeakValueDictionary() | |||
def get_var(self, x): | |||
v = self._t2v.get(x) | |||
if v: | |||
return v | |||
v = self._var_id | |||
self._var_id += 1 | |||
self._t2v[x] = v | |||
self._v2t[v] = x | |||
return v | |||
def __bool__(self): | |||
return True | |||
def __enter__(self): | |||
global _current_trace | |||
if hasattr(self, "_prev_trace"): | |||
raise RuntimeError | |||
self._prev_trace = _current_trace | |||
_current_trace = self | |||
return self | |||
def __exit__(self, *_): | |||
global _current_trace | |||
if _current_trace is not self: | |||
raise RuntimeError | |||
_current_trace = self._prev_trace | |||
del self._prev_trace | |||
class Event: | |||
pass | |||
class InputEvent(Event): | |||
def __init__(self, var): | |||
self.var = var | |||
class ReadEvent(Event): | |||
def __init__(self, var): | |||
self.var = var | |||
class ReadDtypeEvent(ReadEvent): | |||
pass | |||
class ReadDeviceEvent(ReadEvent): | |||
pass | |||
class ReadShapeEvent(ReadEvent): | |||
pass | |||
class ReadValueEvent(ReadEvent): | |||
pass | |||
class OpEvent(Event): | |||
def __init__(self, op, inputs, outputs): | |||
self.op = op | |||
self.inputs = inputs | |||
self.outputs = outputs | |||
class DelEvent(Event): | |||
def __init__(self, var): | |||
self.var = var | |||
_current_trace = None | |||
def get_trace() -> Trace: | |||
global _current_trace | |||
return _current_trace | |||
def format_trace(trace): | |||
buf = io.StringIO() | |||
active_vars = set() | |||
def write(fmt, *args, **kwargs): | |||
print(fmt.format(*args, **kwargs), file=buf) | |||
def init_vars(*args): | |||
for i in args: | |||
if i in active_vars: | |||
continue | |||
active_vars.add(i) | |||
write("_{} = input()", i) | |||
for event in trace: | |||
if isinstance(event, InputEvent): | |||
init_vars(event.var) | |||
elif isinstance(event, ReadDtypeEvent): | |||
init_vars(event.var) | |||
write("output(_{}.dtype)", event.var) | |||
elif isinstance(event, ReadDeviceEvent): | |||
init_vars(event.var) | |||
write("output(_{}.device)", event.var) | |||
elif isinstance(event, ReadShapeEvent): | |||
init_vars(event.var) | |||
write("output(_{}.shape)", event.var) | |||
elif isinstance(event, ReadValueEvent): | |||
init_vars(event.var) | |||
write("output(_{}.dtype)", event.var) | |||
elif isinstance(event, ReadValueEvent): | |||
init_vars(event.var) | |||
write("output(_{}.value)", event.var) | |||
elif isinstance(event, OpEvent): | |||
init_vars(*event.inputs) | |||
active_vars.update(event.outputs) | |||
ovars = ", ".join(map("_{}".format, event.outputs)) | |||
ivars = ", ".join(map("_{}".format, event.inputs)) | |||
if ovars: | |||
write("{} = {}({})", ovars, repr(event.op), ivars) | |||
else: | |||
write("{}({})", repr(event.op), ivars) | |||
elif isinstance(event, DelEvent): | |||
init_vars(event.var) | |||
write("del _{}", event.var) | |||
else: | |||
raise TypeError(type(event)) | |||
return buf.getvalue() | |||
def compile_trace(trace): | |||
trace = list(trace) | |||
def static_function(f): | |||
trace = None | |||
@functools.wraps(f) | |||
def wrapper(*args, **kwargs): | |||
nonlocal trace | |||
if trace is None: | |||
with Trace() as trace: | |||
return f(*args, **kwargs) | |||
return f(*args, **kwargs) | |||
return wrapper |
@@ -0,0 +1,263 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import weakref | |||
# Concepts | |||
# | |||
# * Internal tensor | |||
# Tensor produced by the static sequence | |||
# | |||
# * External tensor | |||
# Tensor not produced, but used as input, by the static sequence | |||
# | |||
# * Irrelevant tensor | |||
# Tensor not present in input/output of any op | |||
# | |||
# * Escape | |||
# An internal tensor is said to escape if it is still alive | |||
# at the end of the sequence | |||
# JIT-ed execution | |||
# | |||
# 1. read attr (dtype, device, shape) | |||
# a. internal tensor | |||
# read out as soon as tensor is produced | |||
# b. external or irrelevant tensor | |||
# fallback | |||
# | |||
# 2. apply op | |||
# bind external tensors in input | |||
# | |||
# 3. del | |||
class Action: | |||
pass | |||
class ReadAttrAction(Action): | |||
def __init__(self, var, name, getter): | |||
self.var = var | |||
self.name = name | |||
self.getter = getter | |||
class ReadValueAction(Action): | |||
def __init__(self, var, getter): | |||
self.var = var | |||
self.getter = getter | |||
class GetTensorAction(Action): | |||
def __init__(self, var, getter): | |||
self.var = var | |||
self.getter = getter | |||
class OpAction(Action): | |||
def __init__(self, op, inputs, outputs, input_receivers): | |||
self.op = op | |||
self.inputs = inputs | |||
self.outputs = outputs | |||
self.input_receivers = input_receivers | |||
class TensorAttr: | |||
def __init__(self): | |||
self.shape = None | |||
self.dtype = None | |||
self.device = None | |||
class Bailout(Exception): | |||
pass | |||
class Fallback(Exception): | |||
pass | |||
def handle_bailout_fallback_finalize(f): | |||
@functools.wraps(f) | |||
def wrapper(self, impl, *args, **kwargs): | |||
try: | |||
return f(*args, **kwargs) | |||
except Bailout: | |||
self.bailout() | |||
except Fallback: | |||
pass | |||
finally: | |||
if self.pc == len(self): | |||
self.finalize() | |||
return impl(*args, **kwargs) | |||
return wrapper | |||
class ExecTrajectory(list): | |||
def __init__(self): | |||
super().__init__() | |||
self.reset() | |||
def __bool__(self): | |||
return True | |||
def __enter__(self): | |||
global _current_trajectory | |||
if hasattr(self, "_prev_trajectory"): | |||
raise RuntimeError | |||
self._prev_trajectory = _current_trajectory | |||
_current_trajectory = self | |||
self._exited = False | |||
return self | |||
def __exit__(self, *exc_info): | |||
# cleanup should be done at completion, | |||
# which is before exiting context manager | |||
assert self._exited == (exc_info == (None, None, None)) | |||
if not self._exited: | |||
assert self.pc < len(self) | |||
self.bailout() | |||
def _exit(self): | |||
# clean up self and global varaible | |||
assert not self._exited | |||
self.reset() | |||
global _current_trajectory | |||
if _current_trajectory is not self: | |||
raise RuntimeError | |||
_current_trajectory = self._prev_trajectory | |||
del self._prev_trajectory | |||
def reset(self): | |||
self._exited = True | |||
self.pc = 0 | |||
self.attr_cache = weakref.WeakKeyDictionary() | |||
### Internal and External Tensor ### | |||
# internal tensors are those produced by us | |||
# external tensors are those received from outside | |||
# during JIT-ed execution, internal tensors are just placeholders. | |||
# var_to_tensor is the binding table for all tensors | |||
self.var_to_tensor = {} # var -> weakref[tensor] | |||
# tensor_to_var is the reverse binding table for internal tensors | |||
# note that external tensors could map to >1 vars. | |||
self.tensor_to_var = weakref.WeakKeyDictionary() | |||
# internal tensor will be materialized if its .data is accessed from outside | |||
# after being meterialized, an intern tensor is much like an external tensor | |||
def finalize(self): | |||
assert self.pc == len(self) | |||
self._exit() | |||
def bailout(self): | |||
self._exit() | |||
raise NotImplementedError | |||
def next_action(self): | |||
assert not self._exited | |||
assert self.pc < len(self) | |||
return self[self.pc] | |||
@handle_bailout_fallback_finalize | |||
def read_attr(self, tensor, name): | |||
attrs = self.attr_cache.setdefault(tensor, TensorAttr()) | |||
value = getattr(attrs, name, None) | |||
if value is None: | |||
action = self.next_action() | |||
if not isinstance(action, ReadAttrAction): | |||
raise Bailout | |||
if name != action.name: | |||
raise Bailout | |||
value = action.getter() | |||
setattr(attrs, name, value) | |||
return value | |||
@handle_bailout_fallback_finalize | |||
def read_value(self, impl, tensor): | |||
# possibilities: | |||
# 1. internal tensor | |||
# 2. external tensor | |||
# 3. irrelevant tensor (not an input / output of any op) | |||
if tensor not in self.tensor_to_var: | |||
raise Fallback | |||
assert tensor._data is None | |||
action = self.next_action() | |||
if not isinstance(action, ReadValueAction): | |||
raise Bailout | |||
return action.getter() | |||
@handle_bailout_fallback_finalize | |||
def apply_op(self, impl, op, *args): | |||
from . import RawTensor | |||
action = self.next_action() | |||
if not isinstance(action, OpAction): | |||
raise Bailout | |||
if len(args) != len(action.inputs): | |||
raise Bailout | |||
assert len(actions.inputs) == len(action.input_receivers) | |||
for v, t, r in zip(action.inputs, args, action.input_receivers): | |||
if v in self.var_to_tensor: | |||
assert r is None | |||
if t is not self.var_to_tensor[v](): | |||
raise Bailout | |||
else: | |||
# NOTE: not checking for aliasing (>=2 vars map to 1 tensor) | |||
# the static execution backend must handle this | |||
self.var_to_tensor[v] = weakref.ref(t) | |||
r(t) | |||
outputs = [] | |||
for v in action.outputs: | |||
assert v not in self.var_to_tensor | |||
t = RawTensor() | |||
t._data_getter = functools.partial(self.get_data, v) | |||
outputs.append(t) | |||
self.var_to_tensor[v] = weakref.ref(t) | |||
return tuple(outputs) | |||
def get_data(self, var): | |||
tensor = self.var_to_tensor[var]() | |||
assert tensor is not None | |||
assert tensor._data is None | |||
assert tensor in self.tensor_to_var | |||
action = self.next_action() | |||
if not isinstance(action, GetTensorAction): | |||
self.bailout() | |||
elif action.var != var: | |||
self.bailout() | |||
else: | |||
tensor._data = action.getter() | |||
del tensor._data_getter | |||
del self.tensor_to_var[tensor] | |||
assert "_data_getter" not in tensor.__dict__ | |||
return tensor._data_getter() | |||
_current_trajectory = None | |||
def get_trajectory(): | |||
return _current_trajectory | |||
def compile_trace(trace): | |||
from .jit import ReadDTypeEvent, ReadDeviceEvent, ReadShapeEvent, OpEvent, DelEvent | |||
traj = ExecutionTrajectory() | |||
active_vars = set() | |||
for event in trace: | |||
if isinstance(event, ReadDTypeEvent): | |||
traj.append(ReadAttrAction()) |
@@ -0,0 +1,106 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import copy | |||
from .core import Dispatcher, OpBase, TensorBase, apply | |||
class Tensor(TensorBase): | |||
def __init__(self, data: TensorBase): | |||
self._data = data | |||
# _extra_data is set up in Grad.wrt | |||
self._extra_data = {} | |||
self._user_data = {} | |||
def __getattr__(self, name): | |||
if name in self._user_data: | |||
return self._user_data[name] | |||
raise AttributeError(name) | |||
def reset(self, other): | |||
assert isinstance(other, __class__) | |||
self.__dict__.clear() | |||
self._data = other.data | |||
self._extra_data = other._extra_data.copy() | |||
self._user_data = other._user_data.copy() | |||
def copy(self): | |||
other = object.__new__(type(self)) | |||
other.reset(self) | |||
return other | |||
# tensor interface | |||
@property | |||
def shape(self): | |||
return self._data.shape | |||
@property | |||
def dtype(self): | |||
return self._data.dtype | |||
@property | |||
def device(self): | |||
return self._data.device | |||
def numpy(self): | |||
return self._data.numpy() | |||
class ApplyContext: | |||
def __init__(self): | |||
self.inputs = None | |||
self.outputs = None | |||
self.key = None | |||
_context = None | |||
@contextlib.contextmanager | |||
def push_context(): | |||
global _context | |||
backup = _context | |||
try: | |||
_context = ApplyContext() | |||
yield _context | |||
finally: | |||
_context = backup | |||
def get_context(): | |||
return _context | |||
@apply.add | |||
def tensor_apply(op: OpBase, *args: Tensor): | |||
data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | |||
# type(Tensor._data) is RawTensor | |||
# dispached to apply.add@RawTensor.py if passed Tensor args | |||
outputs = apply(op, *data) | |||
ret = tuple(map(Tensor, outputs)) | |||
with push_context() as ctx: | |||
ctx.inputs = args | |||
ctx.outputs = ret | |||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
ctx.key = k | |||
data = tuple( | |||
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | |||
) | |||
# data are instances of Tracer | |||
# dispatched to apply.add@grad.py | |||
outputs = apply(op, *data) | |||
if outputs is not None: | |||
assert len(outputs) == len(ret) | |||
for t, i in zip(ret, outputs): | |||
t._extra_data[k] = i | |||
return ret |
@@ -0,0 +1,367 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import abc | |||
import collections | |||
import numpy as np | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from . import utils | |||
from .core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from .indexing import getitem as _getitem | |||
from .indexing import setitem as _setitem | |||
from .raw_tensor import RawTensor, as_raw_tensor | |||
from .tensor import Tensor | |||
def _elwise(*args, mode): | |||
op = builtin.Elemwise(mode=mode) | |||
args = utils.convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
def _matmul(inp1, inp2): | |||
op = builtin.MatrixMul( | |||
transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT" | |||
) | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
(result,) = apply(op, inp1, inp2) | |||
return result | |||
def _transpose(data, axes): | |||
op = builtin.Dimshuffle(axes) | |||
(data,) = utils.convert_inputs(data) | |||
(result,) = apply(op, data) | |||
return result | |||
def _broadcast(inp, shape): | |||
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||
return result | |||
def _reshape(x, shape): | |||
if isinstance(shape, (TensorBase, TensorWrapperBase)): | |||
shape = shape.numpy() | |||
shape = tuple(map(int, shape)) | |||
unspec_axis = None | |||
for i, s in enumerate(shape): | |||
if s < 0: | |||
if s != -1: | |||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
if unspec_axis is not None: | |||
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
unspec_axis = i | |||
# TODO: device should be None (cpu) | |||
(shape,) = Const(shape, dtype=np.int32, device=x.device)(x) | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
else: | |||
op = builtin.Reshape(unspec_axis=unspec_axis) | |||
(x,) = apply(op, x, shape) | |||
return x | |||
def _unary_elwise(mode): | |||
def f(self): | |||
return _elwise(self, mode=mode) | |||
return f | |||
def _binary_elwise(mode, rev=False): | |||
if not rev: | |||
def f(self, value): | |||
return _elwise(self, value, mode=mode) | |||
else: | |||
def f(self, value): | |||
return _elwise(value, self, mode=mode) | |||
return f | |||
def _logical_unary_elwise(mode, rev=False): | |||
def f(self): | |||
if self.dtype != np.bool_: | |||
raise TypeError("{} requires a bool tensor".format(mode)) | |||
return _elwise(self, mode=mode) | |||
return f | |||
def _logical_binary_elwise(mode, rev=False): | |||
if not rev: | |||
def f(self, value): | |||
if self.dtype != np.bool_ or value.dtype != np.bool_: | |||
raise TypeError("{} requires 2 bool tensors".format(mode)) | |||
return _elwise(self, value, mode=mode) | |||
else: | |||
def f(self, value): | |||
if self.dtype != np.bool_ or value.dtype != np.bool_: | |||
raise TypeError("{} requires 2 bool tensors".format(mode)) | |||
return _elwise(value, self, mode=mode) | |||
return f | |||
def _reduce(mode): | |||
def f(self, axis=None): | |||
inp = self | |||
if axis is None: | |||
inp = self.flatten() | |||
axis = 0 | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = utils.convert_inputs(inp) | |||
(result,) = apply(op, result) | |||
return result | |||
return f | |||
def _inplace(f): | |||
def g(self, value): | |||
result = f(self, value) | |||
if result is NotImplemented: | |||
raise NotImplementedError | |||
self._reset(result) | |||
return self | |||
return g | |||
def _todo(*_): | |||
raise NotImplementedError | |||
class ArrayMethodMixin(abc.ABC): | |||
__array_priority__ = 233333 | |||
@abc.abstractmethod | |||
def _reset(self, other): | |||
pass | |||
@abc.abstractproperty | |||
def dtype(self) -> np.dtype: | |||
pass | |||
@abc.abstractproperty | |||
def shape(self) -> tuple: | |||
pass | |||
@abc.abstractmethod | |||
def numpy(self) -> np.ndarray: | |||
pass | |||
__hash__ = None # due to __eq__ diviates from python convention | |||
__lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") | |||
__le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") | |||
__gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") | |||
__ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") | |||
__eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") | |||
__ne__ = lambda self, value: _elwise( | |||
_elwise(self, value, mode="EQ").astype("bool"), mode="NOT" | |||
) | |||
__neg__ = _unary_elwise("NEGATE") | |||
__pos__ = lambda self: self | |||
__abs__ = _unary_elwise("ABS") | |||
__invert__ = _logical_unary_elwise("NOT") | |||
__round__ = _unary_elwise("ROUND") | |||
__trunc__ = _todo | |||
__floor__ = _unary_elwise("FLOOR") | |||
__ceil__ = _unary_elwise("CEIL") | |||
__add__ = _binary_elwise("ADD") | |||
__sub__ = _binary_elwise("SUB") | |||
__mul__ = _binary_elwise("MUL") | |||
__matmul__ = lambda self, other: _matmul(self, other) | |||
__truediv__ = _binary_elwise("TRUE_DIV") | |||
__floordiv__ = _binary_elwise("FLOOR_DIV") | |||
__mod__ = _binary_elwise("MOD") | |||
# __divmode__ | |||
__pow__ = _binary_elwise("POW") | |||
__lshift__ = _binary_elwise("SHL") | |||
__rshift__ = _binary_elwise("SHR") | |||
__and__ = _logical_binary_elwise("AND") | |||
__or__ = _logical_binary_elwise("OR") | |||
__xor__ = _logical_binary_elwise("XOR") | |||
__radd__ = _binary_elwise("ADD", rev=1) | |||
__rsub__ = _binary_elwise("SUB", rev=1) | |||
__rmul__ = _binary_elwise("MUL", rev=1) | |||
__rmatmul__ = lambda self, other: _matmul(other, self) | |||
__rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) | |||
__rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) | |||
__rmod__ = _binary_elwise("MOD", rev=1) | |||
# __rdivmode__ | |||
__rpow__ = _binary_elwise("POW", rev=1) | |||
__rlshift__ = _binary_elwise("SHL", rev=1) | |||
__rrshift__ = _binary_elwise("SHR", rev=1) | |||
__rand__ = _logical_binary_elwise("AND", rev=1) | |||
__ror__ = _logical_binary_elwise("OR", rev=1) | |||
__rxor__ = _logical_binary_elwise("XOR", rev=1) | |||
__iadd__ = _inplace(__add__) | |||
__isub__ = _inplace(__sub__) | |||
__imul__ = _inplace(__mul__) | |||
__imatmul__ = _inplace(__matmul__) | |||
__itruediv__ = _inplace(__truediv__) | |||
__ifloordiv__ = _inplace(__floordiv__) | |||
__imod__ = _inplace(__mod__) | |||
__ipow__ = _inplace(__pow__) | |||
__ilshift__ = _inplace(__lshift__) | |||
__irshift__ = _inplace(__rshift__) | |||
__iand__ = _inplace(__and__) | |||
__ior__ = _inplace(__or__) | |||
__ixor__ = _inplace(__xor__) | |||
__index__ = lambda self: self.item().__index__() | |||
__bool__ = lambda self: bool(self.item()) | |||
__int__ = lambda self: int(self.item()) | |||
__float__ = lambda self: float(self.item()) | |||
__complex__ = lambda self: complex(self.item()) | |||
def __len__(self): | |||
shape = self.shape | |||
if shape: | |||
return int(shape[0]) | |||
raise TypeError("ndim is 0") | |||
def __iter__(self): | |||
for i in range(len(self)): | |||
yield self[i] | |||
def __getitem__(self, index): | |||
return _getitem(self, index) | |||
def __setitem__(self, index, value): | |||
if index is not Ellipsis: | |||
value = _setitem(self, index, value) | |||
self._reset(value) | |||
__contains__ = _todo | |||
@property | |||
def ndim(self): | |||
return len(self.shape) | |||
@property | |||
def size(self): | |||
return np.prod(self.shape).item() | |||
@property | |||
def T(self): | |||
return self.transpose() | |||
def item(self, *args): | |||
if not args: | |||
assert self.size == 1 | |||
return self.numpy().item() | |||
return self[args].item() | |||
def tolist(self): | |||
return self.numpy().tolist() | |||
def astype(self, dtype): | |||
return utils.astype(self, dtype) | |||
def reshape(self, *args): | |||
if len(args) == 1: | |||
if isinstance(args[0], collections.Sequence): | |||
args = args[0] | |||
return _reshape(self, args) | |||
def broadcast(self, *args): | |||
if len(args) == 1: | |||
if isinstance(args[0], collections.Sequence): | |||
args = args[0] | |||
return _broadcast(self, args) | |||
def transpose(self, *args): | |||
if not args: | |||
args = reversed(range(self.ndim)) | |||
elif len(args) == 1: | |||
if isinstance(args[0], collections.Sequence): | |||
args = args[0] | |||
return _transpose(self, args) | |||
def flatten(self): | |||
return self.reshape(-1) | |||
sum = _reduce("SUM") | |||
prod = _reduce("PRODUCT") | |||
min = _reduce("MIN") | |||
max = _reduce("MAX") | |||
mean = _reduce("MEAN") | |||
class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
def __init__(self, data): | |||
self.__wrapped__ = data | |||
def _reset(self, other): | |||
if not isinstance(other, __class__): | |||
raise TypeError(type(other)) | |||
self.__wrapped__ = other.__wrapped__ | |||
return self | |||
@property | |||
def dtype(self): | |||
return self.__wrapped__.dtype | |||
@property | |||
def shape(self): | |||
return self.__wrapped__.shape | |||
@property | |||
def device(self): | |||
return self.__wrapped__.device | |||
def numpy(self): | |||
return self.__wrapped__.numpy() | |||
class TensorWrapper(GenericTensorWrapper): | |||
def __init__(self, data, dtype=None, device=None): | |||
if isinstance(data, TensorWrapperBase): | |||
data = data.__wrapped__ | |||
elif not isinstance(data, TensorBase): | |||
assert data is not None, "Cannot init a tensor with data as None" | |||
data = Tensor(as_raw_tensor(data, dtype=dtype, device=device)) | |||
super().__init__(data) | |||
def _reset(self, other): | |||
if isinstance(other, TensorWrapperBase): | |||
self.__wrapped__ = other.__wrapped__ | |||
elif isinstance(other, TensorBase): | |||
self.__wrapped__ = other | |||
else: | |||
self._reset(type(self)(other, dtype=self.dtype, device=self.device)) | |||
def __repr__(self): | |||
piece = "Tensor(" | |||
with np.printoptions(precision=4, suppress=True): | |||
piece += "{}".format(str(self.numpy())) | |||
if self.dtype != np.float32: | |||
piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
piece += ", device={}".format(self.device) + ")" | |||
return piece |
@@ -0,0 +1,154 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Iterable, Union | |||
import numpy as np | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
def dtype_promotion(raw_inputs): | |||
def add_dtype(i): | |||
if type(i) == int: | |||
return np.array(i, dtype=np.int32) | |||
if type(i) == float: | |||
return np.array(i, dtype=np.float32) | |||
if type(i) == bool: | |||
return np.array(i, dtype=np.bool_) | |||
return None | |||
scalar_inputs = [ | |||
add_dtype(i) for i in raw_inputs if not hasattr(i, "dtype") and add_dtype(i) | |||
] | |||
inputs = [i for i in raw_inputs if hasattr(i, "dtype")] | |||
assert len(scalar_inputs + inputs) > 0 | |||
dtype = np.result_type(*inputs) | |||
dtype_all = np.result_type(*(inputs + scalar_inputs)) | |||
assert ( | |||
dtype != np.float64 and dtype != np.int64 | |||
), "unsupport dtype {} by dtype_promotion, please use explict type convert".format( | |||
dtype | |||
) | |||
if dtype_all == np.bool_: | |||
for i in raw_inputs: | |||
if not hasattr(i, "dtype") or i.dtype != np.bool_: | |||
raise TypeError( | |||
"bool dtype can not be operated with an element without bool dtype" | |||
) | |||
if dtype_all == np.float64: | |||
dtype_all = np.float32 | |||
return dtype_all | |||
def get_device(inputs): | |||
device = None | |||
for i in inputs: | |||
if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
if device is None: | |||
device = i.device | |||
elif device != i.device: | |||
raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||
assert device is not None | |||
return device | |||
def concatenate(inputs, axis=0, *, device=None): | |||
dtype = dtype_promotion(inputs) | |||
device = get_device(inputs) | |||
def convert(x): | |||
return convert_single_value(x, inputs, dtype=dtype) | |||
inputs = tuple(map(convert, inputs)) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inputs) | |||
return result | |||
def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if x.dtype != dtype: | |||
(x,) = apply(builtin.TypeCvt(param=dtype), x) | |||
return x | |||
def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
tensors = [i for i in inputs if isinstance(i, (TensorBase, TensorWrapperBase))] | |||
assert len(tensors) > 0 | |||
if isinstance(v, (TensorWrapperBase, TensorBase)): | |||
v = astype(v, dtype) | |||
else: | |||
(v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||
return v | |||
def convert_inputs(*args: TensorBase): | |||
dtype = dtype_promotion(args) | |||
device = get_device(args) | |||
def convert(value): | |||
if value is None: | |||
return value | |||
return convert_single_value(value, args, dtype=dtype, device=device) | |||
return tuple(map(convert, args)) | |||
def result_type(*args): | |||
dtypes = [] | |||
for i in args: | |||
if isinstance(i, (TensorWrapperBase, TensorBase)): | |||
dtypes.append(i.dtype) | |||
continue | |||
try: | |||
dtypes.append(np.dtype(i)) | |||
except TypeError: | |||
pass | |||
return np.result_type(*dtypes) | |||
def isscalar(x): | |||
try: | |||
return x.ndim == 0 | |||
except: | |||
pass | |||
return np.isscalar(x) | |||
def astensor1d(x, *reference, dtype=None, device=None): | |||
""" | |||
Convert something to 1D tensor. Support following types | |||
* sequence of scalar literal / tensor | |||
* numpy array | |||
* tensor (returned as is, regardless of dtype and device) | |||
""" | |||
try: | |||
ndim = x.ndim | |||
except AttributeError: | |||
pass | |||
else: | |||
if ndim != 1: | |||
raise ValueError("ndim != 1: %d" % ndim) | |||
if not isinstance(x, (TensorBase, TensorWrapperBase)): | |||
(x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
return x | |||
if not isinstance(x, collections.Sequence): | |||
raise TypeError | |||
if any(isinstance(i, (TensorBase, TensorWrapperBase)) for i in x): | |||
x = concatenate(x, device=device) | |||
if dtype is not None: | |||
x = astype(x, dtype) | |||
return x | |||
(x,) = Const(x, dtype=dtype, device=device)(*reference) | |||
return x |
@@ -0,0 +1,17 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .collator import Collator | |||
from .dataloader import DataLoader | |||
from .sampler import ( | |||
Infinite, | |||
RandomSampler, | |||
ReplacementSampler, | |||
Sampler, | |||
SequentialSampler, | |||
) |
@@ -0,0 +1,139 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import binascii | |||
import os | |||
import queue | |||
import subprocess | |||
from multiprocessing import Queue | |||
import pyarrow | |||
import pyarrow.plasma as plasma | |||
MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB | |||
# Each process only need to start one plasma store, so we set it as a global variable. | |||
# TODO: how to share between different processes? | |||
MGE_PLASMA_STORE_MANAGER = None | |||
def _clear_plasma_store(): | |||
# `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, | |||
# so this function should be called explicitly | |||
global MGE_PLASMA_STORE_MANAGER | |||
if MGE_PLASMA_STORE_MANAGER is not None: | |||
del MGE_PLASMA_STORE_MANAGER | |||
MGE_PLASMA_STORE_MANAGER = None | |||
class _PlasmaStoreManager: | |||
__initialized = False | |||
def __init__(self): | |||
self.socket_name = "/tmp/mge_plasma_{}".format( | |||
binascii.hexlify(os.urandom(8)).decode() | |||
) | |||
debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) | |||
# NOTE: this is a hack. Directly use `plasma_store` may make subprocess | |||
# difficult to handle the exception happened in `plasma-store-server`. | |||
# For `plasma_store` is just a wrapper of `plasma-store-server`, which use | |||
# `os.execv` to call the executable `plasma-store-server`. | |||
cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") | |||
self.plasma_store = subprocess.Popen( | |||
[cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], | |||
stdout=None if debug_flag else subprocess.DEVNULL, | |||
stderr=None if debug_flag else subprocess.DEVNULL, | |||
) | |||
self.__initialized = True | |||
def __del__(self): | |||
if self.__initialized and self.plasma_store.returncode is None: | |||
self.plasma_store.kill() | |||
class PlasmaShmQueue: | |||
def __init__(self, maxsize: int = 0): | |||
r"""Use pyarrow in-memory plasma store to implement shared memory queue. | |||
Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle | |||
and communication overhead, leading to better performance in multi-process | |||
application. | |||
:type maxsize: int | |||
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) | |||
""" | |||
# Lazy start the plasma store manager | |||
global MGE_PLASMA_STORE_MANAGER | |||
if MGE_PLASMA_STORE_MANAGER is None: | |||
try: | |||
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() | |||
except Exception as e: | |||
err_info = ( | |||
"Please make sure pyarrow installed correctly!\n" | |||
"You can try reinstall pyarrow and see if you can run " | |||
"`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." | |||
) | |||
raise RuntimeError( | |||
"Exception happened in starting plasma_store: {}\n" | |||
"Tips: {}".format(str(e), err_info) | |||
) | |||
self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name | |||
# TODO: how to catch the exception happened in `plasma.connect`? | |||
self.client = None | |||
# Used to store the header for the data.(ObjectIDs) | |||
self.queue = Queue(maxsize) # type: Queue | |||
def put(self, data, block=True, timeout=None): | |||
if self.client is None: | |||
self.client = plasma.connect(self.socket_name) | |||
try: | |||
object_id = self.client.put(data) | |||
except plasma.PlasmaStoreFull: | |||
raise RuntimeError("plasma store out of memory!") | |||
try: | |||
self.queue.put(object_id, block, timeout) | |||
except queue.Full: | |||
self.client.delete([object_id]) | |||
raise queue.Full | |||
def get(self, block=True, timeout=None): | |||
if self.client is None: | |||
self.client = plasma.connect(self.socket_name) | |||
object_id = self.queue.get(block, timeout) | |||
if not self.client.contains(object_id): | |||
raise RuntimeError( | |||
"ObjectID: {} not found in plasma store".format(object_id) | |||
) | |||
data = self.client.get(object_id) | |||
self.client.delete([object_id]) | |||
return data | |||
def qsize(self): | |||
return self.queue.qsize() | |||
def empty(self): | |||
return self.queue.empty() | |||
def join(self): | |||
self.queue.join() | |||
def disconnect_client(self): | |||
if self.client is not None: | |||
self.client.disconnect() | |||
def close(self): | |||
self.queue.close() | |||
self.disconnect_client() | |||
_clear_plasma_store() | |||
def cancel_join_thread(self): | |||
self.queue.cancel_join_thread() |
@@ -0,0 +1,76 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |||
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |||
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||
# Copyright (c) 2011-2013 NYU (Clement Farabet) | |||
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# ---------------------------------------------------------------------- | |||
import collections.abc | |||
import re | |||
import numpy as np | |||
np_str_obj_array_pattern = re.compile(r"[aO]") | |||
default_collate_err_msg_format = ( | |||
"default_collator: inputs must contain numpy arrays, numbers, " | |||
"Unicode strings, bytes, dicts or lists; found {}" | |||
) | |||
class Collator: | |||
r""" | |||
Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. | |||
modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py | |||
""" | |||
def apply(self, inputs): | |||
""" | |||
input : sequence_N(tuple(CHW, C, CK)) | |||
output : tuple(NCHW, NC, NCK) | |||
""" | |||
elem = inputs[0] | |||
elem_type = type(elem) | |||
if ( | |||
elem_type.__module__ == "numpy" | |||
and elem_type.__name__ != "str_" | |||
and elem_type.__name__ != "string_" | |||
): | |||
elem = inputs[0] | |||
if elem_type.__name__ == "ndarray": | |||
# array of string classes and object | |||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |||
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |||
return np.ascontiguousarray(np.stack(inputs)) | |||
elif elem.shape == (): # scalars | |||
return np.array(inputs) | |||
elif isinstance(elem, float): | |||
return np.array(inputs, dtype=np.float64) | |||
elif isinstance(elem, int): | |||
return np.array(inputs) | |||
elif isinstance(elem, (str, bytes)): | |||
return inputs | |||
elif isinstance(elem, collections.abc.Mapping): | |||
return {key: self.apply([d[key] for d in inputs]) for key in elem} | |||
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | |||
return elem_type(*(self.apply(samples) for samples in zip(*inputs))) | |||
elif isinstance(elem, collections.abc.Sequence): | |||
transposed = zip(*inputs) | |||
return [self.apply(samples) for samples in transposed] | |||
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
@@ -0,0 +1,500 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import math | |||
import multiprocessing | |||
import queue | |||
import random | |||
import time | |||
import numpy as np | |||
from ..logger import get_logger | |||
from ..random.rng import _random_seed_generator | |||
from .collator import Collator | |||
from .dataset import Dataset | |||
from .sampler import Sampler, SequentialSampler | |||
from .transform import PseudoTransform, Transform | |||
logger = get_logger(__name__) | |||
MP_QUEUE_GET_TIMEOUT = 5 | |||
class DataLoader: | |||
__initialized = False | |||
def __init__( | |||
self, | |||
dataset: Dataset, | |||
sampler: Sampler = None, | |||
transform: Transform = None, | |||
collator: Collator = None, | |||
num_workers: int = 0, | |||
timeout: int = 0, | |||
divide: bool = False, | |||
): | |||
r"""Provides a convenient way to iterate on a given dataset. | |||
`DataLoader` combines a dataset with sampler, transform and collator, | |||
make it flexible to get minibatch continually from a dataset. | |||
:type dataset: Dataset | |||
:param dataset: dataset from which to load the minibatch. | |||
:type sampler: Sampler | |||
:param sampler: defines the strategy to sample data from the dataset. | |||
If specified, :attr:`shuffle` must be ``False``. | |||
:type transform: Transform | |||
:param transform: defined the transforming strategy for a sampled batch. | |||
(default: ``None``) | |||
:type collator: Collator | |||
:param collator: defined the merging strategy for a transformed batch. | |||
(default: ``None``) | |||
:type num_workers: int | |||
:param num_workers: the number of sub-process to load, transform and collate | |||
the batch. ``0`` means using single-process. (default: ``0``) | |||
:type timeout: int | |||
:param timeout: if positive, means the timeout value(second) for collecting a | |||
batch from workers. (default: 0) | |||
:type divide: bool | |||
:param divide: define the paralleling strategy in multi-processing mode. | |||
``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
the workers will process these pieces parallelly. ``False`` means | |||
different sub-process will process different batch. (default: ``False``) | |||
""" | |||
if num_workers < 0: | |||
raise ValueError("num_workers should not be negative") | |||
if timeout < 0: | |||
raise ValueError("timeout should not be negative") | |||
if divide and num_workers <= 1: | |||
raise ValueError("divide should not be set to True when num_workers <= 1") | |||
self.dataset = dataset | |||
self.num_workers = num_workers | |||
self.timeout = timeout | |||
self.divide = divide | |||
if sampler is None: | |||
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) | |||
else: | |||
self.sampler = sampler | |||
if divide: | |||
if self.sampler.batch_size <= self.num_workers: | |||
raise ValueError( | |||
"batch size must not smaller than num_workers in divide mode." | |||
) | |||
elif self.sampler.batch_size % self.num_workers: | |||
logger.warning( | |||
"batch size is not divisible by num_workers, may lose performance in divide mode." | |||
) | |||
if transform is None: | |||
self.transform = PseudoTransform() | |||
else: | |||
self.transform = transform | |||
if collator is None: | |||
self.collator = Collator() | |||
else: | |||
self.collator = collator | |||
self.__initialized = True | |||
def __iter__(self): | |||
if self.num_workers == 0: | |||
return _SerialDataLoaderIter(self) | |||
else: | |||
return _ParallelDataLoaderIter(self) | |||
def __len__(self): | |||
return len(self.sampler) | |||
class _BaseDataLoaderIter: | |||
def __init__(self, loader): | |||
self.dataset = loader.dataset | |||
self.sampler = loader.sampler | |||
self.seed = _random_seed_generator().__next__() | |||
self.transform = loader.transform | |||
self.collator = loader.collator | |||
self.num_workers = loader.num_workers | |||
self.timeout = loader.timeout | |||
self.divide = loader.divide | |||
self.num_processed = 0 | |||
def _get_next_batch(self): | |||
raise NotImplementedError | |||
def __len__(self): | |||
return len(self.sampler) | |||
def __iter__(self): | |||
return self | |||
def __next__(self): | |||
if self.num_processed >= len(self): | |||
raise StopIteration | |||
minibatch = self._get_next_batch() | |||
self.num_processed += 1 | |||
return minibatch | |||
class _SerialDataLoaderIter(_BaseDataLoaderIter): | |||
def __init__(self, loader): | |||
super(_SerialDataLoaderIter, self).__init__(loader) | |||
self.indices_iter = iter(self.sampler) | |||
def _get_next_batch(self): | |||
indices = next(self.indices_iter) | |||
items = [self.dataset[idx] for idx in indices] | |||
trans_items = self.transform.apply_batch(items) | |||
return self.collator.apply(trans_items) | |||
class _ParallelDataLoaderIter(_BaseDataLoaderIter): | |||
__initialized = False | |||
def __init__(self, loader): | |||
super(_ParallelDataLoaderIter, self).__init__(loader) | |||
self.task_queues = [ | |||
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
] | |||
self.feed_batch_idx = multiprocessing.Value("i", 0) | |||
self.target_batch_idx = multiprocessing.Value("i", 0) | |||
self.shutdown_flag = multiprocessing.Value("i", 0) | |||
self.trans_data_queues = [ | |||
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) | |||
] | |||
# use shared-memory queue implemented by pyarrow plasma store. | |||
from ._queue import PlasmaShmQueue | |||
self.batch_queue = PlasmaShmQueue(maxsize=2) | |||
self.task_feeding_worker = multiprocessing.Process( | |||
target=_task_feeding_loop, | |||
args=( | |||
iter(self.sampler), | |||
self.task_queues, | |||
self.num_workers, | |||
self.divide, | |||
self.shutdown_flag, | |||
self.feed_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
self.task_feeding_worker.start() | |||
self.workers = [] | |||
for worker_id in range(self.num_workers): | |||
worker = multiprocessing.Process( | |||
target=_worker_loop, | |||
args=( | |||
self.dataset, | |||
self.task_queues[worker_id], | |||
self.trans_data_queues[worker_id], | |||
self.transform, | |||
self.seed + worker_id + 1, | |||
self.shutdown_flag, | |||
), | |||
daemon=True, | |||
) | |||
worker.start() | |||
self.workers.append(worker) | |||
if self.divide: | |||
self.data_collecting_worker = multiprocessing.Process( | |||
target=_data_gathering_loop, | |||
args=( | |||
self.trans_data_queues, | |||
self.batch_queue, | |||
self.collator, | |||
len(self), | |||
self.num_workers, | |||
self.shutdown_flag, | |||
self.target_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
else: | |||
self.data_collecting_worker = multiprocessing.Process( | |||
target=_data_selecting_loop, | |||
args=( | |||
self.trans_data_queues, | |||
self.batch_queue, | |||
self.collator, | |||
len(self), | |||
self.num_workers, | |||
self.shutdown_flag, | |||
self.target_batch_idx, | |||
), | |||
daemon=True, | |||
) | |||
self.data_collecting_worker.start() | |||
self.__initialized = True | |||
def _check_workers(self): | |||
# Check the status of each worker. | |||
if not self.data_collecting_worker.is_alive(): | |||
exitcode = self.task_feeding_worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("data collecting worker died. {}".format(exitcode)) | |||
if not self.task_feeding_worker.is_alive(): | |||
exitcode = self.task_feeding_worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("task feeding worker died. {}".format(exitcode)) | |||
for worker_id, worker in enumerate(self.workers): | |||
if not worker.is_alive(): | |||
exitcode = worker.exitcode | |||
if exitcode != 0: | |||
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) | |||
logger.debug("all workers are alive.") | |||
def _try_get_next_batch(self): | |||
start_time = time.time() | |||
while True: | |||
self._check_workers() | |||
try: | |||
return self.batch_queue.get(timeout=1) | |||
except queue.Empty: | |||
logger.debug("batch queue empty!") | |||
waited_time = time.time() - start_time | |||
if self.timeout > 0: | |||
if waited_time > self.timeout: | |||
raise RuntimeError("get_next_batch timeout!") | |||
def _get_next_batch(self): | |||
batch_data = self._try_get_next_batch() | |||
return batch_data | |||
def _shutdown(self): | |||
with self.shutdown_flag.get_lock(): | |||
self.shutdown_flag.value = 1 | |||
if self.task_feeding_worker.is_alive(): | |||
self.task_feeding_worker.terminate() | |||
self.task_feeding_worker.join() | |||
if self.data_collecting_worker.is_alive(): | |||
self.data_collecting_worker.terminate() | |||
self.data_collecting_worker.join() | |||
for worker in self.workers: | |||
if worker.is_alive(): | |||
worker.terminate() | |||
worker.join() | |||
for q in self.trans_data_queues: | |||
q.cancel_join_thread() | |||
q.close() | |||
for q in self.task_queues: | |||
q.cancel_join_thread() | |||
q.close() | |||
self.batch_queue.cancel_join_thread() | |||
self.batch_queue.close() | |||
def __del__(self): | |||
if self.__initialized: | |||
self._shutdown() | |||
def _task_feeding_loop( | |||
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx | |||
): | |||
# Feed the indices into the task queues | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
batch_idx = feed_batch_idx.value | |||
try: | |||
indices = next(indices_iter) | |||
except StopIteration: | |||
break | |||
if divide: | |||
# make sure all task_queues is ready for put | |||
while any([q.full() for q in task_queues]): | |||
if shutdown_flag.value == 1: | |||
return | |||
# divide into small pieces, feed to different workers. | |||
sub_num = math.ceil(len(indices) / num_workers) | |||
for worker_id in range(num_workers): | |||
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] | |||
task_queues[worker_id].put((batch_idx, sub_indices)) | |||
else: | |||
# distribute tasks to different workers uniformly. | |||
target_id = batch_idx % num_workers | |||
while task_queues[target_id].full(): | |||
if shutdown_flag.value == 1: | |||
return | |||
task_queues[target_id].put((batch_idx, indices)) | |||
with feed_batch_idx.get_lock(): | |||
feed_batch_idx.value += 1 | |||
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): | |||
# Get dataset items and do the transform | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
try: | |||
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) | |||
except queue.Empty: | |||
continue | |||
if len(indices) > 0: | |||
items = [dataset[idx] for idx in indices] | |||
trans_items = transform.apply_batch(items) | |||
else: | |||
# in case of incomplete last batch | |||
trans_items = () | |||
while True: | |||
try: | |||
trans_data_queue.put((batch_idx, trans_items), timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch part queue is full!") | |||
def _data_gathering_loop( | |||
trans_data_queues, | |||
batch_queue, | |||
collator, | |||
length, | |||
num_workers, | |||
shutdown_flag, | |||
target_idx, | |||
): | |||
# Gathering the small pieces of batch data into full batch data | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
target_batch_idx = target_idx.value | |||
if target_batch_idx >= length: | |||
break | |||
full_trans_items = [] | |||
for worker_id in range(num_workers): | |||
while True: | |||
try: | |||
batch_idx, trans_items = trans_data_queues[worker_id].get( | |||
timeout=MP_QUEUE_GET_TIMEOUT | |||
) | |||
break | |||
except queue.Empty: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug( | |||
"worker:{} data queue get timeout! target batch idx:{}".format( | |||
worker_id, target_batch_idx | |||
) | |||
) | |||
if batch_idx != target_batch_idx: | |||
raise RuntimeError( | |||
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format( | |||
worker_id | |||
) | |||
) | |||
else: | |||
full_trans_items.extend(trans_items) | |||
# Merge different parts into a batch. | |||
full_batch = collator.apply(full_trans_items) | |||
while True: | |||
try: | |||
batch_queue.put(full_batch, timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch queue is full!") | |||
with target_idx.get_lock(): | |||
target_idx.value += 1 | |||
batch_queue.disconnect_client() | |||
def _data_selecting_loop( | |||
trans_data_queues, | |||
batch_queue, | |||
collator, | |||
length, | |||
num_workers, | |||
shutdown_flag, | |||
target_idx, | |||
): | |||
# Make sure that batch is generated exactly with the same order as generated indices | |||
while True: | |||
if shutdown_flag.value == 1: | |||
break | |||
target_batch_idx = target_idx.value | |||
if target_batch_idx >= length: | |||
break | |||
target_worker_id = target_batch_idx % num_workers | |||
while True: | |||
try: | |||
batch_idx, trans_items = trans_data_queues[target_worker_id].get( | |||
timeout=MP_QUEUE_GET_TIMEOUT | |||
) | |||
batch_data = collator.apply(trans_items) | |||
break | |||
except queue.Empty: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug( | |||
"worker:{} data queue get timeout! target batch idx:{}".format( | |||
target_worker_id, target_batch_idx | |||
) | |||
) | |||
if batch_idx != target_batch_idx: | |||
raise RuntimeError( | |||
"batch_idx {} mismatch the target_batch_idx {}".format( | |||
batch_idx, target_batch_idx | |||
) | |||
) | |||
while True: | |||
try: | |||
batch_queue.put(batch_data, timeout=1) | |||
break | |||
except queue.Full: | |||
if shutdown_flag.value == 1: | |||
break | |||
logger.debug("batch queue is full!") | |||
with target_idx.get_lock(): | |||
target_idx.value += 1 | |||
batch_queue.disconnect_client() |
@@ -0,0 +1,10 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||
from .vision import * |
@@ -0,0 +1,73 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABC, abstractmethod | |||
from typing import Tuple | |||
class Dataset(ABC): | |||
r""" | |||
An abstract class for all Datasets | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
class MapDataset(Dataset): | |||
r""" | |||
An abstract class for map data | |||
__getitem__ and __len__ method are aditionally needed | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
@abstractmethod | |||
def __getitem__(self, index): | |||
pass | |||
@abstractmethod | |||
def __len__(self): | |||
pass | |||
class StreamDataset(Dataset): | |||
r""" | |||
An abstract class for stream data | |||
__iter__ method is aditionally needed | |||
""" | |||
@abstractmethod | |||
def __init__(self): | |||
pass | |||
@abstractmethod | |||
def __iter__(self): | |||
pass | |||
class ArrayDataset(MapDataset): | |||
def __init__(self, *arrays): | |||
r""" | |||
ArrayDataset is a dataset for numpy array data, one or more numpy arrays | |||
are needed to initiate the dataset. And the dimensions represented sample number | |||
are expected to be the same. | |||
""" | |||
super().__init__() | |||
if not all(len(arrays[0]) == len(array) for array in arrays): | |||
raise ValueError("lengths of input arrays are inconsistent") | |||
self.arrays = arrays | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) |
@@ -0,0 +1,17 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .cifar import CIFAR10, CIFAR100 | |||
from .cityscapes import Cityscapes | |||
from .coco import COCO | |||
from .folder import ImageFolder | |||
from .imagenet import ImageNet | |||
from .meta_vision import VisionDataset | |||
from .mnist import MNIST | |||
from .objects365 import Objects365 | |||
from .voc import PascalVOC |
@@ -0,0 +1,171 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import pickle | |||
import tarfile | |||
from typing import Tuple | |||
import numpy as np | |||
from ....logger import get_logger | |||
from .meta_vision import VisionDataset | |||
from .utils import _default_dataset_root, load_raw_data_from_url | |||
logger = get_logger(__name__) | |||
class CIFAR10(VisionDataset): | |||
r""" ``Dataset`` for CIFAR10 meta data | |||
""" | |||
url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
raw_file_name = "cifar-10-python.tar.gz" | |||
raw_file_md5 = "c58f30108f718f92721af3b95e74349a" | |||
raw_file_dir = "cifar-10-batches-py" | |||
train_batch = [ | |||
"data_batch_1", | |||
"data_batch_2", | |||
"data_batch_3", | |||
"data_batch_4", | |||
"data_batch_5", | |||
] | |||
test_batch = ["test_batch"] | |||
meta_info = {"name": "batches.meta"} | |||
def __init__( | |||
self, | |||
root: str = None, | |||
train: bool = True, | |||
download: bool = True, | |||
timeout: int = 500, | |||
): | |||
super().__init__(root, order=("image", "image_category")) | |||
self.timeout = timeout | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
if not os.path.exists(self.root): | |||
os.makedirs(self.root) | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
if download: | |||
logger.debug( | |||
"dir %s does not exist, will be automatically created", | |||
self.root, | |||
) | |||
os.makedirs(self.root) | |||
else: | |||
raise ValueError("dir %s does not exist" % self.root) | |||
self.target_file = os.path.join(self.root, self.raw_file_dir) | |||
# check existence of target pickle dir, if exists load the | |||
# pickle file no matter what download is set | |||
if os.path.exists(self.target_file): | |||
if train: | |||
self.arrays = self.bytes2array(self.train_batch) | |||
else: | |||
self.arrays = self.bytes2array(self.test_batch) | |||
else: | |||
if download: | |||
self.download() | |||
if train: | |||
self.arrays = self.bytes2array(self.train_batch) | |||
else: | |||
self.arrays = self.bytes2array(self.test_batch) | |||
else: | |||
raise ValueError( | |||
"dir does not contain target file %s, please set download=True" | |||
% (self.target_file) | |||
) | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def meta(self): | |||
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
with open(meta_path, "rb") as f: | |||
meta = pickle.load(f, encoding="bytes") | |||
return meta | |||
def download(self): | |||
url = self.url_path + self.raw_file_name | |||
load_raw_data_from_url( | |||
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout | |||
) | |||
self.process() | |||
def untar(self, file_path, dirs): | |||
assert file_path.endswith(".tar.gz") | |||
logger.debug("untar file %s to %s", file_path, dirs) | |||
t = tarfile.open(file_path) | |||
t.extractall(path=dirs) | |||
def bytes2array(self, filenames): | |||
data = [] | |||
label = [] | |||
for filename in filenames: | |||
path = os.path.join(self.root, self.raw_file_dir, filename) | |||
logger.debug("unpickle file %s", path) | |||
with open(path, "rb") as fo: | |||
dic = pickle.load(fo, encoding="bytes") | |||
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
data.extend(list(batch_data[..., [2, 1, 0]])) | |||
label.extend(dic[b"labels"]) | |||
label = np.array(label, dtype=np.int32) | |||
return (data, label) | |||
def process(self): | |||
logger.info("process raw data ...") | |||
self.untar(os.path.join(self.root, self.raw_file_name), self.root) | |||
class CIFAR100(CIFAR10): | |||
url_path = "http://www.cs.utoronto.ca/~kriz/" | |||
raw_file_name = "cifar-100-python.tar.gz" | |||
raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85" | |||
raw_file_dir = "cifar-100-python" | |||
train_batch = ["train"] | |||
test_batch = ["test"] | |||
meta_info = {"name": "meta"} | |||
@property | |||
def meta(self): | |||
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"]) | |||
with open(meta_path, "rb") as f: | |||
meta = pickle.load(f, encoding="bytes") | |||
return meta | |||
def bytes2array(self, filenames): | |||
data = [] | |||
fine_label = [] | |||
coarse_label = [] | |||
for filename in filenames: | |||
path = os.path.join(self.root, self.raw_file_dir, filename) | |||
logger.debug("unpickle file %s", path) | |||
with open(path, "rb") as fo: | |||
dic = pickle.load(fo, encoding="bytes") | |||
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | |||
data.extend(list(batch_data[..., [2, 1, 0]])) | |||
fine_label.extend(dic[b"fine_labels"]) | |||
coarse_label.extend(dic[b"coarse_labels"]) | |||
fine_label = np.array(fine_label, dtype=np.int32) | |||
coarse_label = np.array(coarse_label, dtype=np.int32) | |||
return data, fine_label, coarse_label |
@@ -0,0 +1,151 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to torchvision | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class Cityscapes(VisionDataset): | |||
r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"mask", | |||
"info", | |||
) | |||
def __init__(self, root, image_set, mode, *, order=None): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
city_root = self.root | |||
if not os.path.isdir(city_root): | |||
raise RuntimeError("Dataset not found or corrupted.") | |||
self.mode = mode | |||
self.images_dir = os.path.join(city_root, "leftImg8bit", image_set) | |||
self.masks_dir = os.path.join(city_root, self.mode, image_set) | |||
self.images, self.masks = [], [] | |||
# self.target_type = ["instance", "semantic", "polygon", "color"] | |||
# for semantic segmentation | |||
if mode == "gtFine": | |||
valid_modes = ("train", "test", "val") | |||
else: | |||
valid_modes = ("train", "train_extra", "val") | |||
for city in os.listdir(self.images_dir): | |||
img_dir = os.path.join(self.images_dir, city) | |||
mask_dir = os.path.join(self.masks_dir, city) | |||
for file_name in os.listdir(img_dir): | |||
mask_name = "{}_{}".format( | |||
file_name.split("_leftImg8bit")[0], | |||
self._get_target_suffix(self.mode, "semantic"), | |||
) | |||
self.images.append(os.path.join(img_dir, file_name)) | |||
self.masks.append(os.path.join(mask_dir, mask_name)) | |||
def __getitem__(self, index): | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "mask": | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
mask = self._trans_mask(mask) | |||
mask = mask[:, :, np.newaxis] | |||
target.append(mask) | |||
elif k == "info": | |||
if image is None: | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
info = [image.shape[0], image.shape[1], self.images[index]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.images) | |||
def _trans_mask(self, mask): | |||
trans_labels = [ | |||
7, | |||
8, | |||
11, | |||
12, | |||
13, | |||
17, | |||
19, | |||
20, | |||
21, | |||
22, | |||
23, | |||
24, | |||
25, | |||
26, | |||
27, | |||
28, | |||
31, | |||
32, | |||
33, | |||
] | |||
label = np.ones(mask.shape) * 255 | |||
for i, tl in enumerate(trans_labels): | |||
label[mask == tl] = i | |||
return label.astype(np.uint8) | |||
def _get_target_suffix(self, mode, target_type): | |||
if target_type == "instance": | |||
return "{}_instanceIds.png".format(mode) | |||
elif target_type == "semantic": | |||
return "{}_labelIds.png".format(mode) | |||
elif target_type == "color": | |||
return "{}_color.png".format(mode) | |||
else: | |||
return "{}_polygons.json".format(mode) | |||
def _load_json(self, path): | |||
with open(path, "r") as file: | |||
data = json.load(file) | |||
return data | |||
class_names = ( | |||
"road", | |||
"sidewalk", | |||
"building", | |||
"wall", | |||
"fence", | |||
"pole", | |||
"traffic light", | |||
"traffic sign", | |||
"vegetation", | |||
"terrain", | |||
"sky", | |||
"person", | |||
"rider", | |||
"car", | |||
"truck", | |||
"bus", | |||
"train", | |||
"motorcycle", | |||
"bicycle", | |||
) |
@@ -0,0 +1,366 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to maskrcnn-benchmark | |||
# MIT License | |||
# | |||
# Copyright (c) 2018 Facebook | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
min_keypoints_per_image = 10 | |||
def _count_visible_keypoints(anno): | |||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||
def has_valid_annotation(anno, order): | |||
# if it"s empty, there is no annotation | |||
if len(anno) == 0: | |||
return False | |||
if "boxes" in order or "boxes_category" in order: | |||
if "bbox" not in anno[0]: | |||
return False | |||
if "keypoints" in order: | |||
if "keypoints" not in anno[0]: | |||
return False | |||
# for keypoint detection tasks, only consider valid images those | |||
# containing at least min_keypoints_per_image | |||
if _count_visible_keypoints(anno) < min_keypoints_per_image: | |||
return False | |||
return True | |||
class COCO(VisionDataset): | |||
r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"keypoints", | |||
# TODO: need to check | |||
# "polygons", | |||
"info", | |||
) | |||
def __init__( | |||
self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
with open(ann_file, "r") as f: | |||
dataset = json.load(f) | |||
self.imgs = dict() | |||
for img in dataset["images"]: | |||
# for saving memory | |||
if "license" in img: | |||
del img["license"] | |||
if "coco_url" in img: | |||
del img["coco_url"] | |||
if "date_captured" in img: | |||
del img["date_captured"] | |||
if "flickr_url" in img: | |||
del img["flickr_url"] | |||
self.imgs[img["id"]] = img | |||
self.img_to_anns = defaultdict(list) | |||
for ann in dataset["annotations"]: | |||
# for saving memory | |||
if ( | |||
"boxes" not in self.order | |||
and "boxes_category" not in self.order | |||
and "bbox" in ann | |||
): | |||
del ann["bbox"] | |||
if "polygons" not in self.order and "segmentation" in ann: | |||
del ann["segmentation"] | |||
self.img_to_anns[ann["image_id"]].append(ann) | |||
self.cats = dict() | |||
for cat in dataset["categories"]: | |||
self.cats[cat["id"]] = cat | |||
self.ids = list(sorted(self.imgs.keys())) | |||
# filter images without detection annotations | |||
if remove_images_without_annotations: | |||
ids = [] | |||
for img_id in self.ids: | |||
anno = self.img_to_anns[img_id] | |||
# filter crowd annotations | |||
anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
anno = [ | |||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
] | |||
if has_valid_annotation(anno, order): | |||
ids.append(img_id) | |||
self.img_to_anns[img_id] = anno | |||
else: | |||
del self.imgs[img_id] | |||
del self.img_to_anns[img_id] | |||
self.ids = ids | |||
self.json_category_id_to_contiguous_id = { | |||
v: i + 1 for i, v in enumerate(self.cats.keys()) | |||
} | |||
self.contiguous_category_id_to_json_id = { | |||
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
} | |||
def __getitem__(self, index): | |||
img_id = self.ids[index] | |||
anno = self.img_to_anns[img_id] | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
file_name = self.imgs[img_id]["file_name"] | |||
path = os.path.join(self.root, file_name) | |||
image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
boxes = [obj["bbox"] for obj in anno] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
# transfer boxes from xywh to xyxy | |||
boxes[:, 2:] += boxes[:, :2] | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
boxes_category = [obj["category_id"] for obj in anno] | |||
boxes_category = [ | |||
self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "keypoints": | |||
keypoints = [obj["keypoints"] for obj in anno] | |||
keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||
-1, len(self.keypoint_names), 3 | |||
) | |||
target.append(keypoints) | |||
elif k == "polygons": | |||
polygons = [obj["segmentation"] for obj in anno] | |||
polygons = [ | |||
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||
for ps in polygons | |||
] | |||
target.append(polygons) | |||
elif k == "info": | |||
info = self.imgs[img_id] | |||
info = [info["height"], info["width"], info["file_name"]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.ids) | |||
def get_img_info(self, index): | |||
img_id = self.ids[index] | |||
img_info = self.imgs[img_id] | |||
return img_info | |||
class_names = ( | |||
"person", | |||
"bicycle", | |||
"car", | |||
"motorcycle", | |||
"airplane", | |||
"bus", | |||
"train", | |||
"truck", | |||
"boat", | |||
"traffic light", | |||
"fire hydrant", | |||
"stop sign", | |||
"parking meter", | |||
"bench", | |||
"bird", | |||
"cat", | |||
"dog", | |||
"horse", | |||
"sheep", | |||
"cow", | |||
"elephant", | |||
"bear", | |||
"zebra", | |||
"giraffe", | |||
"backpack", | |||
"umbrella", | |||
"handbag", | |||
"tie", | |||
"suitcase", | |||
"frisbee", | |||
"skis", | |||
"snowboard", | |||
"sports ball", | |||
"kite", | |||
"baseball bat", | |||
"baseball glove", | |||
"skateboard", | |||
"surfboard", | |||
"tennis racket", | |||
"bottle", | |||
"wine glass", | |||
"cup", | |||
"fork", | |||
"knife", | |||
"spoon", | |||
"bowl", | |||
"banana", | |||
"apple", | |||
"sandwich", | |||
"orange", | |||
"broccoli", | |||
"carrot", | |||
"hot dog", | |||
"pizza", | |||
"donut", | |||
"cake", | |||
"chair", | |||
"couch", | |||
"potted plant", | |||
"bed", | |||
"dining table", | |||
"toilet", | |||
"tv", | |||
"laptop", | |||
"mouse", | |||
"remote", | |||
"keyboard", | |||
"cell phone", | |||
"microwave", | |||
"oven", | |||
"toaster", | |||
"sink", | |||
"refrigerator", | |||
"book", | |||
"clock", | |||
"vase", | |||
"scissors", | |||
"teddy bear", | |||
"hair drier", | |||
"toothbrush", | |||
) | |||
classes_originID = { | |||
"person": 1, | |||
"bicycle": 2, | |||
"car": 3, | |||
"motorcycle": 4, | |||
"airplane": 5, | |||
"bus": 6, | |||
"train": 7, | |||
"truck": 8, | |||
"boat": 9, | |||
"traffic light": 10, | |||
"fire hydrant": 11, | |||
"stop sign": 13, | |||
"parking meter": 14, | |||
"bench": 15, | |||
"bird": 16, | |||
"cat": 17, | |||
"dog": 18, | |||
"horse": 19, | |||
"sheep": 20, | |||
"cow": 21, | |||
"elephant": 22, | |||
"bear": 23, | |||
"zebra": 24, | |||
"giraffe": 25, | |||
"backpack": 27, | |||
"umbrella": 28, | |||
"handbag": 31, | |||
"tie": 32, | |||
"suitcase": 33, | |||
"frisbee": 34, | |||
"skis": 35, | |||
"snowboard": 36, | |||
"sports ball": 37, | |||
"kite": 38, | |||
"baseball bat": 39, | |||
"baseball glove": 40, | |||
"skateboard": 41, | |||
"surfboard": 42, | |||
"tennis racket": 43, | |||
"bottle": 44, | |||
"wine glass": 46, | |||
"cup": 47, | |||
"fork": 48, | |||
"knife": 49, | |||
"spoon": 50, | |||
"bowl": 51, | |||
"banana": 52, | |||
"apple": 53, | |||
"sandwich": 54, | |||
"orange": 55, | |||
"broccoli": 56, | |||
"carrot": 57, | |||
"hot dog": 58, | |||
"pizza": 59, | |||
"donut": 60, | |||
"cake": 61, | |||
"chair": 62, | |||
"couch": 63, | |||
"potted plant": 64, | |||
"bed": 65, | |||
"dining table": 67, | |||
"toilet": 70, | |||
"tv": 72, | |||
"laptop": 73, | |||
"mouse": 74, | |||
"remote": 75, | |||
"keyboard": 76, | |||
"cell phone": 77, | |||
"microwave": 78, | |||
"oven": 79, | |||
"toaster": 80, | |||
"sink": 81, | |||
"refrigerator": 82, | |||
"book": 84, | |||
"clock": 85, | |||
"vase": 86, | |||
"scissors": 87, | |||
"teddy bear": 88, | |||
"hair drier": 89, | |||
"toothbrush": 90, | |||
} | |||
keypoint_names = ( | |||
"nose", | |||
"left_eye", | |||
"right_eye", | |||
"left_ear", | |||
"right_ear", | |||
"left_shoulder", | |||
"right_shoulder", | |||
"left_elbow", | |||
"right_elbow", | |||
"left_wrist", | |||
"right_wrist", | |||
"left_hip", | |||
"right_hip", | |||
"left_knee", | |||
"right_knee", | |||
"left_ankle", | |||
"right_ankle", | |||
) |
@@ -0,0 +1,90 @@ | |||
# -*- coding: utf-8 -*- | |||
# BSD 3-Clause License | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import os | |||
from typing import Dict, List, Tuple | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
from .utils import is_img | |||
class ImageFolder(VisionDataset): | |||
def __init__(self, root: str, check_valid_func=None, class_name: bool = False): | |||
r""" | |||
ImageFolder is a class for loading image data and labels from a organized folder. | |||
the folder is expected to be organized as followed | |||
root/cls/xxx.img_ext | |||
labels are indices of sorted classes in the root directory | |||
:param root: root directory of an image folder | |||
:param loader: a function used to load image from path, | |||
if ``None``, default function that loads | |||
images with PILwill be called | |||
:param check_valid_func: a function used to check if files in folder are | |||
expected image files, if ``None``, default function | |||
that checks file extensions will be called | |||
:param class_name: if ``True``, return class name instead of class index | |||
""" | |||
super().__init__(root, order=("image", "image_category")) | |||
self.root = root | |||
if check_valid_func is not None: | |||
self.check_valid = check_valid_func | |||
else: | |||
self.check_valid = is_img | |||
self.class_name = class_name | |||
self.class_dict = self.collect_class() | |||
self.samples = self.collect_samples() | |||
def collect_samples(self) -> List: | |||
samples = [] | |||
directory = os.path.expanduser(self.root) | |||
for key in sorted(self.class_dict.keys()): | |||
d = os.path.join(directory, key) | |||
if not os.path.isdir(d): | |||
continue | |||
for r, _, filename in sorted(os.walk(d, followlinks=True)): | |||
for name in sorted(filename): | |||
path = os.path.join(r, name) | |||
if self.check_valid(path): | |||
if self.class_name: | |||
samples.append((path, key)) | |||
else: | |||
samples.append((path, self.class_dict[key])) | |||
return samples | |||
def collect_class(self) -> Dict: | |||
classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |||
classes.sort() | |||
return {classes[i]: np.int32(i) for i in range(len(classes))} | |||
def __getitem__(self, index: int) -> Tuple: | |||
path, label = self.samples[index] | |||
img = cv2.imread(path, cv2.IMREAD_COLOR) | |||
return img, label | |||
def __len__(self): | |||
return len(self.samples) |
@@ -0,0 +1,248 @@ | |||
# -*- coding: utf-8 -*- | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# | |||
# This file has been modified by Megvii ("Megvii Modifications"). | |||
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import os | |||
import shutil | |||
from tqdm import tqdm | |||
from ....distributed.group import is_distributed | |||
from ....logger import get_logger | |||
from ....serialization import load, save | |||
from .folder import ImageFolder | |||
from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||
logger = get_logger(__name__) | |||
class ImageNet(ImageFolder): | |||
r""" | |||
Load ImageNet from raw files or folder, expected folder looks like | |||
.. code-block:: bash | |||
${root}/ | |||
| [REQUIRED TAR FILES] | |||
|- ILSVRC2012_img_train.tar | |||
|- ILSVRC2012_img_val.tar | |||
|- ILSVRC2012_devkit_t12.tar.gz | |||
| [OPTIONAL IMAGE FOLDERS] | |||
|- train/cls/xxx.${img_ext} | |||
|- val/cls/xxx.${img_ext} | |||
|- ILSVRC2012_devkit_t12/data/meta.mat | |||
|- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
If the image folders don't exist, raw tar files are required to get extracted and processed. | |||
""" | |||
raw_file_meta = { | |||
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||
} # ImageNet raw files | |||
default_train_dir = "train" | |||
default_val_dir = "val" | |||
default_devkit_dir = "ILSVRC2012_devkit_t12" | |||
def __init__(self, root: str = None, train: bool = True, **kwargs): | |||
r""" | |||
initialization: | |||
* if ``root`` contains ``self.target_folder`` depent on ``train``: | |||
* initialize ImageFolder with target_folder | |||
* else: | |||
* if all raw files are in ``root``: | |||
* parse ``self.target_folder`` from raw files | |||
* initialize ImageFolder with ``self.target_folder`` | |||
* else: | |||
* raise error | |||
:param root: root directory of imagenet data, if root is ``None``, used default_dataset_root | |||
:param train: if ``True``, load the train split, otherwise load the validation split | |||
""" | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
raise FileNotFoundError("dir %s does not exist" % self.root) | |||
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
if not os.path.exists(self.devkit_dir): | |||
logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||
self._prepare_devkit() | |||
self.train = train | |||
if train: | |||
self.target_folder = os.path.join(self.root, self.default_train_dir) | |||
else: | |||
self.target_folder = os.path.join(self.root, self.default_val_dir) | |||
if not os.path.exists(self.target_folder): | |||
logger.warning( | |||
"expected image folder %s does not exist, try to load from raw file", | |||
self.target_folder, | |||
) | |||
if not self.check_raw_file(): | |||
raise FileNotFoundError( | |||
"expected image folder %s does not exist, and raw files do not exist in %s" | |||
% (self.target_folder, self.root) | |||
) | |||
elif is_distributed(): | |||
raise RuntimeError( | |||
"extracting raw file shouldn't be done in distributed mode, use single process instead" | |||
) | |||
elif train: | |||
self._prepare_train() | |||
else: | |||
self._prepare_val() | |||
super().__init__(self.target_folder, **kwargs) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def valid_ground_truth(self): | |||
groud_truth_path = os.path.join( | |||
self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" | |||
) | |||
if os.path.exists(groud_truth_path): | |||
with open(groud_truth_path, "r") as f: | |||
val_labels = f.readlines() | |||
return [int(val_label) for val_label in val_labels] | |||
else: | |||
raise FileNotFoundError( | |||
"valid ground truth file %s does not exist" % groud_truth_path | |||
) | |||
@property | |||
def meta(self): | |||
try: | |||
return load(os.path.join(self.devkit_dir, "meta.pkl")) | |||
except FileNotFoundError: | |||
import scipy.io | |||
meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | |||
if not os.path.exists(meta_path): | |||
raise FileNotFoundError("meta file %s does not exist" % meta_path) | |||
meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||
nums_children = list(zip(*meta))[4] | |||
meta = [ | |||
meta[idx] | |||
for idx, num_children in enumerate(nums_children) | |||
if num_children == 0 | |||
] | |||
idcs, wnids, classes = list(zip(*meta))[:3] | |||
classes = [tuple(clss.split(", ")) for clss in classes] | |||
idx_to_wnid = dict(zip(idcs, wnids)) | |||
wnid_to_classes = dict(zip(wnids, classes)) | |||
logger.info( | |||
"saving cached meta file to %s", | |||
os.path.join(self.devkit_dir, "meta.pkl"), | |||
) | |||
save( | |||
(idx_to_wnid, wnid_to_classes), | |||
os.path.join(self.devkit_dir, "meta.pkl"), | |||
) | |||
return idx_to_wnid, wnid_to_classes | |||
def check_raw_file(self) -> bool: | |||
return all( | |||
[ | |||
os.path.exists(os.path.join(self.root, value[0])) | |||
for _, value in self.raw_file_meta.items() | |||
] | |||
) | |||
def _organize_val_data(self): | |||
id2wnid = self.meta[0] | |||
val_idcs = self.valid_ground_truth | |||
val_wnids = [id2wnid[idx] for idx in val_idcs] | |||
val_images = sorted( | |||
[ | |||
os.path.join(self.target_folder, image) | |||
for image in os.listdir(self.target_folder) | |||
] | |||
) | |||
logger.debug("mkdir for val set wnids") | |||
for wnid in set(val_wnids): | |||
os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) | |||
logger.debug("mv val images into wnids dir") | |||
for wnid, img_file in tqdm(zip(val_wnids, val_images)): | |||
shutil.move( | |||
img_file, | |||
os.path.join( | |||
self.root, self.default_val_dir, wnid, os.path.basename(img_file) | |||
), | |||
) | |||
def _prepare_val(self): | |||
assert not self.train | |||
raw_filename, checksum = self.raw_file_meta["val"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum valid tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract valid tar file... this may take 10-20 minutes") | |||
untar(os.path.join(self.root, raw_file), self.target_folder) | |||
self._organize_val_data() | |||
def _prepare_train(self): | |||
assert self.train | |||
raw_filename, checksum = self.raw_file_meta["train"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum train tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract train tar file.. this may take several hours") | |||
untar( | |||
os.path.join(self.root, raw_file), self.target_folder, | |||
) | |||
paths = [ | |||
os.path.join(self.target_folder, child_dir) | |||
for child_dir in os.listdir(self.target_folder) | |||
] | |||
for path in tqdm(paths): | |||
untar(path, os.path.splitext(path)[0], remove=True) | |||
def _prepare_devkit(self): | |||
raw_filename, checksum = self.raw_file_meta["devkit"] | |||
raw_file = os.path.join(self.root, raw_filename) | |||
logger.info("checksum devkit tar file %s ...", raw_file) | |||
assert ( | |||
calculate_md5(raw_file) == checksum | |||
), "checksum mismatch, {} may be damaged".format(raw_file) | |||
logger.info("extract devkit file..") | |||
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) |
@@ -0,0 +1,41 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import os | |||
from ..meta_dataset import MapDataset | |||
class VisionDataset(MapDataset): | |||
_repr_indent = 4 | |||
def __init__(self, root, *, order=None, supported_order=None): | |||
if isinstance(root, (str, bytes)): | |||
root = os.path.expanduser(root) | |||
self.root = root | |||
if order is None: | |||
order = ("image",) | |||
if not isinstance(order, collections.abc.Sequence): | |||
raise ValueError( | |||
"order should be a sequence, but got order={}".format(order) | |||
) | |||
if supported_order is not None: | |||
assert isinstance(supported_order, collections.abc.Sequence) | |||
for k in order: | |||
if k not in supported_order: | |||
raise NotImplementedError("{} is unsupported data type".format(k)) | |||
self.order = order | |||
def __getitem__(self, index): | |||
raise NotImplementedError | |||
def __len__(self): | |||
raise NotImplementedError |
@@ -0,0 +1,197 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import gzip | |||
import os | |||
import struct | |||
from typing import Tuple | |||
import numpy as np | |||
from tqdm import tqdm | |||
from ....logger import get_logger | |||
from .meta_vision import VisionDataset | |||
from .utils import _default_dataset_root, load_raw_data_from_url | |||
logger = get_logger(__name__) | |||
class MNIST(VisionDataset): | |||
r""" ``Dataset`` for MNIST meta data | |||
""" | |||
url_path = "http://yann.lecun.com/exdb/mnist/" | |||
""" | |||
url prefix for downloading raw file | |||
""" | |||
raw_file_name = [ | |||
"train-images-idx3-ubyte.gz", | |||
"train-labels-idx1-ubyte.gz", | |||
"t10k-images-idx3-ubyte.gz", | |||
"t10k-labels-idx1-ubyte.gz", | |||
] | |||
""" | |||
raw file names of both training set and test set (10k) | |||
""" | |||
raw_file_md5 = [ | |||
"f68b3c2dcbeaaa9fbdd348bbdeb94873", | |||
"d53e105ee54ea40749a09fcbcd1e9432", | |||
"9fb629c4189551a2d022fa330f9573f3", | |||
"ec29112dd5afa0611ce80d1b7f02629c", | |||
] | |||
""" | |||
md5 for checking raw files | |||
""" | |||
def __init__( | |||
self, | |||
root: str = None, | |||
train: bool = True, | |||
download: bool = True, | |||
timeout: int = 500, | |||
): | |||
r""" | |||
:param root: path for mnist dataset downloading or loading, if ``None``, | |||
set ``root`` to the ``_default_root`` | |||
:param train: if ``True``, loading trainingset, else loading test set | |||
:param download: if raw files do not exists and download sets to ``True``, | |||
download raw files and process, otherwise raise ValueError, default is True | |||
""" | |||
super().__init__(root, order=("image", "image_category")) | |||
self.timeout = timeout | |||
# process the root path | |||
if root is None: | |||
self.root = self._default_root | |||
if not os.path.exists(self.root): | |||
os.makedirs(self.root) | |||
else: | |||
self.root = root | |||
if not os.path.exists(self.root): | |||
if download: | |||
logger.debug( | |||
"dir %s does not exist, will be automatically created", | |||
self.root, | |||
) | |||
os.makedirs(self.root) | |||
else: | |||
raise ValueError("dir %s does not exist" % self.root) | |||
if self._check_raw_files(): | |||
self.process(train) | |||
elif download: | |||
self.download() | |||
self.process(train) | |||
else: | |||
raise ValueError( | |||
"root does not contain valid raw files, please set download=True" | |||
) | |||
def __getitem__(self, index: int) -> Tuple: | |||
return tuple(array[index] for array in self.arrays) | |||
def __len__(self) -> int: | |||
return len(self.arrays[0]) | |||
@property | |||
def _default_root(self): | |||
return os.path.join(_default_dataset_root(), self.__class__.__name__) | |||
@property | |||
def meta(self): | |||
return self._meta_data | |||
def _check_raw_files(self): | |||
return all( | |||
[ | |||
os.path.exists(os.path.join(self.root, path)) | |||
for path in self.raw_file_name | |||
] | |||
) | |||
def download(self): | |||
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | |||
url = self.url_path + file_name | |||
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | |||
def process(self, train): | |||
# load raw files and transform them into meta data and datasets Tuple(np.array) | |||
logger.info("process the raw files of %s set...", "train" if train else "test") | |||
if train: | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[0]) | |||
) | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[1]) | |||
) | |||
else: | |||
meta_data_images, images = parse_idx3( | |||
os.path.join(self.root, self.raw_file_name[2]) | |||
) | |||
meta_data_labels, labels = parse_idx1( | |||
os.path.join(self.root, self.raw_file_name[3]) | |||
) | |||
self._meta_data = { | |||
"images": meta_data_images, | |||
"labels": meta_data_labels, | |||
} | |||
self.arrays = (images, labels.astype(np.int32)) | |||
def parse_idx3(idx3_file): | |||
# parse idx3 file to meta data and data in numpy array (images) | |||
logger.debug("parse idx3 file %s ...", idx3_file) | |||
assert idx3_file.endswith(".gz") | |||
with gzip.open(idx3_file, "rb") as f: | |||
bin_data = f.read() | |||
# parse meta data | |||
offset = 0 | |||
fmt_header = ">iiii" | |||
magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) | |||
meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} | |||
# parse images | |||
image_size = height * width | |||
offset += struct.calcsize(fmt_header) | |||
fmt_image = ">" + str(image_size) + "B" | |||
images = [] | |||
bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
for image in struct.iter_unpack(fmt_image, bin_data[offset:]): | |||
images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) | |||
bar.update() | |||
bar.close() | |||
return meta_data, images | |||
def parse_idx1(idx1_file): | |||
# parse idx1 file to meta data and data in numpy array (labels) | |||
logger.debug("parse idx1 file %s ...", idx1_file) | |||
assert idx1_file.endswith(".gz") | |||
with gzip.open(idx1_file, "rb") as f: | |||
bin_data = f.read() | |||
# parse meta data | |||
offset = 0 | |||
fmt_header = ">ii" | |||
magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) | |||
meta_data = {"magic": magic, "imgs": imgs} | |||
# parse labels | |||
offset += struct.calcsize(fmt_header) | |||
fmt_image = ">B" | |||
labels = np.empty(imgs, dtype=int) | |||
bar = tqdm(total=meta_data["imgs"], ncols=80) | |||
for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): | |||
labels[i] = label[0] | |||
bar.update() | |||
bar.close() | |||
return meta_data, labels |
@@ -0,0 +1,498 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to maskrcnn-benchmark | |||
# MIT License | |||
# | |||
# Copyright (c) 2018 Facebook | |||
# --------------------------------------------------------------------- | |||
import json | |||
import os | |||
from collections import defaultdict | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class Objects365(VisionDataset): | |||
r"""`Objects365 <https://www.objects365.org/overview.html>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"info", | |||
) | |||
def __init__( | |||
self, root, ann_file, remove_images_without_annotations=False, *, order=None | |||
): | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
with open(ann_file, "r") as f: | |||
dataset = json.load(f) | |||
self.imgs = dict() | |||
for img in dataset["images"]: | |||
self.imgs[img["id"]] = img | |||
self.img_to_anns = defaultdict(list) | |||
for ann in dataset["annotations"]: | |||
# for saving memory | |||
if ( | |||
"boxes" not in self.order | |||
and "boxes_category" not in self.order | |||
and "bbox" in ann | |||
): | |||
del ann["bbox"] | |||
self.img_to_anns[ann["image_id"]].append(ann) | |||
self.cats = dict() | |||
for cat in dataset["categories"]: | |||
self.cats[cat["id"]] = cat | |||
self.ids = list(sorted(self.imgs.keys())) | |||
# filter images without detection annotations | |||
if remove_images_without_annotations: | |||
ids = [] | |||
for img_id in self.ids: | |||
anno = self.img_to_anns[img_id] | |||
# filter crowd annotations | |||
anno = [obj for obj in anno if obj["iscrowd"] == 0] | |||
anno = [ | |||
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 | |||
] | |||
if len(anno) > 0: | |||
ids.append(img_id) | |||
self.img_to_anns[img_id] = anno | |||
else: | |||
del self.imgs[img_id] | |||
del self.img_to_anns[img_id] | |||
self.ids = ids | |||
self.json_category_id_to_contiguous_id = { | |||
v: i + 1 for i, v in enumerate(self.cats.keys()) | |||
} | |||
self.contiguous_category_id_to_json_id = { | |||
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |||
} | |||
def __getitem__(self, index): | |||
img_id = self.ids[index] | |||
anno = self.img_to_anns[img_id] | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
file_name = self.imgs[img_id]["file_name"] | |||
path = os.path.join(self.root, file_name) | |||
image = cv2.imread(path, cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
boxes = [obj["bbox"] for obj in anno] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
# transfer boxes from xywh to xyxy | |||
boxes[:, 2:] += boxes[:, :2] | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
boxes_category = [obj["category_id"] for obj in anno] | |||
boxes_category = [ | |||
self.json_category_id_to_contiguous_id[c] for c in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "info": | |||
info = self.imgs[img_id] | |||
info = [info["height"], info["width"], info["file_name"]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.ids) | |||
def get_img_info(self, index): | |||
img_id = self.ids[index] | |||
img_info = self.imgs[img_id] | |||
return img_info | |||
class_names = ( | |||
"person", | |||
"sneakers", | |||
"chair", | |||
"hat", | |||
"lamp", | |||
"bottle", | |||
"cabinet/shelf", | |||
"cup", | |||
"car", | |||
"glasses", | |||
"picture/frame", | |||
"desk", | |||
"handbag", | |||
"street lights", | |||
"book", | |||
"plate", | |||
"helmet", | |||
"leather shoes", | |||
"pillow", | |||
"glove", | |||
"potted plant", | |||
"bracelet", | |||
"flower", | |||
"tv", | |||
"storage box", | |||
"vase", | |||
"bench", | |||
"wine glass", | |||
"boots", | |||
"bowl", | |||
"dining table", | |||
"umbrella", | |||
"boat", | |||
"flag", | |||
"speaker", | |||
"trash bin/can", | |||
"stool", | |||
"backpack", | |||
"couch", | |||
"belt", | |||
"carpet", | |||
"basket", | |||
"towel/napkin", | |||
"slippers", | |||
"barrel/bucket", | |||
"coffee table", | |||
"suv", | |||
"toy", | |||
"tie", | |||
"bed", | |||
"traffic light", | |||
"pen/pencil", | |||
"microphone", | |||
"sandals", | |||
"canned", | |||
"necklace", | |||
"mirror", | |||
"faucet", | |||
"bicycle", | |||
"bread", | |||
"high heels", | |||
"ring", | |||
"van", | |||
"watch", | |||
"sink", | |||
"horse", | |||
"fish", | |||
"apple", | |||
"camera", | |||
"candle", | |||
"teddy bear", | |||
"cake", | |||
"motorcycle", | |||
"wild bird", | |||
"laptop", | |||
"knife", | |||
"traffic sign", | |||
"cell phone", | |||
"paddle", | |||
"truck", | |||
"cow", | |||
"power outlet", | |||
"clock", | |||
"drum", | |||
"fork", | |||
"bus", | |||
"hanger", | |||
"nightstand", | |||
"pot/pan", | |||
"sheep", | |||
"guitar", | |||
"traffic cone", | |||
"tea pot", | |||
"keyboard", | |||
"tripod", | |||
"hockey", | |||
"fan", | |||
"dog", | |||
"spoon", | |||
"blackboard/whiteboard", | |||
"balloon", | |||
"air conditioner", | |||
"cymbal", | |||
"mouse", | |||
"telephone", | |||
"pickup truck", | |||
"orange", | |||
"banana", | |||
"airplane", | |||
"luggage", | |||
"skis", | |||
"soccer", | |||
"trolley", | |||
"oven", | |||
"remote", | |||
"baseball glove", | |||
"paper towel", | |||
"refrigerator", | |||
"train", | |||
"tomato", | |||
"machinery vehicle", | |||
"tent", | |||
"shampoo/shower gel", | |||
"head phone", | |||
"lantern", | |||
"donut", | |||
"cleaning products", | |||
"sailboat", | |||
"tangerine", | |||
"pizza", | |||
"kite", | |||
"computer box", | |||
"elephant", | |||
"toiletries", | |||
"gas stove", | |||
"broccoli", | |||
"toilet", | |||
"stroller", | |||
"shovel", | |||
"baseball bat", | |||
"microwave", | |||
"skateboard", | |||
"surfboard", | |||
"surveillance camera", | |||
"gun", | |||
"life saver", | |||
"cat", | |||
"lemon", | |||
"liquid soap", | |||
"zebra", | |||
"duck", | |||
"sports car", | |||
"giraffe", | |||
"pumpkin", | |||
"piano", | |||
"stop sign", | |||
"radiator", | |||
"converter", | |||
"tissue ", | |||
"carrot", | |||
"washing machine", | |||
"vent", | |||
"cookies", | |||
"cutting/chopping board", | |||
"tennis racket", | |||
"candy", | |||
"skating and skiing shoes", | |||
"scissors", | |||
"folder", | |||
"baseball", | |||
"strawberry", | |||
"bow tie", | |||
"pigeon", | |||
"pepper", | |||
"coffee machine", | |||
"bathtub", | |||
"snowboard", | |||
"suitcase", | |||
"grapes", | |||
"ladder", | |||
"pear", | |||
"american football", | |||
"basketball", | |||
"potato", | |||
"paint brush", | |||
"printer", | |||
"billiards", | |||
"fire hydrant", | |||
"goose", | |||
"projector", | |||
"sausage", | |||
"fire extinguisher", | |||
"extension cord", | |||
"facial mask", | |||
"tennis ball", | |||
"chopsticks", | |||
"electronic stove and gas stove", | |||
"pie", | |||
"frisbee", | |||
"kettle", | |||
"hamburger", | |||
"golf club", | |||
"cucumber", | |||
"clutch", | |||
"blender", | |||
"tong", | |||
"slide", | |||
"hot dog", | |||
"toothbrush", | |||
"facial cleanser", | |||
"mango", | |||
"deer", | |||
"egg", | |||
"violin", | |||
"marker", | |||
"ship", | |||
"chicken", | |||
"onion", | |||
"ice cream", | |||
"tape", | |||
"wheelchair", | |||
"plum", | |||
"bar soap", | |||
"scale", | |||
"watermelon", | |||
"cabbage", | |||
"router/modem", | |||
"golf ball", | |||
"pine apple", | |||
"crane", | |||
"fire truck", | |||
"peach", | |||
"cello", | |||
"notepaper", | |||
"tricycle", | |||
"toaster", | |||
"helicopter", | |||
"green beans", | |||
"brush", | |||
"carriage", | |||
"cigar", | |||
"earphone", | |||
"penguin", | |||
"hurdle", | |||
"swing", | |||
"radio", | |||
"CD", | |||
"parking meter", | |||
"swan", | |||
"garlic", | |||
"french fries", | |||
"horn", | |||
"avocado", | |||
"saxophone", | |||
"trumpet", | |||
"sandwich", | |||
"cue", | |||
"kiwi fruit", | |||
"bear", | |||
"fishing rod", | |||
"cherry", | |||
"tablet", | |||
"green vegetables", | |||
"nuts", | |||
"corn", | |||
"key", | |||
"screwdriver", | |||
"globe", | |||
"broom", | |||
"pliers", | |||
"volleyball", | |||
"hammer", | |||
"eggplant", | |||
"trophy", | |||
"dates", | |||
"board eraser", | |||
"rice", | |||
"tape measure/ruler", | |||
"dumbbell", | |||
"hamimelon", | |||
"stapler", | |||
"camel", | |||
"lettuce", | |||
"goldfish", | |||
"meat balls", | |||
"medal", | |||
"toothpaste", | |||
"antelope", | |||
"shrimp", | |||
"rickshaw", | |||
"trombone", | |||
"pomegranate", | |||
"coconut", | |||
"jellyfish", | |||
"mushroom", | |||
"calculator", | |||
"treadmill", | |||
"butterfly", | |||
"egg tart", | |||
"cheese", | |||
"pig", | |||
"pomelo", | |||
"race car", | |||
"rice cooker", | |||
"tuba", | |||
"crosswalk sign", | |||
"papaya", | |||
"hair drier", | |||
"green onion", | |||
"chips", | |||
"dolphin", | |||
"sushi", | |||
"urinal", | |||
"donkey", | |||
"electric drill", | |||
"spring rolls", | |||
"tortoise/turtle", | |||
"parrot", | |||
"flute", | |||
"measuring cup", | |||
"shark", | |||
"steak", | |||
"poker card", | |||
"binoculars", | |||
"llama", | |||
"radish", | |||
"noodles", | |||
"yak", | |||
"mop", | |||
"crab", | |||
"microscope", | |||
"barbell", | |||
"bread/bun", | |||
"baozi", | |||
"lion", | |||
"red cabbage", | |||
"polar bear", | |||
"lighter", | |||
"seal", | |||
"mangosteen", | |||
"comb", | |||
"eraser", | |||
"pitaya", | |||
"scallop", | |||
"pencil case", | |||
"saw", | |||
"table tennis paddle", | |||
"okra", | |||
"starfish", | |||
"eagle", | |||
"monkey", | |||
"durian", | |||
"game board", | |||
"rabbit", | |||
"french horn", | |||
"ambulance", | |||
"asparagus", | |||
"hoverboard", | |||
"pasta", | |||
"target", | |||
"hotair balloon", | |||
"chainsaw", | |||
"lobster", | |||
"iron", | |||
"flashlight", | |||
) |
@@ -0,0 +1,89 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import hashlib | |||
import os | |||
import tarfile | |||
from ....distributed.group import is_distributed | |||
from ....logger import get_logger | |||
from ....utils.http_download import download_from_url | |||
IMG_EXT = (".jpg", ".png", ".jpeg", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") | |||
logger = get_logger(__name__) | |||
def _default_dataset_root(): | |||
default_dataset_root = os.path.expanduser( | |||
os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "megengine") | |||
) | |||
return default_dataset_root | |||
def load_raw_data_from_url( | |||
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int | |||
): | |||
cached_file = os.path.join(raw_data_dir, filename) | |||
logger.debug( | |||
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||
) | |||
if not os.path.exists(cached_file): | |||
if is_distributed(): | |||
logger.warning( | |||
"Downloading raw data in DISTRIBUTED mode\n" | |||
" File may be downloaded multiple times. We recommend\n" | |||
" users to download in single process first." | |||
) | |||
md5 = download_from_url(url, cached_file, http_read_timeout=timeout) | |||
else: | |||
md5 = calculate_md5(cached_file) | |||
if target_md5 == md5: | |||
logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||
else: | |||
os.remove(cached_file) | |||
raise RuntimeError("{} exists but fail to match md5".format(filename)) | |||
def calculate_md5(filename): | |||
m = hashlib.md5() | |||
with open(filename, "rb") as f: | |||
while True: | |||
data = f.read(4096) | |||
if not data: | |||
break | |||
m.update(data) | |||
return m.hexdigest() | |||
def is_img(filename): | |||
return filename.lower().endswith(IMG_EXT) | |||
def untar(path, to=None, remove=False): | |||
if to is None: | |||
to = os.path.dirname(path) | |||
with tarfile.open(path, "r") as tar: | |||
tar.extractall(path=to) | |||
if remove: | |||
os.remove(path) | |||
def untargz(path, to=None, remove=False): | |||
if path.endswith(".tar.gz"): | |||
if to is None: | |||
to = os.path.dirname(path) | |||
with tarfile.open(path, "r:gz") as tar: | |||
tar.extractall(path=to) | |||
else: | |||
raise ValueError("path %s does not end with .tar" % path) | |||
if remove: | |||
os.remove(path) |
@@ -0,0 +1,195 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# --------------------------------------------------------------------- | |||
# Part of the following code in this file refs to torchvision | |||
# BSD 3-Clause License | |||
# | |||
# Copyright (c) Soumith Chintala 2016, | |||
# All rights reserved. | |||
# --------------------------------------------------------------------- | |||
import collections.abc | |||
import os | |||
import xml.etree.ElementTree as ET | |||
import cv2 | |||
import numpy as np | |||
from .meta_vision import VisionDataset | |||
class PascalVOC(VisionDataset): | |||
r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset. | |||
""" | |||
supported_order = ( | |||
"image", | |||
"boxes", | |||
"boxes_category", | |||
"mask", | |||
"info", | |||
) | |||
def __init__(self, root, image_set, *, order=None): | |||
if ("boxes" in order or "boxes_category" in order) and "mask" in order: | |||
raise ValueError( | |||
"PascalVOC only supports boxes & boxes_category or mask, not both." | |||
) | |||
super().__init__(root, order=order, supported_order=self.supported_order) | |||
if not os.path.isdir(self.root): | |||
raise RuntimeError("Dataset not found or corrupted.") | |||
self.image_set = image_set | |||
image_dir = os.path.join(self.root, "JPEGImages") | |||
if "boxes" in order or "boxes_category" in order: | |||
annotation_dir = os.path.join(self.root, "Annotations") | |||
splitdet_dir = os.path.join(self.root, "ImageSets/Main") | |||
split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") | |||
with open(os.path.join(split_f), "r") as f: | |||
self.file_names = [x.strip() for x in f.readlines()] | |||
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
self.annotations = [ | |||
os.path.join(annotation_dir, x + ".xml") for x in self.file_names | |||
] | |||
assert len(self.images) == len(self.annotations) | |||
elif "mask" in order: | |||
if "aug" in image_set: | |||
mask_dir = os.path.join(self.root, "SegmentationClass_aug") | |||
else: | |||
mask_dir = os.path.join(self.root, "SegmentationClass") | |||
splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") | |||
split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") | |||
with open(os.path.join(split_f), "r") as f: | |||
self.file_names = [x.strip() for x in f.readlines()] | |||
self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] | |||
self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] | |||
assert len(self.images) == len(self.masks) | |||
else: | |||
raise NotImplementedError | |||
def __getitem__(self, index): | |||
target = [] | |||
for k in self.order: | |||
if k == "image": | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
target.append(image) | |||
elif k == "boxes": | |||
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] | |||
# boxes type xyxy | |||
boxes = [ | |||
(bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes | |||
] | |||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||
target.append(boxes) | |||
elif k == "boxes_category": | |||
anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||
boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] | |||
boxes_category = [ | |||
self.class_names.index(bc) + 1 for bc in boxes_category | |||
] | |||
boxes_category = np.array(boxes_category, dtype=np.int32) | |||
target.append(boxes_category) | |||
elif k == "mask": | |||
if "aug" in self.image_set: | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | |||
else: | |||
mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||
mask = self._trans_mask(mask) | |||
mask = mask[:, :, np.newaxis] | |||
target.append(mask) | |||
elif k == "info": | |||
if image is None: | |||
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | |||
info = [image.shape[0], image.shape[1], self.file_names[index]] | |||
target.append(info) | |||
else: | |||
raise NotImplementedError | |||
return tuple(target) | |||
def __len__(self): | |||
return len(self.images) | |||
def _trans_mask(self, mask): | |||
label = np.ones(mask.shape[:2]) * 255 | |||
for i in range(len(self.class_colors)): | |||
b, g, r = self.class_colors[i] | |||
label[ | |||
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | |||
] = i | |||
return label.astype(np.uint8) | |||
def parse_voc_xml(self, node): | |||
voc_dict = {} | |||
children = list(node) | |||
if children: | |||
def_dic = collections.defaultdict(list) | |||
for dc in map(self.parse_voc_xml, children): | |||
for ind, v in dc.items(): | |||
def_dic[ind].append(v) | |||
if node.tag == "annotation": | |||
def_dic["object"] = [def_dic["object"]] | |||
voc_dict = { | |||
node.tag: { | |||
ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() | |||
} | |||
} | |||
if node.text: | |||
text = node.text.strip() | |||
if not children: | |||
voc_dict[node.tag] = text | |||
return voc_dict | |||
class_names = ( | |||
"aeroplane", | |||
"bicycle", | |||
"bird", | |||
"boat", | |||
"bottle", | |||
"bus", | |||
"car", | |||
"cat", | |||
"chair", | |||
"cow", | |||
"diningtable", | |||
"dog", | |||
"horse", | |||
"motorbike", | |||
"person", | |||
"pottedplant", | |||
"sheep", | |||
"sofa", | |||
"train", | |||
"tvmonitor", | |||
) | |||
class_colors = [ | |||
[0, 0, 128], | |||
[0, 128, 0], | |||
[0, 128, 128], | |||
[128, 0, 0], | |||
[128, 0, 128], | |||
[128, 128, 0], | |||
[128, 128, 128], | |||
[0, 0, 64], | |||
[0, 0, 192], | |||
[0, 128, 64], | |||
[0, 128, 192], | |||
[128, 0, 64], | |||
[128, 0, 192], | |||
[128, 128, 64], | |||
[128, 128, 192], | |||
[0, 64, 0], | |||
[0, 64, 128], | |||
[0, 192, 0], | |||
[0, 192, 128], | |||
[128, 64, 0], | |||
] |
@@ -0,0 +1,274 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import math | |||
from abc import ABC | |||
from typing import Any, Generator, Iterator, List, Union | |||
import numpy as np | |||
import megengine.distributed as dist | |||
class Sampler(ABC): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
num_samples=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
An abstract class for all sampler | |||
:type dataset: `dataset` | |||
:param dataset: dataset to sample from | |||
:type batch_size: positive integer | |||
:param batch_size: batch size for batch method | |||
:type drop_last: bool | |||
:param drop_last: set ``True`` to drop the last incomplete batch, | |||
if the dataset size is not divisible by the batch size. If ``False`` and | |||
the size of dataset is not divisible by the batch_size, then the last batch will | |||
be smaller. (default: ``False``) | |||
:type num_samples: positive integer | |||
:param num_samples: number of samples assigned to one rank | |||
:type world_size: positive integer | |||
:param world_size: number of ranks | |||
:type rank: non-negative integer within 0 and world_size | |||
:param rank: rank id, non-negative interger within 0 and ``world_size`` | |||
:type seed: non-negative integer | |||
:param seed: seed for random operators | |||
""" | |||
if ( | |||
not isinstance(batch_size, int) | |||
or isinstance(batch_size, bool) | |||
or batch_size <= 0 | |||
): | |||
raise ValueError( | |||
"batch_size should be a positive integer value, " | |||
"but got batch_size={}".format(batch_size) | |||
) | |||
if not isinstance(drop_last, bool): | |||
raise ValueError( | |||
"drop_last should be a boolean value, but got " | |||
"drop_last={}".format(drop_last) | |||
) | |||
if num_samples is not None and ( | |||
not isinstance(num_samples, int) | |||
or isinstance(num_samples, bool) | |||
or num_samples <= 0 | |||
): | |||
raise ValueError( | |||
"num_samples should be a positive integer " | |||
"value, but got num_samples={}".format(num_samples) | |||
) | |||
self.batch_size = batch_size | |||
self.dataset = dataset | |||
self.drop_last = drop_last | |||
if world_size is None: | |||
world_size = dist.get_world_size() if dist.is_distributed() else 1 | |||
self.world_size = world_size | |||
if rank is None: | |||
rank = dist.get_rank() if dist.is_distributed() else 0 | |||
self.rank = rank | |||
if num_samples is None: | |||
num_samples = len(self.dataset) | |||
self.num_samples = int(math.ceil(num_samples / self.world_size)) | |||
# Make sure seeds are the same at each rank | |||
if seed is None and self.world_size > 1: | |||
seed = 0 | |||
self.rng = np.random.RandomState(seed) | |||
def __iter__(self) -> Union[Generator, Iterator]: | |||
return self.batch() | |||
def __len__(self) -> int: | |||
if self.drop_last: | |||
return self.num_samples // self.batch_size | |||
else: | |||
return int(math.ceil(self.num_samples / self.batch_size)) | |||
def sample(self): | |||
""" | |||
return a list contains all sample indices | |||
""" | |||
raise NotImplementedError | |||
def scatter(self, indices) -> List: | |||
r""" | |||
scatter method is used for splitting indices into subset, each subset | |||
will be assigned to a rank. Indices are evenly splitted by default. | |||
If customized indices assignment method is needed, please rewrite this method | |||
""" | |||
total_size = self.num_samples * self.world_size | |||
# add extra indices to make it evenly divisible | |||
indices += indices[: (total_size - len(indices))] | |||
assert len(indices) == total_size | |||
# subsample | |||
indices = indices[self.rank : total_size : self.world_size] | |||
assert len(indices) == self.num_samples | |||
return indices | |||
def batch(self) -> Iterator[List[Any]]: | |||
r""" | |||
batch method provides a batch indices generator | |||
""" | |||
indices = list(self.sample()) | |||
# user might pass the world_size parameter without dist, | |||
# so dist.is_distributed() should not be used | |||
if self.world_size > 1: | |||
indices = self.scatter(indices) | |||
step, length = self.batch_size, len(indices) | |||
batch_index = [indices[i : i + step] for i in range(0, length, step)] | |||
if self.drop_last and len(batch_index[-1]) < self.batch_size: | |||
batch_index.pop() | |||
return iter(batch_index) | |||
class SequentialSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
indices=None, | |||
world_size=None, | |||
rank=None, | |||
): | |||
r""" | |||
Sample elements sequentially | |||
""" | |||
super().__init__(dataset, batch_size, drop_last, None, world_size, rank) | |||
if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
raise ValueError( | |||
"indices should be None or a sequence, " | |||
"but got indices={}".format(indices) | |||
) | |||
self.indices = indices | |||
def sample(self) -> Iterator[Any]: | |||
r""" | |||
return a generator | |||
""" | |||
if self.indices is None: | |||
return iter(range(len(self.dataset))) | |||
else: | |||
return self.indices | |||
class RandomSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
indices=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
Sample elements randomly without replacement | |||
""" | |||
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) | |||
if indices is not None and not isinstance(indices, collections.abc.Sequence): | |||
raise ValueError( | |||
"indices should be None or a sequence, " | |||
"but got indices={}".format(indices) | |||
) | |||
self.indices = indices | |||
def sample(self) -> List: | |||
if self.indices is None: | |||
return self.rng.permutation(len(self.dataset)).tolist() | |||
else: | |||
return self.rng.permutation(self.indices).tolist() | |||
class ReplacementSampler(Sampler): | |||
def __init__( | |||
self, | |||
dataset, | |||
batch_size=1, | |||
drop_last=False, | |||
num_samples=None, | |||
weights=None, | |||
world_size=None, | |||
rank=None, | |||
seed=None, | |||
): | |||
r""" | |||
Sample elements randomly with replacement | |||
:type weights: List | |||
:param weights: weights for sampling indices, it could be unnormalized weights | |||
""" | |||
super().__init__( | |||
dataset, batch_size, drop_last, num_samples, world_size, rank, seed | |||
) | |||
if weights is not None: | |||
if not isinstance(weights, collections.abc.Sequence): | |||
raise ValueError( | |||
"weights should be None or a sequence, " | |||
"but got weights={}".format(weights) | |||
) | |||
if len(weights) != len(dataset): | |||
raise ValueError( | |||
"len(dataset)={} should be equal to" | |||
"len(weights)={}".format(len(dataset), len(weights)) | |||
) | |||
self.weights = weights | |||
if self.weights is not None: | |||
self.weights = np.array(weights) / sum(weights) | |||
def sample(self) -> List: | |||
n = len(self.dataset) | |||
if self.weights is None: | |||
return self.rng.randint(n, size=self.num_samples).tolist() | |||
else: | |||
return self.rng.multinomial(n, self.weights, self.num_samples).tolist() | |||
class Infinite(Sampler): | |||
r"""Infinite Sampler warper for basic sampler""" | |||
def sample(self): | |||
raise NotImplementedError("sample method not supported in Infinite") | |||
def __init__(self, sampler): | |||
self.sampler = sampler | |||
self.sampler_iter = iter(self.sampler) | |||
def __iter__(self): | |||
return self | |||
def __next__(self): | |||
try: | |||
index = next(self.sampler_iter) | |||
except StopIteration: | |||
self.sampler_iter = iter(self.sampler) | |||
index = next(self.sampler_iter) | |||
return index | |||
def __len__(self): | |||
return np.iinfo(np.int64).max |
@@ -0,0 +1,10 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .meta_transform import PseudoTransform, Transform | |||
from .vision import * |
@@ -0,0 +1,31 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABC, abstractmethod | |||
from typing import Sequence, Tuple | |||
class Transform(ABC): | |||
""" | |||
rewrite apply method in subclass | |||
""" | |||
def apply_batch(self, inputs: Sequence[Tuple]): | |||
return tuple(self.apply(input) for input in inputs) | |||
@abstractmethod | |||
def apply(self, input: Tuple): | |||
pass | |||
def __repr__(self): | |||
return self.__class__.__name__ | |||
class PseudoTransform(Transform): | |||
def apply(self, input: Tuple): | |||
return input |
@@ -0,0 +1,9 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .transform import * |
@@ -0,0 +1,111 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections.abc | |||
import functools | |||
import random | |||
import cv2 | |||
import numpy as np | |||
def wrap_keepdims(func): | |||
"""Wraper to keep the dimension of input images unchanged""" | |||
@functools.wraps(func) | |||
def wrapper(image, *args, **kwargs): | |||
if len(image.shape) != 3: | |||
raise ValueError( | |||
"image must have 3 dims, but got {} dims".format(len(image.shape)) | |||
) | |||
ret = func(image, *args, **kwargs) | |||
if len(ret.shape) == 2: | |||
ret = ret[:, :, np.newaxis] | |||
return ret | |||
return wrapper | |||
@wrap_keepdims | |||
def to_gray(image): | |||
r""" | |||
Change BGR format image's color space to gray | |||
:param image: Input BGR format image, with (H, W, C) shape | |||
:return: Gray format image, with (H, W, C) shape | |||
""" | |||
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |||
@wrap_keepdims | |||
def to_bgr(image): | |||
r""" | |||
Change gray format image's color space to BGR | |||
:param image: input Gray format image, with (H, W, C) shape | |||
:return: BGR format image, with (H, W, C) shape | |||
""" | |||
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |||
@wrap_keepdims | |||
def pad(input, size, value): | |||
r""" | |||
Pad input data with *value* and given *size* | |||
:param input: Input data, with (H, W, C) shape | |||
:param size: Padding size of input data, it could be integer or sequence. | |||
If it's an integer, the input data will be padded in four directions. | |||
If it's a sequence contains two integer, the bottom and right side | |||
of input data will be padded. | |||
If it's a sequence contains four integer, the top, bottom, left, right | |||
side of input data will be padded with given size. | |||
:param value: Padding value of data, could be a sequence of int or float. | |||
if it's float value, the dtype of image will be casted to float32 also. | |||
:return: Padded image | |||
""" | |||
if isinstance(size, int): | |||
size = (size, size, size, size) | |||
elif isinstance(size, collections.abc.Sequence) and len(size) == 2: | |||
size = (0, size[0], 0, size[1]) | |||
if np.array(value).dtype == float: | |||
input = input.astype(np.float32) | |||
return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value) | |||
@wrap_keepdims | |||
def flip(image, flipCode): | |||
r""" | |||
Accordding to the flipCode (the type of flip), flip the input image | |||
:param image: Input image, with (H, W, C) shape | |||
:param flipCode: code that indicates the type of flip. | |||
1 : Flip horizontally | |||
0 : Flip vertically | |||
-1 : Flip horizontally and vertically | |||
:return: BGR format image, with (H, W, C) shape | |||
""" | |||
return cv2.flip(image, flipCode=flipCode) | |||
@wrap_keepdims | |||
def resize(input, size, interpolation=cv2.INTER_LINEAR): | |||
r""" | |||
resize the input data to given size | |||
:param input: Input data, could be image or masks, with (H, W, C) shape | |||
:param size: Target size of input data, with (height, width) shape. | |||
:param interpolation: Interpolation method. | |||
:return: Resized data, with (H, W, C) shape | |||
""" | |||
if len(size) != 2: | |||
raise ValueError("resize needs (h, w), but got {}".format(size)) | |||
if isinstance(interpolation, collections.abc.Sequence): | |||
interpolation = random.choice(interpolation) | |||
return cv2.resize(input, size[::-1], interpolation=interpolation) |
@@ -0,0 +1,89 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
from .core._imperative_rt.common import CompNode, DeviceType | |||
__all__ = [ | |||
"is_cuda_available", | |||
"get_device_count", | |||
"get_default_device", | |||
"set_default_device", | |||
] | |||
_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||
def _valid_device(inp): | |||
if isinstance(inp, str) and len(inp) == 4: | |||
if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": | |||
if inp[3] == "x" or inp[3].isdigit(): | |||
return True | |||
return False | |||
def _str2device_type(type_str: str, allow_unspec: bool = True): | |||
type_str = type_str.upper() | |||
if type_str == "CPU": | |||
return DeviceType.CPU | |||
elif type_str == "GPU" or type_str == "CUDA": | |||
return DeviceType.CUDA | |||
else: | |||
assert allow_unspec and str == "XPU", "bad device type" | |||
return DeviceType.UNSPEC | |||
def get_device_count(device_type: str) -> int: | |||
"""Gets number of devices installed on this system. | |||
:param device_type: device type, one of 'gpu' or 'cpu' | |||
""" | |||
device_type_set = ("cpu", "gpu") | |||
assert device_type in device_type_set, "device must be one of {}".format( | |||
device_type_set | |||
) | |||
device_type = _str2device_type(device_type) | |||
return CompNode._get_device_count(device_type, False) | |||
def is_cuda_available() -> bool: | |||
"""Returns whether cuda device is available on this system. | |||
""" | |||
t = _str2device_type("gpu") | |||
return CompNode._get_device_count(t, False) > 0 | |||
def set_default_device(device: str = "xpux"): | |||
r"""Sets default computing node. | |||
:param device: default device type. The type can be 'cpu0', 'cpu1', etc., | |||
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. | |||
'cpux' and 'gpux' can also be used to specify any number of cpu or gpu devices. | |||
'multithread' device type is avaliable when inference, which implements | |||
multi-threading parallelism at the operator level. For example, | |||
'multithread4' will compute with 4 threads. which implements | |||
The default value is 'xpux' to specify any device available. | |||
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | |||
""" | |||
global _default_device # pylint: disable=global-statement | |||
assert _valid_device(device), "Invalid device name {}".format(device) | |||
_default_device = device | |||
def get_default_device() -> str: | |||
r"""Gets default computing node. | |||
It returns the value set by :func:`~.set_default_device`. | |||
""" | |||
return _default_device |
@@ -0,0 +1,25 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .group import ( | |||
WORLD, | |||
get_backend, | |||
get_client, | |||
get_mm_server_addr, | |||
get_py_server_addr, | |||
get_rank, | |||
get_world_size, | |||
group_barrier, | |||
init_process_group, | |||
is_distributed, | |||
new_group, | |||
) | |||
from .helper import synchronized | |||
from .launcher import launcher | |||
from .server import Client, Server | |||
from .util import get_free_ports |
@@ -0,0 +1,176 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import List, Optional, Tuple | |||
from ..device import set_default_device | |||
from .server import Client, Server | |||
class StaticData: | |||
server = None | |||
client = None | |||
master_ip = None | |||
py_server_port = None | |||
mm_server_port = None | |||
world_size = None | |||
proc_rank = None | |||
device = None | |||
backend = None | |||
next_stream = None | |||
_sd = None | |||
class Group: | |||
def __init__(self, proc_ranks): | |||
if len(proc_ranks) == 0: # empty group | |||
self.proc_ranks = None | |||
self.stream = None | |||
else: | |||
self.reset(proc_ranks) | |||
def reset(self, proc_ranks): | |||
self.check(proc_ranks) | |||
self.proc_ranks = proc_ranks | |||
self.stream = _sd.next_stream | |||
_sd.next_stream += 1 | |||
def check(self, proc_ranks): | |||
assert _sd is not None, "please call init_process_group first" | |||
for rank in proc_ranks: | |||
assert isinstance(rank, int) | |||
assert rank >= 0 and rank < _sd.world_size | |||
assert _sd.proc_rank in proc_ranks | |||
@property | |||
def size(self): | |||
assert len(self.proc_ranks) > 0, "invalid group" | |||
return len(self.proc_ranks) | |||
@property | |||
def key(self): | |||
assert len(self.proc_ranks) > 0, "invalid group" | |||
return ",".join(map(str, self.proc_ranks)) | |||
@property | |||
def rank(self): | |||
assert len(self.proc_ranks) > 0, "invalid group" | |||
return self.proc_ranks.index(_sd.proc_rank) | |||
@property | |||
def comp_node(self): | |||
assert len(self.proc_ranks) > 0, "invalid group" | |||
return "gpu{}:{}".format(_sd.device, self.stream) | |||
WORLD = Group([]) | |||
def init_process_group( | |||
master_ip: str, | |||
port: int, | |||
world_size: int, | |||
rank: int, | |||
device: int, | |||
backend: Optional[str] = "nccl", | |||
) -> None: | |||
"""Initialize the distributed process group and specify the device used in the current process | |||
:param master_ip: IP address of the master node | |||
:param port: Port available for all processes to communicate | |||
:param world_size: Total number of processes participating in the job | |||
:param rank: Rank of the current process | |||
:param device: The GPU device id to bind this process to | |||
:param backend: Communicator backend, currently support 'nccl' and 'ucx' | |||
""" | |||
if not isinstance(master_ip, str): | |||
raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
if not isinstance(port, int): | |||
raise TypeError("Expect type int but got {}".format(type(port))) | |||
if not isinstance(world_size, int): | |||
raise TypeError("Expect type int but got {}".format(type(world_size))) | |||
if not isinstance(rank, int): | |||
raise TypeError("Expect type int but got {}".format(type(rank))) | |||
if not isinstance(device, int): | |||
raise TypeError("Expect type int but got {}".format(type(backend))) | |||
if not isinstance(backend, str): | |||
raise TypeError("Expect type str but got {}".format(type(backend))) | |||
global _sd | |||
assert _sd is None, "init_process_group should be called only once" | |||
_sd = StaticData() | |||
assert world_size > 1 | |||
assert rank >= 0 and rank < world_size | |||
assert port > 0 | |||
_sd.client = Client(master_ip, port) | |||
_sd.master_ip = master_ip | |||
_sd.py_server_port = port | |||
_sd.mm_server_port = _sd.client.get_mm_server_port() | |||
_sd.world_size = world_size | |||
_sd.proc_rank = rank | |||
_sd.device = device | |||
_sd.backend = backend | |||
_sd.next_stream = 1 | |||
WORLD.reset(list(range(world_size))) | |||
set_default_device("gpu{}".format(device)) | |||
def is_distributed() -> bool: | |||
"""Return True if the distributed process group has been initialized""" | |||
return _sd is not None | |||
def get_rank() -> int: | |||
"""Get the rank of the current process""" | |||
return _sd.proc_rank if _sd is not None else 0 | |||
def get_world_size() -> int: | |||
"""Get the total number of processes participating in the job""" | |||
return _sd.world_size if _sd is not None else 1 | |||
def get_backend() -> str: | |||
"""Get the backend str""" | |||
assert _sd is not None, "please call init_process_group first" | |||
return _sd.backend if _sd is not None else None | |||
def get_py_server_addr() -> Tuple[str, int]: | |||
"""Get master_ip and port of python XML RPC server""" | |||
assert _sd is not None, "please call init_process_group first" | |||
return _sd.master_ip, _sd.py_server_port | |||
def get_mm_server_addr() -> Tuple[str, int]: | |||
"""Get master_ip and port of C++ mm_server""" | |||
assert _sd is not None, "please call init_process_group first" | |||
return _sd.master_ip, _sd.mm_server_port | |||
def get_client() -> Client: | |||
"""Get client of python XML RPC server""" | |||
assert _sd is not None, "please call init_process_group first" | |||
return _sd.client | |||
def new_group(proc_ranks: List[int]) -> Group: | |||
"""Build a subgroup containing certain ranks""" | |||
return Group(proc_ranks) | |||
def group_barrier(group: Optional[Group] = WORLD) -> None: | |||
"""Block until all ranks in the group reach this barrier""" | |||
assert isinstance(group, Group) | |||
_sd.client.group_barrier(group.key, group.size) |
@@ -0,0 +1,28 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
from typing import Callable | |||
from .group import group_barrier, is_distributed | |||
def synchronized(func: Callable): | |||
"""Decorator. Decorated function will synchronize when finished. | |||
Specifically, we use this to prevent data race during hub.load""" | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
if not is_distributed(): | |||
return func(*args, **kwargs) | |||
ret = func(*args, **kwargs) | |||
group_barrier() | |||
return ret | |||
return wrapper |
@@ -0,0 +1,68 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import multiprocessing as mp | |||
from ..device import get_device_count | |||
from .group import init_process_group | |||
from .server import Server | |||
from .util import get_free_ports | |||
def _get_device_count(): | |||
"""use subprocess to avoid cuda environment initialization in the main process""" | |||
def run(q): | |||
count = get_device_count("gpu") | |||
q.put(count) | |||
q = mp.Queue() | |||
p = mp.Process(target=run, args=(q,)) | |||
p.start() | |||
p.join() | |||
return q.get() | |||
def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): | |||
"""init distributed process group and run wrapped function""" | |||
init_process_group( | |||
master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev | |||
) | |||
func(*args, **kwargs) | |||
def launcher(n_gpus): | |||
"""decorator for launching multiple processes in single-machine multi-gpu training""" | |||
count = _get_device_count() | |||
assert isinstance(n_gpus, int) and n_gpus > 1, "invalid n_gpus" | |||
assert n_gpus <= count, "{} gpus required, {} gpus provided".format(n_gpus, count) | |||
def decorator(func): | |||
def wrapper(*args, **kwargs): | |||
master_ip = "localhost" | |||
port = get_free_ports(1)[0] | |||
server = Server(port) | |||
procs = [] | |||
for rank in range(n_gpus): | |||
p = mp.Process( | |||
target=_run_wrapped, | |||
args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), | |||
) | |||
p.start() | |||
procs.append(p) | |||
for rank in range(n_gpus): | |||
procs[rank].join() | |||
code = procs[rank].exitcode | |||
assert code == 0, "subprocess {} exit with code {}".format(rank, code) | |||
return wrapper | |||
return decorator |
@@ -0,0 +1,170 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import multiprocessing as mp | |||
import threading | |||
import time | |||
from collections import defaultdict | |||
from functools import partial | |||
from socketserver import ThreadingMixIn | |||
from xmlrpc.client import ServerProxy | |||
from xmlrpc.server import SimpleXMLRPCServer | |||
from ..core._imperative_rt.utils import create_mm_server | |||
from .util import get_free_ports | |||
class Future: | |||
def __init__(self, ack=True): | |||
self.ready = threading.Event() | |||
self.ack = threading.Event() if ack else None | |||
def set(self, value): | |||
self.value = value | |||
self.ready.set() | |||
if self.ack: | |||
self.ack.wait() | |||
def get(self): | |||
self.ready.wait() | |||
if self.ack: | |||
self.ack.set() | |||
return self.value | |||
class Methods: | |||
def __init__(self, mm_server_port): | |||
self.lock = threading.Lock() | |||
self.mm_server_port = mm_server_port | |||
self.dict_is_grad = defaultdict(partial(Future, True)) | |||
self.dict_remote_tracer = defaultdict(partial(Future, True)) | |||
self.dict_pack_list = defaultdict(partial(Future, False)) | |||
self.dict_barrier_counter = defaultdict(int) | |||
self.dict_barrier_event = defaultdict(threading.Event) | |||
def connect(self): | |||
return True | |||
def get_mm_server_port(self): | |||
return self.mm_server_port | |||
def set_is_grad(self, rank_peer, is_grad): | |||
with self.lock: | |||
future = self.dict_is_grad[rank_peer] | |||
future.set(is_grad) | |||
return True | |||
def check_is_grad(self, rank_peer): | |||
with self.lock: | |||
future = self.dict_is_grad[rank_peer] | |||
ret = future.get() | |||
with self.lock: | |||
del self.dict_is_grad[rank_peer] | |||
return ret | |||
def set_remote_tracer(self, rank_peer, tracer_set): | |||
with self.lock: | |||
future = self.dict_remote_tracer[rank_peer] | |||
future.set(tracer_set) | |||
return True | |||
def check_remote_tracer(self, rank_peer): | |||
with self.lock: | |||
future = self.dict_remote_tracer[rank_peer] | |||
ret = future.get() | |||
with self.lock: | |||
del self.dict_remote_tracer[rank_peer] | |||
return ret | |||
def set_pack_list(self, key, pack_list): | |||
with self.lock: | |||
future = self.dict_pack_list[key] | |||
future.set(pack_list) | |||
return True | |||
def get_pack_list(self, key): | |||
with self.lock: | |||
future = self.dict_pack_list[key] | |||
return future.get() | |||
def group_barrier(self, key, size): | |||
with self.lock: | |||
self.dict_barrier_counter[key] += 1 | |||
counter = self.dict_barrier_counter[key] | |||
event = self.dict_barrier_event[key] | |||
if counter == size: | |||
del self.dict_barrier_counter[key] | |||
del self.dict_barrier_event[key] | |||
event.set() | |||
else: | |||
event.wait() | |||
return True | |||
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
pass | |||
def start_server(py_server_port, mm_server_port): | |||
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||
server.register_instance(Methods(mm_server_port)) | |||
server.serve_forever() | |||
class Server: | |||
def __init__(self, port): | |||
self.py_server_port = get_free_ports(1)[0] if port == 0 else port | |||
self.mm_server_port = create_mm_server("0.0.0.0", 0) | |||
self.proc = mp.Process( | |||
target=start_server, | |||
args=(self.py_server_port, self.mm_server_port), | |||
daemon=True, | |||
) | |||
self.proc.start() | |||
class Client: | |||
def __init__(self, master_ip, port): | |||
self.master_ip = master_ip | |||
self.port = port | |||
self.connect() | |||
def connect(self): | |||
while True: | |||
try: | |||
self.proxy = ServerProxy( | |||
"http://{}:{}".format(self.master_ip, self.port) | |||
) | |||
if self.proxy.connect(): | |||
break | |||
except: | |||
time.sleep(1) | |||
def get_mm_server_port(self): | |||
return self.proxy.get_mm_server_port() | |||
def set_is_grad(self, rank_peer, is_grad): | |||
self.proxy.set_is_grad(rank_peer, is_grad) | |||
def check_is_grad(self, rank_peer): | |||
return self.proxy.check_is_grad(rank_peer) | |||
def set_remote_tracer(self, rank_peer, tracer_set): | |||
self.proxy.set_remote_tracer(rank_peer, tracer_set) | |||
def check_remote_tracer(self, rank_peer): | |||
return self.proxy.check_remote_tracer(rank_peer) | |||
def set_pack_list(self, key, pack_list): | |||
self.proxy.set_pack_list(key, pack_list) | |||
def get_pack_list(self, key): | |||
return self.proxy.get_pack_list(key) | |||
def group_barrier(self, key, size): | |||
self.proxy.group_barrier(key, size) |
@@ -0,0 +1,25 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import socket | |||
from typing import List | |||
def get_free_ports(num: int) -> List[int]: | |||
"""Get one or more free ports. | |||
""" | |||
socks, ports = [], [] | |||
for i in range(num): | |||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
sock.bind(("", 0)) | |||
socks.append(sock) | |||
ports.append(sock.getsockname()[1]) | |||
for sock in socks: | |||
sock.close() | |||
return ports |
@@ -0,0 +1,32 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=redefined-builtin | |||
from . import distributed | |||
from .elemwise import * | |||
from .graph import add_update | |||
from .loss import ( | |||
binary_cross_entropy, | |||
cross_entropy, | |||
cross_entropy_with_softmax, | |||
hinge_loss, | |||
l1_loss, | |||
nll_loss, | |||
smooth_l1_loss, | |||
square_loss, | |||
triplet_margin_loss, | |||
) | |||
from .math import * | |||
from .nn import * | |||
from .quantized import conv_bias_activation | |||
from .tensor import * | |||
from .utils import accuracy, zero_grad | |||
# delete namespace | |||
# pylint: disable=undefined-variable | |||
# del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] |
@@ -0,0 +1,49 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
_conv_execution_strategy = os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY", "HEURISTIC") | |||
def get_conv_execution_strategy() -> str: | |||
"""Returns the execuation strategy of :class:`~.Conv2d`. | |||
See :func:`~.set_conv_execution_strategy` for possible return values | |||
""" | |||
return _conv_execution_strategy | |||
def set_conv_execution_strategy(option: str): | |||
"""Sets the execuation strategy of :class:`~.Conv2d`. | |||
:param option: Decides how :class:`~.Conv2d` algorithm is chosen. | |||
Available values: | |||
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. | |||
* 'PROFILE' runs possible algorithms on real device to find the best. | |||
* 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. | |||
* 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. | |||
* 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. | |||
The default strategy is 'HEURISTIC'. | |||
It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. | |||
""" | |||
valid_option = ( | |||
"HEURISTIC", | |||
"PROFILE", | |||
"PROFILE_HEURISTIC", | |||
"PROFILE_REPRODUCIBLE", | |||
"HEURISTIC_REPRODUCIBLE", | |||
) | |||
if not option in valid_option: | |||
raise ValueError("Valid option can only be one of {}".format(valid_option)) | |||
global _conv_execution_strategy # pylint: disable=global-statement | |||
_conv_execution_strategy = option |
@@ -0,0 +1,299 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Tuple | |||
from ..core._imperative_rt.ops import CollectiveCommDefModeEnum | |||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
from ..core.autodiff.grad import ( | |||
Tracer, | |||
check_backward_allow_noinput, | |||
get_grad_managers, | |||
get_op_has_grad_fn, | |||
tracer_apply, | |||
) | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.core import apply | |||
from ..core.tensor.tensor import Tensor, tensor_apply | |||
from ..distributed.group import ( | |||
WORLD, | |||
Group, | |||
get_backend, | |||
get_client, | |||
get_mm_server_addr, | |||
get_rank, | |||
) | |||
from ..tensor import tensor | |||
__all__ = [ | |||
"reduce_sum", | |||
"broadcast", | |||
"all_gather", | |||
"reduce_scatter_sum", | |||
"all_reduce_sum", | |||
"all_reduce_max", | |||
"all_reduce_min", | |||
"gather", | |||
"scatter", | |||
"all_to_all", | |||
"remote_send", | |||
"remote_recv", | |||
] | |||
@apply.add | |||
def _(op: RemoteSend, *args: Tensor): | |||
ret = tensor_apply(op, *args) | |||
# set extra information | |||
tracer_set = dict() | |||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
tracer_set[k.name] = True | |||
# check tracer_set in remote_recv | |||
get_client().set_remote_tracer(op.key, tracer_set) | |||
return ret | |||
@builtin_op_get_backward_fn.register(RemoteSend) | |||
def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
def backward(*args): | |||
return [ | |||
remote_recv( | |||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
) | |||
] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
def has_grad(opnode, reached): | |||
return get_client().check_is_grad(op.key) | |||
return has_grad | |||
@check_backward_allow_noinput.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
return True | |||
@builtin_op_get_backward_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||
def backward(*output_grads): | |||
return [remote_send(output_grads[0], op.rank_from)] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv): | |||
def has_grad(opnode, reached): | |||
ret = False | |||
for v in opnode.outputs: | |||
if v() in reached: | |||
ret = True | |||
break | |||
get_client().set_is_grad(op.key, ret) | |||
return ret | |||
return has_grad | |||
def collective_comm(inp, mode, group, device): | |||
"""Helper function for applying collective communication functions""" | |||
assert isinstance(group, Group) | |||
if group is None: | |||
return inp | |||
op = CollectiveComm() | |||
op.key = group.key | |||
op.nr_devices = group.size | |||
op.rank = group.rank | |||
op.is_root = op.rank == 0 | |||
op.local_grad = False | |||
op.addr, op.port = get_mm_server_addr() | |||
op.mode = mode | |||
op.dtype = inp.dtype | |||
op.backend = get_backend() | |||
op.comp_node = device | |||
return apply(op, inp)[0] | |||
def reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def broadcast( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create broadcast operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.BROADCAST | |||
return collective_comm(inp, mode, group, device) | |||
def all_gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.ALL_GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def reduce_scatter_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_scatter_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.REDUCE_SCATTER_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.ALL_REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_max( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_max operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.ALL_REDUCE_MAX | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_min( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_min operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.ALL_REDUCE_MIN | |||
return collective_comm(inp, mode, group, device) | |||
def gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def scatter( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create scatter operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.SCATTER | |||
return collective_comm(inp, mode, group, device) | |||
def all_to_all( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_to_all operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommDefModeEnum.ALL_TO_ALL | |||
return collective_comm(inp, mode, group, device) | |||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
"""Send a Tensor to a remote process | |||
:param inp: tensor to send | |||
:param dest_rank: destination process rank | |||
""" | |||
op = RemoteSend() | |||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
return apply(op, inp)[0] | |||
def remote_recv( | |||
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" | |||
) -> Tensor: | |||
"""Receive a Tensor from a remote process | |||
:param src_rank: source process rank | |||
:param shape: the shape of the tensor to receive | |||
:param dtype: the data type of the tensor to receive | |||
:param cn: the comp node to place the received tensor | |||
""" | |||
key = "{}->{}".format(src_rank, get_rank()) | |||
# dummpy input | |||
inp = tensor([0]) | |||
tracer_set = get_client().check_remote_tracer(key) | |||
for grad_manager in get_grad_managers(): | |||
if grad_manager.name in tracer_set: | |||
grad_manager.wrt(inp) | |||
op = RemoteRecv() | |||
op.key = key | |||
op.cn = cn | |||
op.shape = shape | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
return apply(op, inp)[0] |
@@ -0,0 +1,481 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
import functools | |||
from ..core.ops import builtin | |||
from ..core.tensor import utils | |||
from ..core.tensor.core import apply | |||
from ..tensor import Tensor | |||
__all__ = [ | |||
"abs", | |||
"add", | |||
"acos", | |||
"asin", | |||
"atan", | |||
"atan2", | |||
"asinh", | |||
"acosh", | |||
"atanh", | |||
"bitwise_and", # TODO | |||
"bitwise_not", # TODO | |||
"bitwise_or", # TODO | |||
"bitwise_xor", # TODO | |||
"ceil", | |||
"clamp", | |||
"cos", | |||
"cosh", | |||
"div", | |||
"eq", | |||
"exp", | |||
"expm1", | |||
"floor", | |||
"floor_div", | |||
"gt", | |||
"ge", | |||
"hswish", | |||
"hsigmoid", | |||
"left_shift", | |||
"lt", | |||
"le", | |||
"log", | |||
"log1p", | |||
"logical_and", | |||
"logical_not", | |||
"logical_or", | |||
"logical_xor", | |||
"maximum", | |||
"minimum", | |||
"mod", | |||
"mul", | |||
"neg", | |||
"ne", | |||
"pow", | |||
"relu", | |||
"relu6", | |||
"right_shift", | |||
"round", | |||
"sigmoid", | |||
"sin", | |||
"sinh", | |||
"sqrt", | |||
"square", | |||
"sub", | |||
"tan", | |||
"tanh", | |||
"fast_tanh", | |||
] | |||
def _elwise(*args, mode): | |||
op = builtin.Elemwise(mode=mode) | |||
args = utils.convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
def _logical(*args, mode): | |||
op = builtin.CondExecPredLogical(mode=mode) | |||
args = utils.convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
def _elemwise_multi_type(*args, mode, **kwargs): | |||
op = builtin.ElemwiseMultiType(mode=mode, **kwargs) | |||
args = utils.convert_inputs(*args) | |||
(result,) = apply(op, *args) | |||
return result | |||
# math operations | |||
def add(x, y): | |||
"""Element-wise addition. | |||
At least one operand should be tensor. | |||
same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. | |||
""" | |||
return _elwise(x, y, mode="add") | |||
def sub(x, y): | |||
"""Element-wise subtract.""" | |||
return _elwise(x, y, mode="sub") | |||
def mul(x, y): | |||
"""Element-wise multiplication.""" | |||
return _elwise(x, y, mode="mul") | |||
def div(x, y): | |||
"""Element-wise (x / y).""" | |||
return _elwise(x, y, mode="true_div") | |||
def floor_div(x, y): | |||
"""Element-wise floor(x / y).""" | |||
return _elwise(x, y, mode="floor_divide") | |||
def neg(x): | |||
"""Element-wise negation.""" | |||
return _elwise(x, mode="negate") | |||
def pow(x, y): | |||
"""Element-wise power.""" | |||
return _elwise(x, y, mode="pow") | |||
def mod(x, y): | |||
"""Element-wise remainder of division.""" | |||
return _elwise(x, y, mode="mod") | |||
def abs(x): | |||
"""Element-wise absolute value.""" | |||
return _elwise(x, mode="abs") | |||
def exp(x): | |||
"""Element-wise exponential.""" | |||
return _elwise(x, mode="exp") | |||
def expm1(x): | |||
"""Element-wise exp(x)-1.""" | |||
return _elwise(x, mode="expm1") | |||
def log(x): | |||
"""Element-wise logarithm (base `e`).""" | |||
return _elwise(x, mode="log") | |||
def log1p(x): | |||
"""Element-wise log(x+1) (base `e`).""" | |||
return _elwise(x, mode="log1p") | |||
def sqrt(inp: Tensor) -> Tensor: | |||
""" | |||
Return a new tensor with the square-root of the elements of ``inp``. | |||
For negative value, return nan. | |||
:param inp: The input tensor | |||
:return: The computed tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.sqrt(data) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0. 1. 1.4142] | |||
[1.7321 2. 2.2361 ]] | |||
""" | |||
return inp ** 0.5 | |||
def square(inp: Tensor) -> Tensor: | |||
""" | |||
Return a new tensor with the square of the elements of ``inp`` | |||
:param inp: The input tensor | |||
:return: The computed tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.square(data) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0. 1. 4.] | |||
[9. 16. 25.]] | |||
""" | |||
return inp ** 2 | |||
def round(x): | |||
"""Round tensor to int element-wise.""" | |||
return _elwise(x, mode="round") | |||
def ceil(x): | |||
"""Return the ceil of the input, element-wise.""" | |||
return _elwise(x, mode="ceil") | |||
def floor(x): | |||
"""Calculate the floor element-wise""" | |||
return _elwise(x, mode="floor") | |||
# trigonometric functions | |||
def cos(x): | |||
"""Cosine, element-wise.""" | |||
return _elwise(x, mode="cos") | |||
def sin(x): | |||
"""Sine, element-wise.""" | |||
return _elwise(x, mode="sin") | |||
def tan(x): | |||
return sin(x) / cos(x) | |||
def acos(x): | |||
"""Inverse cosine, element-wise.""" | |||
return _elwise(x, mode="acos") | |||
def asin(x): | |||
"""Inverse sine, element-wise.""" | |||
return _elwise(x, mode="asin") | |||
def atan(x): | |||
return _elwise(x, 1, mode="atan2") | |||
def atan2(y, x): | |||
return _elwise(y, x, mode="atan2") | |||
def cosh(x): | |||
r"""Compute element-wise hyperbolic cosine.""" | |||
return 0.5 * (exp(x) + exp(-x)) | |||
def sinh(x): | |||
r"""Compute element-wise hyperbolic sine.""" | |||
u = expm1(x) | |||
return 0.5 * u / (u + 1) * (u + 2) | |||
def tanh(x): | |||
r"""Compute element-wise hyperbolic tangent.""" | |||
return _elwise(x, mode="tanh") | |||
def asinh(x): | |||
r"""Compute element-wise inverse hyperbolic sine.""" | |||
return log(x + (x ** 2 + 1) ** 0.5) | |||
def acosh(x): | |||
r"""Compute element-wise inverse hyperbolic cosine.""" | |||
return log(x + (x ** 2 - 1) ** 0.5) | |||
def atanh(x): | |||
r"""Compute element-wise inverse hyperbolic tangent.""" | |||
return log1p(2 * x / (1 - x)) / 2 | |||
def fast_tanh(x): | |||
r"""Compute element-wise fast tanh; this is an approximation: | |||
.. math:: | |||
\text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) | |||
""" | |||
return _elwise(x, mode="fast_tanh") | |||
# bit-twiddling functions | |||
def left_shift(x, y): | |||
return _elwise(x, y, mode="shl") | |||
def right_shift(x, y): | |||
return _elwise(x, y, mode="shl") | |||
def bitwise_and(x, y): | |||
raise NotImplementedError | |||
def bitwise_not(x): | |||
raise NotImplementedError | |||
def bitwise_or(x, y): | |||
raise NotImplementedError | |||
def bitwise_xor(x, y): | |||
raise NotImplementedError | |||
# logical functions | |||
def logical_and(x, y): | |||
return _elwise(x, y, mode="AND") | |||
def logical_not(x): | |||
return _elwise(x, mode="NOT") | |||
def logical_or(x, y): | |||
return _elwise(x, y, mode="OR") | |||
def logical_xor(x, y): | |||
return _elwise(x, y, mode="XOR") | |||
# comparison functions | |||
def eq(x, y): | |||
"""Return (x == y) element-wise.""" | |||
return _elwise(x, y, mode="eq") | |||
def ne(x, y): | |||
return x != y | |||
def lt(x, y): | |||
"""Return (x < y) element-wise.""" | |||
return _elwise(x, y, mode="lt") | |||
def le(x, y): | |||
"""Return (x =< y) element-wise.""" | |||
return _elwise(x, y, mode="leq") | |||
def gt(x, y): | |||
"""Return (x > y) element-wise.""" | |||
return _elwise(y, x, mode="lt") | |||
def ge(x, y): | |||
"""Return (x >= y) element-wise""" | |||
return _elwise(y, x, mode="leq") | |||
def hswish(x): | |||
"""Return x * relu6(x + 3) / 6 element-wise""" | |||
return _elwise(x, mode="h_swish") | |||
def hsigmoid(x): | |||
"""Return relu6(x + 3) / 6 element-wise""" | |||
return relu6(x + 3) / 6 | |||
def relu(x): | |||
"""Return `max(x, 0)` element-wise.""" | |||
return _elwise(x, mode="relu") | |||
def relu6(x): | |||
"""Return min(max(x, 0), 6) element-wise.""" | |||
return minimum(maximum(x, 0), 6) | |||
def sigmoid(x): | |||
"""Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||
return _elwise(x, mode="sigmoid") | |||
def maximum(x, y): | |||
"""Element-wise maximum of array elements.""" | |||
return _elwise(x, y, mode="max") | |||
def minimum(x, y): | |||
"""Element-wise minimum of array elements.""" | |||
return _elwise(x, y, mode="min") | |||
def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
r""" | |||
Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||
a resulting tensor: | |||
.. math:: | |||
y_i = \begin{cases} | |||
\text{lower} & \text{if } x_i < \text{lower} \\ | |||
x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\ | |||
\text{upper} & \text{if } x_i > \text{upper} | |||
\end{cases} | |||
:param inp: the input tensor. | |||
:param lower: lower-bound of the range to be clamped to | |||
:param upper: upper-bound of the range to be clamped to | |||
Example: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
a = tensor(np.arange(5).astype(np.int32)) | |||
print(F.clamp(a, 2, 4).numpy()) | |||
print(F.clamp(a, lower=3).numpy()) | |||
print(F.clamp(a, upper=3).numpy()) | |||
.. testoutput:: | |||
[2 2 2 3 4] | |||
[3 3 3 3 4] | |||
[0 1 2 3 3] | |||
""" | |||
assert ( | |||
lower is not None or upper is not None | |||
), "At least one of 'lower' or 'upper' must not be None" | |||
if lower is not None: | |||
if upper is not None: | |||
assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||
return minimum(maximum(inp, lower), upper) | |||
else: | |||
return maximum(inp, lower) | |||
else: | |||
return minimum(inp, upper) |
@@ -0,0 +1,44 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=too-many-lines | |||
from typing import List | |||
from ..core import Tensor | |||
def cambricon_subgraph( | |||
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||
) -> List[Tensor]: | |||
"""Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
execute the operations defined in the subgraph. | |||
:param inputs: List of input tensors of the subgraph. | |||
:param data: The serialized subgraph. | |||
:param symbol: The name of the function in the subgraph. | |||
The function is corresponding to a cnmlFusionOp | |||
which is added to the cnmlModel_t/cnrtModel_t. | |||
:param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||
in cnrtModel_t | |||
""" | |||
raise NotImplementedError | |||
def extern_opr_subgraph( | |||
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||
) -> List[Tensor]: | |||
"""Load a serialized extern opr subgraph and fake execute the operator | |||
:param inputs: Tensor or list of input tensors. | |||
:param output_shapes: The output shapes. | |||
:param dump_name: The serialized subgraph name. | |||
:param dump_data: The serialized subgraph. | |||
:return: List of tensors | |||
""" | |||
raise NotImplementedError |
@@ -0,0 +1,41 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Iterable, Optional, Union | |||
from ..core.tensor import Tensor | |||
def add_update( | |||
dest: Tensor, | |||
delta: Tensor, | |||
*, | |||
alpha: Union[Tensor, float, int] = 1.0, | |||
beta: Union[Tensor, float, int] = 1.0, | |||
bias: Union[Tensor, float, int] = 0.0 | |||
): | |||
r"""Inplace modify ``dest`` as follows: | |||
.. math:: | |||
dest = alpha * dest + beta * delta + bias | |||
:param dest: input data that will be inplace modified. | |||
:param delta: update value that will be added to ``dest``. | |||
:param alpha: weight ratio of ``dest``. Default: 1.0 | |||
:param beta: weight ratio of ``delta``. Default: 1.0 | |||
:param bias: bias value appended to the result. Default: 0.0 | |||
""" | |||
if beta is not None and beta != 1.0: | |||
delta = delta * beta | |||
if bias is not None and bias != 0.0: | |||
delta = delta + bias | |||
if alpha is not None and alpha != 1.0: | |||
dest *= alpha | |||
dest += delta | |||
return dest |
@@ -0,0 +1,388 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..tensor import Tensor | |||
from .elemwise import abs, eq, exp, log, maximum, pow, relu | |||
from .nn import assert_equal, indexing_one_hot | |||
from .tensor import where | |||
from .utils import zero_grad | |||
def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean absolute error (MAE) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean absolute error can be described as: | |||
.. math:: \ell(x,y) = mean\left(L \right) | |||
where | |||
.. math:: | |||
L = \{l_1,\dots,l_N\}, \quad | |||
l_n = \left| x_n - y_n \right|, | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
loss = F.l1_loss(ipt,tgt) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[2.75] | |||
""" | |||
diff = pred - label | |||
return abs(diff).mean() | |||
def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean squared error (squared L2 norm) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean squared error can be described as: | |||
.. math:: \ell(x, y) = mean\left( L \right) | |||
where | |||
.. math:: | |||
L = \{l_1,\dots,l_N\}, \quad | |||
l_n = \left( x_n - y_n \right)^2, | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Shape: | |||
- pred: :math:`(N, *)` where :math:`*` means any number of additional | |||
dimensions | |||
- label: :math:`(N, *)`. Same shape as ``pred`` | |||
""" | |||
diff = pred - label | |||
return (diff ** 2).mean() | |||
def cross_entropy( | |||
inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||
) -> Tensor: | |||
r""" | |||
Returns the cross entropy loss in a classification problem. | |||
.. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||
:param inp: The input tensor representing the predicted probability. | |||
:param label: The input tensor representing the classification label. | |||
:param axis: An axis along which cross_entropy will be applied. Default: 1 | |||
:param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data_shape = (1, 2) | |||
label_shape = (1, ) | |||
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||
label = tensor(np.ones(label_shape, dtype=np.int32)) | |||
loss = F.cross_entropy(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.69] | |||
""" | |||
raise NotImplementedError | |||
# n0 = inp.ndim | |||
# n1 = target.ndim | |||
# assert n0 == n1 + 1, ( | |||
# "target ndim must be one less than input ndim; input_ndim={} " | |||
# "target_ndim={}".format(n0, n1) | |||
# ) | |||
# if ignore_index != -1: | |||
# mask = 1 - equal(target, ignore_index) | |||
# target = target * mask | |||
# loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||
# return loss.sum() / maximum(mask.sum(), 1.0) | |||
# else: | |||
# return -log(indexing_one_hot(inp, target, axis)).mean() | |||
def cross_entropy_with_softmax( | |||
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||
) -> Tensor: | |||
r""" | |||
Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||
When using label smoothing, the label distribution is as follows: | |||
.. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K | |||
where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||
k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||
:param pred: The input tensor representing the predicted probability. | |||
:param label: The input tensor representing the classification label. | |||
:param axis: An axis along which softmax will be applied. Default: 1. | |||
:param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||
""" | |||
n0 = pred.ndim | |||
n1 = label.ndim | |||
assert n0 == n1 + 1, ( | |||
"target ndim must be one less than input ndim; input_ndim={} " | |||
"target_ndim={}".format(n0, n1) | |||
) | |||
num_classes = pred.shape[axis] | |||
# Denominator of the softmax | |||
offset = pred.max(axis=axis).detach() | |||
pred = pred - offset | |||
down = exp(pred).sum(axis=axis) | |||
up = pred[np.arange(pred.shape[0]), label] | |||
if label_smooth != 0: | |||
factor = label_smooth / num_classes | |||
up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor | |||
return (log(down) - up).mean() | |||
def triplet_margin_loss( | |||
anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||
) -> Tensor: | |||
r""" | |||
Creates a criterion that measures the triplet loss given an input tensors. | |||
.. math:: | |||
L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||
d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||
:param anchor: The input tensor representing the anchor samples. | |||
:param positive: The input tensor representing the positive samples. | |||
:param negative: The input tensor representing the negative samples. | |||
:param margin: Default: 1.0 | |||
:param p: The norm degree for pairwise distance. Default: 2.0 | |||
""" | |||
s0 = anchor.shapeof() | |||
s1 = positive.shapeof() | |||
s2 = negative.shapeof() | |||
assert_equal(s0, s1) | |||
assert_equal(s1, s2) | |||
n0 = anchor.ndim | |||
n1 = positive.ndim | |||
n2 = negative.ndim | |||
assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||
"anchor ndim, positive ndim, and negative ndim must be 2; " | |||
"anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||
) | |||
assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||
diff0 = abs(anchor - positive) | |||
diff1 = abs(anchor - negative) | |||
d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||
d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||
loss = maximum(d1 - d2 + margin, 0) | |||
return loss.mean() | |||
def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||
r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||
:param pred: (N,*) where * means, any number of additional dimensions. | |||
:param label: (N,*), same shape as the input. | |||
""" | |||
assert pred.shape == label.shape | |||
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||
def nll_loss( | |||
pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||
) -> Tensor: | |||
r""" | |||
The negative log likelihood loss. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data_shape = (2, 2) | |||
label_shape = (2, ) | |||
data = tensor( | |||
np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||
) | |||
label = tensor( | |||
np.ones(label_shape, dtype=np.int32) | |||
) | |||
pred = F.log(F.softmax(data)) | |||
loss1 = F.nll_loss(pred, label) | |||
loss2 = F.cross_entropy_with_softmax(data, label) | |||
print(loss1.numpy(), loss2.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.6576154] [0.6576154] | |||
""" | |||
raise NotImplementedError | |||
# n0 = pred.ndim | |||
# n1 = label.ndim | |||
# assert n0 == n1 + 1, ( | |||
# "target ndim must be one less than input ndim; input_ndim={} " | |||
# "target_ndim={}".format(n0, n1) | |||
# ) | |||
# mask = 1.0 - equal(label, ignore_index) | |||
# label = label * mask | |||
# loss = indexing_one_hot(pred, label, axis) * mask | |||
# return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
r""" | |||
Caculate the hinge loss which is often used in SVMs. | |||
The hinge loss can be described as: | |||
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j)) | |||
:param pred: The input tensor representing the predicted probability, shape is (N, C). | |||
:param label: The input tensor representing the binary classification label, shape is (N, C). | |||
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32") | |||
label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32") | |||
loss = F.hinge_loss(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1.5] | |||
""" | |||
assert norm in ["L1", "L2"], "norm must be L1 or L2" | |||
# Converts binary labels to -1/1 labels. | |||
loss = relu(1.0 - pred * label) | |||
if norm == "L1": | |||
return loss.sum(axis=1).mean() | |||
else: | |||
return (loss ** 2).sum(axis=1).mean() | |||
def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||
The smooth l1 loss can be described as: | |||
.. math:: | |||
\text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||
where :math:`l_{i}` is given by: | |||
.. math:: | |||
l_{i} = | |||
\begin{cases} | |||
0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||
|x_i - y_i| - 0.5, & \text{otherwise } | |||
\end{cases} | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||
label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||
loss = F.smooth_l1_loss(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.5608334] | |||
""" | |||
raise NotImplementedError | |||
# diff = abs(pred - label) | |||
# l2_loss = 0.5 * (diff ** 2) | |||
# l1_loss = diff - 0.5 | |||
# mask = diff < 1 | |||
# loss = where(mask, l2_loss, l1_loss) | |||
# return loss.mean() |
@@ -0,0 +1,696 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import functools | |||
import math | |||
import numbers | |||
from typing import Optional, Sequence, Tuple, Union | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.tensor import utils | |||
from ..core.tensor.core import apply | |||
from ..tensor import Tensor | |||
from .elemwise import clamp, exp, log, log1p | |||
from .tensor import remove_axis, reshape | |||
__all__ = [ | |||
"all", # TODO | |||
"all_close", # TODO | |||
"any", # TODO | |||
"argmax", | |||
"argmin", | |||
"argsort", | |||
"isinf", | |||
"isnan", # TODO | |||
"max", | |||
"mean", | |||
"median", # TODO | |||
"min", | |||
"norm", | |||
"normalize", | |||
"prod", | |||
"sign", # TODO | |||
"sort", | |||
"std", | |||
"sum", | |||
"topk", | |||
"unique", # TODO | |||
"var", | |||
] | |||
def all(inp): | |||
raise NotImplementedError | |||
def all_close(inp): | |||
raise NotImplementedError | |||
def any(inp): | |||
raise NotImplementedError | |||
def unique(inp): | |||
raise NotImplementedError | |||
def isnan(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is NaN or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, float("nan"), 0]) | |||
print(F.isnan(x)) | |||
.. testoutput:: | |||
Tensor([0 1 0], dtype=uint8) | |||
""" | |||
raise NotImplementedError | |||
# return (inp != inp).astype("uint8") | |||
def isinf(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is Inf or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, float("inf"), 0]) | |||
print(F.isinf(x)) | |||
.. testoutput:: | |||
Tensor([0 1 0], dtype=uint8) | |||
""" | |||
return (abs(inp).astype("float32") == float("inf")).astype("uint8") | |||
def sign(inp: Tensor): | |||
raise NotImplementedError | |||
def _reduce( | |||
data, | |||
*, | |||
mode, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False | |||
): | |||
(data,) = utils.convert_inputs(data) | |||
if axis is None: | |||
data = data.reshape(-1) | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
op = builtin.Reduce(mode=mode, axis=0) | |||
(result,) = apply(op, data) | |||
elif isinstance(axis, collections.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Reduce(mode=mode, axis=ai) | |||
(data,) = apply(op, data) | |||
if not keepdims: | |||
data = remove_axis(data, ai) | |||
result = data | |||
else: | |||
op = builtin.Reduce(mode=mode, axis=axis) | |||
(result,) = apply(op, data) | |||
if not keepdims: | |||
result = remove_axis(result, axis) | |||
return result | |||
def sum( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||
Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. | |||
Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.sum(data) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[21] | |||
""" | |||
return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) | |||
def prod( | |||
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False | |||
) -> Tensor: | |||
r""" | |||
Returns the element product of input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.prod(data) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[720] | |||
""" | |||
return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) | |||
def mean( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the mean value of each row of the ``inp`` tensor in | |||
the given ``axis``. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.mean(data) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[3.5] | |||
""" | |||
return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) | |||
def median( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
raise NotImplementedError | |||
def var( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the variance value of input tensor along | |||
given ``axis``. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
:return: The output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||
out = F.var(data) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[2.9166667] | |||
""" | |||
if axis is None: | |||
m = mean(inp, axis=axis, keepdims=False) | |||
else: | |||
m = mean(inp, axis=axis, keepdims=True) | |||
v = inp - m | |||
return mean(v ** 2, axis=axis, keepdims=keepdims) | |||
def std( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the standard deviation of input tensor along | |||
given ``axis``. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
:return: The output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3)) | |||
out = F.std(data, axis=1) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[0.8164966 0.8164966] | |||
""" | |||
return var(inp, axis=axis, keepdims=keepdims) ** 0.5 | |||
def min( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r""" | |||
Returns the min value of input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.min(x) | |||
print(y.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1] | |||
""" | |||
return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) | |||
def max( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the max value of the input tensor along given *axis*. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.max(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[6] | |||
""" | |||
return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) | |||
def norm( | |||
inp: Tensor, | |||
p: int = 2, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims=False, | |||
): | |||
"""Calculate ``p``-norm of input tensor along certain axis. | |||
:param inp: The input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3)) | |||
y = F.norm(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[4.358899] | |||
""" | |||
if p == 0: | |||
return sum(inp != 0, axis=axis, keepdims=keepdims) | |||
if p == math.inf: | |||
return max(abs(inp)) | |||
if p == -math.inf: | |||
return min(abs(inp)) | |||
return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p) | |||
def argmin( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the indices of the minimum values along an axis | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmin(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[0] | |||
""" | |||
if isinstance(axis, collections.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Argmin(axis=ai) | |||
(inp,) = apply(op, inp) | |||
if not keepdims: | |||
inp = remove_axis(inp, ai) | |||
return inp | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
op = builtin.Argmin(axis=axis) | |||
(result,) = apply(op, inp) | |||
if not keepdims: | |||
result = remove_axis(result, axis) | |||
return result | |||
def argmax( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the indices of the maximum values along an axis | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmax(x) | |||
print(y.numpy()) | |||
.. testoutput:: | |||
[5] | |||
""" | |||
if isinstance(axis, collections.Iterable): | |||
axis = list(axis) | |||
axis.sort(reverse=True) | |||
for ai in axis: | |||
op = builtin.Argmax(axis=ai) | |||
(inp,) = apply(op, inp) | |||
if not keepdims: | |||
inp = remove_axis(inp, ai) | |||
return inp | |||
if axis is None: | |||
assert not keepdims, "can not set axis=None and keepdims=True" | |||
inp = inp.flatten() | |||
axis = 0 | |||
op = builtin.Argmax(axis=axis) | |||
(result,) = apply(op, inp) | |||
if not keepdims: | |||
result = remove_axis(result, axis) | |||
return result | |||
def normalize( | |||
inp: Tensor, | |||
p: int = 2, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
eps: float = 1e-12, | |||
) -> Tensor: | |||
r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||
For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||
.. math:: | |||
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||
:param inp: the input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||
to calculate the norm. Default: None | |||
:param eps: a small value to avoid division by zero. Default: 1e-12 | |||
:return: the normalized output tensor | |||
""" | |||
if axis is None: | |||
return inp / clamp(norm(inp, p, axis), lower=eps) | |||
else: | |||
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) | |||
def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
r""" | |||
Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||
:return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([1,2], dtype=np.float32)) | |||
indices = F.argsort(data) | |||
print(indices.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0 1] | |||
""" | |||
assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
if descending: | |||
order = P.Argsort.Order.DESCENDING | |||
else: | |||
order = P.Argsort.Order.ASCENDING | |||
op = builtin.Argsort(order=order) | |||
if len(inp.shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
_, result = apply(op, inp) | |||
return result[0] | |||
_, result = apply(op, inp) | |||
return result | |||
def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||
assert len(inp.shape) <= 2, "Input should be 1d or 2d" | |||
if descending: | |||
order = P.Argsort.Order.DESCENDING | |||
else: | |||
order = P.Argsort.Order.ASCENDING | |||
op = builtin.Argsort(order=order) | |||
if len(inp.shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
tns, ind = apply(op, inp) | |||
return tns[0], ind[0] | |||
tns, ind = apply(op, inp) | |||
return tns, ind | |||
def topk( | |||
inp: Tensor, | |||
k: int, | |||
descending: bool = False, | |||
kth_only: bool = False, | |||
no_sort: bool = False, | |||
) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param k: The number of elements needed | |||
:param descending: If true, return the largest elements instead. Default: ``False`` | |||
:param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||
:param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||
:return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
top, indices = F.topk(data, 5) | |||
print(top.numpy(), indices.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1. 2. 3. 4. 5.] [7 0 6 1 5] | |||
""" | |||
if descending: | |||
inp = -inp | |||
Mode = P.TopK.Mode | |||
if kth_only: | |||
mode = Mode.KTH_ONLY | |||
elif no_sort: | |||
mode = Mode.VALUE_IDX_NOSORT | |||
else: | |||
mode = Mode.VALUE_IDX_SORTED | |||
op = builtin.TopK(mode=mode) | |||
if len(inp.shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
res = apply(op, inp, Tensor(k, dtype="int32")) | |||
if kth_only: | |||
tns = res[0] | |||
else: | |||
tns, ind = res[0][0], res[1][0] | |||
else: | |||
res = apply(op, inp, Tensor(k, dtype="int32")) | |||
if kth_only: | |||
tns = res | |||
else: | |||
tns, ind = res[0], res[1] | |||
if descending: | |||
tns = -tns | |||
return tns, ind |
@@ -0,0 +1,83 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=too-many-lines | |||
from typing import Tuple, Union | |||
from ..core.ops import builtin | |||
from ..core.tensor.core import apply | |||
from ..tensor import Tensor | |||
from .debug_param import get_conv_execution_strategy | |||
from .types import _pair, _pair_nonzero | |||
def conv_bias_activation( | |||
inp: Tensor, | |||
weight: Tensor, | |||
bias: Tensor, | |||
dtype=None, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
format="NCHW", | |||
nonlinear_mode="IDENTITY", | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" convolution bias with activation operation, only for inference. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
:param dtype: Support for np.dtype, Default: | |||
np.int8. | |||
:param scale: scale if use quantization, Default: | |||
0.0. | |||
:param zero_point: scale if use quantization quint8, Default: | |||
0.0. | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
""" | |||
ph, pw = _pair(padding) | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
op = builtin.ConvBiasForward( | |||
stride_h=sh, | |||
stride_w=sw, | |||
pad_h=ph, | |||
pad_w=pw, | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
dtype=dtype, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
mode=conv_mode, | |||
compute_mode=compute_mode, | |||
sparse=sparse_type, | |||
) | |||
(outputs,) = apply(op, inp, weight, bias) | |||
return outputs |
@@ -0,0 +1,934 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import math | |||
from itertools import accumulate | |||
from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||
import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core.ops import builtin | |||
from ..core.ops._internal import param_defs as P | |||
from ..core.ops.special import Const | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.utils import ( | |||
astensor1d, | |||
convert_inputs, | |||
convert_single_value, | |||
dtype_promotion, | |||
get_device, | |||
) | |||
from ..device import get_default_device | |||
from ..tensor import Tensor | |||
from .elemwise import ceil | |||
__all__ = [ | |||
"add_axis", # expand_dims | |||
"arange", | |||
"broadcast", | |||
"concat", | |||
"cond_take", | |||
"dimshuffle", # transpose, permute | |||
"expand_dims", | |||
"full", | |||
"full_like", | |||
"gather", | |||
"eye", | |||
"linspace", | |||
"ones", | |||
"ones_like", | |||
"remove_axis", # squeeze | |||
"split", | |||
"squeeze", | |||
"stack", | |||
"reshape", | |||
"scatter", | |||
"where", | |||
"zeros", | |||
"zeros_like", | |||
"param_pack_split", | |||
"param_pack_concat", | |||
] | |||
def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor: | |||
""" | |||
Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
:param n: The number of rows | |||
:param m: The number of columns. Default: None | |||
:param dtype: The data type. Default: None | |||
:param device: Compute node of the matrix. Default: None | |||
:param comp_graph: Compute graph of the matrix. Default: None | |||
:return: The eye matrix | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
data_shape = (4, 6) | |||
n, m = data_shape | |||
out = F.eye(n, m, dtype=np.float32) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1. 0. 0. 0. 0. 0.] | |||
[0. 1. 0. 0. 0. 0.] | |||
[0. 0. 1. 0. 0. 0.] | |||
[0. 0. 0. 1. 0. 0.]] | |||
""" | |||
op = builtin.Eye(k=0, dtype=dtype, comp_node=device) | |||
(result,) = apply(op, Tensor(n, dtype="int32", device=device)) | |||
return result | |||
def full(shape, value, dtype="float32", device=None): | |||
if device is None: | |||
device = get_default_device() | |||
(x,) = Const(value, dtype=dtype, device=device)( | |||
Tensor(value, dtype=dtype, device=device) | |||
) | |||
return broadcast(x, shape) | |||
def ones(shape, dtype="float32", device=None): | |||
return full(shape, 1.0, dtype=dtype, device=device) | |||
def zeros(shape, dtype="float32", device=None): | |||
return full(shape, 0.0, dtype=dtype, device=device) | |||
def zeros_like(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a zero tensor with the same shape as input tensor | |||
:param inp: input tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
out = F.zeros_like(inp) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[[0 0 0] | |||
[0 0 0]] | |||
""" | |||
return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||
def ones_like(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a identity tensor with the same shape as input tensor | |||
""" | |||
return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
r""" | |||
Returns a tensor filled with value val with the same shape as input tensor | |||
""" | |||
return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
""" | |||
Broadcast a tensor to ``shape`` | |||
:param inp: The input tensor | |||
:param shape: The target shape | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.broadcast(data, (4, 2, 3)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]] | |||
[[0. 1. 2.] | |||
[3. 4. 5.]]] | |||
""" | |||
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(builtin.Broadcast(), inp, shape) | |||
return result | |||
def concat( | |||
inps: Iterable[Tensor], axis: int = 0, device: Optional[CompNode] = None, | |||
) -> Tensor: | |||
r""" | |||
Concat some tensors | |||
:param inps: Input tensors to concat | |||
:param axis: the dimension over which the tensors are concatenated. Default: 0 | |||
:param device: The comp node output on. Default: None | |||
:param comp_graph: The graph in which output is. Default: None | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
out = F.concat([data1, data2]) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[ 0. 1. 2.] | |||
[ 3. 4. 5.] | |||
[ 6. 7. 8.] | |||
[ 9. 10. 11.]] | |||
""" | |||
dtype = dtype_promotion(inps) | |||
device = get_device(inps) | |||
def convert(x): | |||
return convert_single_value(x, inps, dtype=dtype) | |||
inps = tuple(map(convert, inps)) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
return result | |||
def stack(inps, axis=0): | |||
"""Concats a sequence of tensors along a new axis. | |||
The input tensors must have the same shape. | |||
:param inps: The input tensors. | |||
:param axis: Which axis will be concatenated. | |||
:return: The output concatenated tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) | |||
x2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) | |||
out = F.stack([x1, x2], axis=0) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[ 0. 1. 2.] | |||
[ 3. 4. 5.]] | |||
[[ 6. 7. 8.] | |||
[ 9. 10. 11.]]] | |||
""" | |||
shapes = {arr.shape for arr in inps} | |||
if len(shapes) != 1: | |||
raise ValueError("All input tensors must have the same shape") | |||
inps = [add_axis(inp, axis=axis) for inp in inps] | |||
return concat(inps, axis=axis) | |||
def split(inp, nsplits_or_sections, axis=0): | |||
"""Splits the input tensor into several smaller tensors. | |||
When nsplits_or_sections is int, the last tensor may be smaller than others. | |||
:param inp: The input tensor. | |||
:param nsplits_or_sections: Number of sub tensors or section information list. | |||
:param axis: Which axis will be splited. | |||
:return: The output tensor list. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.random.random((2,3,4,5)), dtype=np.float32) | |||
out = F.split(x, 2, axis=3) | |||
print(out[0].shape, out[1].shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 3, 4, 3) (2, 3, 4, 2) | |||
""" | |||
sub_tensors = [] | |||
sections = [] | |||
def swapaxis(inp, src, dst): | |||
if src == dst: | |||
return inp | |||
shape = [i for i in range(len(inp.shape))] | |||
shape[src] = dst | |||
shape[dst] = src | |||
return inp.transpose(shape) | |||
inp = swapaxis(inp, 0, axis) | |||
if isinstance(nsplits_or_sections, int): | |||
incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) | |||
while incr_step < inp.shape[0]: | |||
sections.append(incr_step) | |||
incr_step += nsplits_or_sections | |||
else: | |||
sections = nsplits_or_sections | |||
st = 0 | |||
for se in sections: | |||
sub_tensors.append(swapaxis(inp[st:se], axis, 0)) | |||
st = se | |||
if st < inp.shape[0]: | |||
sub_tensors.append(swapaxis(inp[st:], axis, 0)) | |||
return sub_tensors | |||
def _get_idx(index, axis): | |||
index_dims = len(index.shape) | |||
idx = [] | |||
for i in range(index_dims): | |||
if i != axis: | |||
shape = [1] * index_dims | |||
shape[i] = index.shape[i] | |||
arange = linspace( | |||
0, index.shape[i] - 1, index.shape[i], device=index.device, | |||
) | |||
arange = ( | |||
arange.reshape(*shape) | |||
.broadcast(index.shape) | |||
.reshape(-1) | |||
.astype(np.int32) | |||
) | |||
idx.append(arange) | |||
else: | |||
idx.append(index.reshape(-1)) | |||
return tuple(idx) | |||
def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
r""" | |||
Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||
For a 3-D tensor, the output is specified by:: | |||
out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0 | |||
out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||
out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||
if :attr:`inp` is an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||
then :attr:`index` must be an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||
output will have the same size as :attr:`index`. | |||
:param inp: the source tensor | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to gather | |||
Examples: | |||
.. testcode:: | |||
import megengine.functional as F | |||
from megengine import tensor | |||
inp = tensor([ | |||
[1,2], [3,4], [5,6], | |||
]) | |||
index = tensor([[0,2], [1,0]]) | |||
oup = F.gather(inp, 0, index) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 6] | |||
[3 2]] | |||
""" | |||
input_shape = inp.shape | |||
index_shape = index.shape | |||
input_dims = len(input_shape) | |||
index_dims = len(index_shape) | |||
if input_dims != index_dims: | |||
raise ValueError( | |||
"The index tensor must have same dimensions as input tensor, " | |||
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims) | |||
) | |||
if axis < 0 or axis >= input_dims: | |||
raise ValueError( | |||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||
axis, input_dims | |||
) | |||
) | |||
for i in range(input_dims): | |||
if i != axis and input_shape[i] != index_shape[i]: | |||
raise ValueError( | |||
"The input {} and index {} must have the same size apart from axis {}".format( | |||
input_shape, index_shape, axis | |||
) | |||
) | |||
idx = _get_idx(index, axis) | |||
return inp[idx].reshape(index.shape) # pylint: disable=no-member | |||
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
r""" | |||
Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||
For each value in :attr:`source`, its output index is specified by its index | |||
in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||
:attr:`index` for ``axis = dimension``. | |||
For a 3-D tensor, :attr:`inp` is updated as:: | |||
inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||
inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||
inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||
:attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||
It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||
for all dimensions ``d``. | |||
Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
.. note:: | |||
Please notice that, due to performance issues, the result is uncertain on the GPU device | |||
if scatter difference positions from source to the same destination position | |||
regard to index tensor. | |||
Show the case using the following examples, the oup[0][2] is maybe | |||
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||
if set the index[1][2] from 1 to 0. | |||
:param inp: the inp tensor which to be scattered | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to scatter | |||
:param source: the source element(s) to scatter | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
from megengine import tensor | |||
inp = tensor(np.zeros(shape=(3,5),dtype=np.float32)) | |||
source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) | |||
index = tensor([[0,2,0,2,1],[2,0,1,1,2]]) | |||
oup = F.scatter(inp, 0, index,source) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[0.9935 0.0718 0.2256 0. 0. ] | |||
[0. 0. 0.5939 0.357 0.4396] | |||
[0.7723 0.9465 0. 0.8926 0.4576]] | |||
""" | |||
input_shape = inp.shape | |||
index_shape = index.shape | |||
source_shape = source.shape | |||
input_dims = len(input_shape) | |||
index_dims = len(index_shape) | |||
source_dims = len(source_shape) | |||
if input_dims != index_dims or input_dims != source_dims: | |||
raise ValueError("The input, source and index tensor must have same dimensions") | |||
if axis < 0 or axis >= input_dims: | |||
raise ValueError( | |||
"Index axis {} is output of bounds, should in range [0 {})".format( | |||
axis, input_dims | |||
) | |||
) | |||
for i in range(source_dims): | |||
if source_shape[i] > input_shape[i]: | |||
raise ValueError( | |||
"The each shape size for source {} must be less than or equal to input {} ".format( | |||
source_shape, input_shape | |||
) | |||
) | |||
for i in range(index_dims): | |||
if index_shape[i] != source_shape[i]: | |||
raise ValueError( | |||
"The each shape size for index {} must be equal to source {} ".format( | |||
index_shape, source_shape | |||
) | |||
) | |||
for i in range(index_dims): | |||
if i != axis and index_shape[i] > input_shape[i]: | |||
raise ValueError( | |||
"The index {} must be less than or equal to input {} size apart from axis {}".format( | |||
index_shape, input_shape, axis | |||
) | |||
) | |||
idx = _get_idx(index, axis) | |||
inp[idx] = source.flatten() | |||
return inp | |||
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
r""" | |||
Select elements either from Tensor x or Tensor y, according to mask. | |||
.. math:: | |||
\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||
:param mask: a mask used for choosing x or y | |||
:param x: the first choice | |||
:param y: the second choice | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
dtype=np.float32)) | |||
y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) | |||
out = F.where(mask, x, y) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1. 6.] | |||
[7. 4.]] | |||
""" | |||
raise NotImplementedError | |||
# v0, index0 = mgb.opr.cond_take( | |||
# x, mask, mode=P.CondTake.Mode.EQ, val=1 | |||
# ) | |||
# v1, index1 = mgb.opr.cond_take( | |||
# y, mask, mode=P.CondTake.Mode.EQ, val=0 | |||
# ) | |||
# out = x.flatten() | |||
# index = mgb.opr.concat(index0, index1, axis=0) | |||
# v = mgb.opr.concat(v0, v1, axis=0) | |||
# out = mgb.opr.set_advanced_indexing(out, v)[index] | |||
# out = out.reshape(x.shape) | |||
# return out | |||
def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
r""" | |||
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
:param mask: condition param; must be the same shape with data | |||
:param x: input tensor from which to take elements | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
dtype=np.float32)) | |||
v, index = F.cond_take(mask, x) | |||
print(v.numpy(), index.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
""" | |||
if not isinstance(x, (TensorWrapperBase, TensorBase)): | |||
raise TypeError("input must be a tensor") | |||
if not isinstance(mask, (TensorWrapperBase, TensorBase)): | |||
raise TypeError("mask must be a tensor") | |||
if mask.dtype != np.bool_: | |||
raise ValueError("mask must be bool") | |||
if x.device != mask.device: | |||
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) | |||
op = builtin.CondTake() | |||
v, index = apply(op, x, mask) | |||
return v, index | |||
def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
r""" | |||
Swap shapes and strides according to given pattern | |||
:param inp: Input tensor | |||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
* (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
* (0, 1) -> identity for 2d vectors | |||
* (1, 0) -> inverts the first and second dimensions | |||
* (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN) | |||
* (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1) | |||
* (2, 0, 1) -> AxBxC to CxAxB | |||
* (0, ``'x'``, 1) -> AxB to Ax1xB | |||
* (1, ``'x'``, 0) -> AxB to Bx1xA | |||
* (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||
out = F.dimshuffle(x, (1, 0)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 0] | |||
[1 0]] | |||
""" | |||
op = builtin.Dimshuffle(pattern) | |||
(inp,) = convert_inputs(inp) | |||
(result,) = apply(op, inp) | |||
return result | |||
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
r""" | |||
Reshape a tensor to given target shape; total number of logical elements must | |||
remain unchanged | |||
:param inp: Input tensor | |||
:param target_shape: target shape, the components would be concatenated to form the | |||
target shape, and it can contain an element of -1 representing unspec_axis. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(12, dtype=np.int32)) | |||
out = F.reshape(x, (3, 2, 2)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[ 0 1] | |||
[ 2 3]] | |||
[[ 4 5] | |||
[ 6 7]] | |||
[[ 8 9] | |||
[10 11]]] | |||
""" | |||
if isinstance(target_shape, (TensorBase, TensorWrapperBase)): | |||
target_shape = target_shape.numpy() | |||
target_shape = tuple(map(int, target_shape)) | |||
unspec_axis = None | |||
for i, s in enumerate(target_shape): | |||
if s < 0: | |||
if s != -1: | |||
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
if unspec_axis is not None: | |||
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
unspec_axis = i | |||
# TODO: device should be None (cpu) | |||
(target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp) | |||
if unspec_axis is None: | |||
op = builtin.Reshape() | |||
else: | |||
op = builtin.Reshape(unspec_axis=unspec_axis) | |||
(x,) = apply(op, inp, target_shape) | |||
return x | |||
transpose = dimshuffle | |||
AxisAddRemove = builtin.AxisAddRemove | |||
AxisDesc = AxisAddRemove.AxisDesc | |||
def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
r""" | |||
Add dimension before given axis. | |||
:param inp: Input tensor | |||
:param axis: Place of new axes | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, 2]) | |||
out = F.add_axis(x, 0) | |||
print(out.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(1, 2) | |||
""" | |||
Param = AxisAddRemove.Param | |||
def get_axes(): | |||
try: | |||
return [int(axis)] | |||
except (TypeError, ValueError): | |||
pass | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
ndim = inp.ndim + len(axis) | |||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
param = Param(*map(AxisDesc.make_add, axis)) | |||
op = AxisAddRemove(param=param) | |||
(result,) = apply(op, inp) | |||
return result | |||
expand_dims = add_axis | |||
def remove_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
r""" | |||
Remove dimension of shape 1. | |||
:param inp: Input tensor | |||
:param axis: Place of axis to be removed | |||
:return: The output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||
out = F.remove_axis(x, 3) | |||
print(out.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(1, 1, 2) | |||
""" | |||
Param = AxisAddRemove.Param | |||
def get_axes(): | |||
if axis is None: | |||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||
try: | |||
return [int(axis)] | |||
except (TypeError, ValueError): | |||
pass | |||
return list(map(int, axis)) | |||
axis = get_axes() | |||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
param = Param(*map(AxisDesc.make_remove, axis)) | |||
op = AxisAddRemove(param=param) | |||
(result,) = apply(op, inp) | |||
return result | |||
squeeze = remove_axis | |||
def linspace( | |||
start: Union[int, float, Tensor], | |||
stop: Union[int, float, Tensor], | |||
num: Union[int, Tensor], | |||
dtype="float32", | |||
device: Optional[CompNode] = None, | |||
) -> Tensor: | |||
r""" | |||
Return equally spaced numbers over a specified interval | |||
:param start: Starting value of the squence, shoule be scalar | |||
:param stop: The last value of the squence, shoule be scalar | |||
:param num: number of values to generate | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
a = F.linspace(3,10,5) | |||
print(a.numpy()) | |||
.. testoutput:: | |||
[ 3. 4.75 6.5 8.25 10. ] | |||
""" | |||
start = Tensor(start, device=device) | |||
stop = Tensor(stop, device=device) | |||
num = Tensor(num, device=device) | |||
device = device if device is None else device.to_c() | |||
op = builtin.Linspace(comp_node=device) | |||
(result,) = apply(op, start, stop, num) | |||
if np.dtype(dtype) == np.int32: | |||
return result.astype(dtype) | |||
return result | |||
def arange( | |||
start: Union[int, float, Tensor], | |||
end: Union[int, float, Tensor], | |||
step: Union[int, float, Tensor] = 1, | |||
dtype="float32", | |||
device: Optional[CompNode] = None, | |||
) -> Tensor: | |||
r""" | |||
Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||
:param start: starting value of the squence, shoule be scalar | |||
:param end: ending value of the squence, shoule be scalar | |||
:param step: the gap between each pair of adjacent values. Default 1 | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
a = F.arange(1, 5, 1) | |||
print(a.numpy()) | |||
.. testoutput:: | |||
[1. 2. 3. 4.] | |||
""" | |||
if isinstance(start, Tensor): | |||
start = start.astype("float32") | |||
if isinstance(end, Tensor): | |||
end = end.astype("float32") | |||
if isinstance(step, Tensor): | |||
step = step.astype("float32") | |||
num = ceil(Tensor((end - start) / step, device=device)) | |||
stop = start + step * (num - 1) | |||
result = linspace(start, stop, num, device=device) | |||
if np.dtype(dtype) == np.int32: | |||
return result.astype(dtype) | |||
return result | |||
def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: | |||
op = builtin.ParamPackSplit() | |||
op.offsets = offsets | |||
op.shapes = shapes | |||
return apply(op, inp) | |||
def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: | |||
op = builtin.ParamPackConcat() | |||
op.offsets = offsets_val | |||
return apply(op, *inps, offsets)[0] |
@@ -0,0 +1,37 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import functools | |||
def get_ndtuple(value, *, n, allow_zero=True): | |||
r"""Converts possibly 1D tuple to nd tuple | |||
:type allow_zero: bool | |||
:param allow_zero: whether to allow zero tuple value""" | |||
if not isinstance(value, collections.Iterable): | |||
value = int(value) | |||
value = tuple([value for i in range(n)]) | |||
else: | |||
assert len(value) == n, "tuple len is not equal to n: {}".format(value) | |||
spatial_axis = map(int, value) | |||
value = tuple(spatial_axis) | |||
if allow_zero: | |||
minv = 0 | |||
else: | |||
minv = 1 | |||
assert min(value) >= minv, "invalid value: {}".format(value) | |||
return value | |||
_single = functools.partial(get_ndtuple, n=1, allow_zero=True) | |||
_pair = functools.partial(get_ndtuple, n=2, allow_zero=True) | |||
_pair_nonzero = functools.partial(get_ndtuple, n=2, allow_zero=False) | |||
_triple = functools.partial(get_ndtuple, n=3, allow_zero=True) | |||
_quadruple = functools.partial(get_ndtuple, n=4, allow_zero=True) |
@@ -0,0 +1,80 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Iterable, Union | |||
import numpy as np | |||
from ..core.ops.builtin import Copy | |||
from ..core.tensor import Tensor | |||
from ..core.tensor.core import apply | |||
from .math import topk as _topk | |||
from .tensor import dimshuffle as _dimshuffle | |||
def accuracy( | |||
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 | |||
) -> Union[Tensor, Iterable[Tensor]]: | |||
r""" | |||
Calculate the classification accuracy given predicted logits and ground-truth labels. | |||
:param logits: Model predictions of shape [batch_size, num_classes], | |||
representing the probability (likelyhood) of each class. | |||
:param target: Ground-truth labels, 1d tensor of int32 | |||
:param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
:return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
logits = tensor(np.arange(80, dtype=np.int32).reshape(8,10)) | |||
target = tensor(np.arange(8, dtype=np.int32)) | |||
top1, top5 = F.accuracy(logits, target, (1, 5)) | |||
print(top1.numpy(), top5.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.] [0.375] | |||
""" | |||
if isinstance(topk, int): | |||
topk = (topk,) | |||
_, pred = _topk(logits, k=max(topk), descending=True) | |||
accs = [] | |||
for k in topk: | |||
correct = pred[:, :k].detach() == _dimshuffle(target, (0, "x")).broadcast( | |||
target.shape[0], k | |||
) | |||
accs.append(correct.astype(np.float32).sum() / target.shape[0]) | |||
if len(topk) == 1: # type: ignore[arg-type] | |||
accs = accs[0] | |||
return accs | |||
def zero_grad(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a tensor which is treated as constant during backward gradient calcuation, | |||
i.e. its gradient is zero. | |||
:param inp: Input tensor. | |||
See implementation of :func:`~.softmax` for example. | |||
""" | |||
print("zero_grad is obsoleted, please use detach instead") | |||
raise NotImplementedError | |||
def copy(inp, cn): | |||
return apply(Copy(comp_node=cn), inp)[0] |
@@ -0,0 +1,16 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .hub import ( | |||
help, | |||
import_module, | |||
list, | |||
load, | |||
load_serialized_obj_from_url, | |||
pretrained, | |||
) |
@@ -0,0 +1,17 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
DEFAULT_BRANCH_NAME = "master" | |||
HUBCONF = "hubconf.py" | |||
HUBDEPENDENCY = "dependencies" | |||
DEFAULT_GIT_HOST = "github.com" | |||
ENV_MGE_HOME = "MGE_HOME" | |||
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" | |||
DEFAULT_CACHE_DIR = "~/.cache" | |||
DEFAULT_PROTOCOL = "HTTPS" | |||
HTTP_READ_TIMEOUT = 120 |
@@ -0,0 +1,30 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
class FetcherError(Exception): | |||
"""Base class for fetch related error.""" | |||
class InvalidRepo(FetcherError): | |||
"""The repo provided was somehow invalid.""" | |||
class InvalidGitHost(FetcherError): | |||
"""The git host provided was somehow invalid.""" | |||
class GitPullError(FetcherError): | |||
"""A git pull error occurred""" | |||
class GitCheckoutError(FetcherError): | |||
"""A git checkout error occurred""" | |||
class InvalidProtocol(FetcherError): | |||
"""The protocol provided was somehow invalid""" |
@@ -0,0 +1,300 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import hashlib | |||
import os | |||
import re | |||
import shutil | |||
import subprocess | |||
from tempfile import NamedTemporaryFile | |||
from typing import Tuple | |||
from zipfile import ZipFile | |||
import requests | |||
from tqdm import tqdm | |||
from megengine.utils.http_download import ( | |||
CHUNK_SIZE, | |||
HTTP_CONNECTION_TIMEOUT, | |||
HTTPDownloadError, | |||
) | |||
from ..distributed import is_distributed, synchronized | |||
from ..logger import get_logger | |||
from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT | |||
from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo | |||
from .tools import cd | |||
logger = get_logger(__name__) | |||
HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT) | |||
pattern = re.compile( | |||
r"^(?:[a-z0-9]" # First character of the domain | |||
r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname | |||
r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD | |||
r"[a-z]$" # Last character of the gTLD | |||
) | |||
class RepoFetcherBase: | |||
@classmethod | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
raise NotImplementedError() | |||
@classmethod | |||
def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: | |||
try: | |||
branch_info = DEFAULT_BRANCH_NAME | |||
if ":" in repo_info: | |||
prefix_info, branch_info = repo_info.split(":") | |||
else: | |||
prefix_info = repo_info | |||
repo_owner, repo_name = prefix_info.split("/") | |||
return repo_owner, repo_name, branch_info | |||
except ValueError: | |||
raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) | |||
@classmethod | |||
def _check_git_host(cls, git_host): | |||
return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host) | |||
@classmethod | |||
def _is_valid_domain(cls, s): | |||
try: | |||
return pattern.match(s.encode("idna").decode("ascii")) | |||
except UnicodeError: | |||
return False | |||
@classmethod | |||
def _is_valid_host(cls, s): | |||
nums = s.split(".") | |||
if len(nums) != 4 or any(not _.isdigit() for _ in nums): | |||
return False | |||
return all(0 <= int(_) < 256 for _ in nums) | |||
@classmethod | |||
def _gen_repo_dir(cls, repo_dir: str) -> str: | |||
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] | |||
class GitSSHFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
""" | |||
Fetches git repo by SSH protocol | |||
:param git_host: | |||
host address of git repo. | |||
example: github.com | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param use_cache: | |||
whether to use locally fetched code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param silent: | |||
whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
displaying on the screen | |||
:return: | |||
directory where the repo code is stored | |||
""" | |||
if not cls._check_git_host(git_host): | |||
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
normalized_branch_info = branch_info.replace("/", "_") | |||
repo_dir_raw = "{}_{}_{}".format( | |||
repo_owner, repo_name, normalized_branch_info | |||
) + ("_{}".format(commit) if commit else "") | |||
repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) | |||
if use_cache and os.path.exists(repo_dir): # use cache | |||
logger.debug("Cache Found in %s", repo_dir) | |||
return repo_dir | |||
if is_distributed(): | |||
logger.warning( | |||
"When using `hub.load` or `hub.list` to fetch git repositories\n" | |||
" in DISTRIBUTED mode for the first time, processes are synchronized to\n" | |||
" ensure that target repository is ready to use for each process.\n" | |||
" Users are expected to see this warning no more than ONCE, otherwise\n" | |||
" (very little chance) you may need to remove corrupt cache\n" | |||
" `%s` and fetch again.", | |||
repo_dir, | |||
) | |||
shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
logger.debug( | |||
"Git Clone from Repo:%s Branch: %s to %s", | |||
git_url, | |||
normalized_branch_info, | |||
repo_dir, | |||
) | |||
kwargs = ( | |||
{"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} | |||
) | |||
if commit is None: | |||
# shallow clone repo by branch/tag | |||
p = subprocess.Popen( | |||
[ | |||
"git", | |||
"clone", | |||
"-b", | |||
normalized_branch_info, | |||
git_url, | |||
repo_dir, | |||
"--depth=1", | |||
], | |||
**kwargs, | |||
) | |||
cls._check_clone_pipe(p) | |||
else: | |||
# clone repo and checkout to commit_id | |||
p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs) | |||
cls._check_clone_pipe(p) | |||
with cd(repo_dir): | |||
logger.debug("git checkout to %s", commit) | |||
p = subprocess.Popen(["git", "checkout", commit], **kwargs) | |||
_, err = p.communicate() | |||
if p.returncode: | |||
shutil.rmtree(repo_dir, ignore_errors=True) | |||
raise GitCheckoutError( | |||
"Git checkout error, please check the commit id.\n" | |||
+ err.decode() | |||
) | |||
with cd(repo_dir): | |||
shutil.rmtree(".git") | |||
return repo_dir | |||
@classmethod | |||
def _check_clone_pipe(cls, p): | |||
_, err = p.communicate() | |||
if p.returncode: | |||
raise GitPullError( | |||
"Repo pull error, please check repo info.\n" + err.decode() | |||
) | |||
class GitHTTPSFetcher(RepoFetcherBase): | |||
@classmethod | |||
@synchronized | |||
def fetch( | |||
cls, | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
silent: bool = True, | |||
) -> str: | |||
""" | |||
Fetches git repo by HTTPS protocol | |||
:param git_host: | |||
host address of git repo | |||
example: github.com | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param silent: | |||
whether to accept the stdout and stderr of the subprocess with PIPE, instead of | |||
displaying on the screen | |||
:return: | |||
directory where the repo code is stored | |||
""" | |||
if not cls._check_git_host(git_host): | |||
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) | |||
repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) | |||
normalized_branch_info = branch_info.replace("/", "_") | |||
repo_dir_raw = "{}_{}_{}".format( | |||
repo_owner, repo_name, normalized_branch_info | |||
) + ("_{}".format(commit) if commit else "") | |||
repo_dir = cls._gen_repo_dir(repo_dir_raw) | |||
archive_url = cls._git_archive_link( | |||
git_host, repo_owner, repo_name, branch_info, commit | |||
) | |||
if use_cache and os.path.exists(repo_dir): # use cache | |||
logger.debug("Cache Found in %s", repo_dir) | |||
return repo_dir | |||
if is_distributed(): | |||
logger.warning( | |||
"When using `hub.load` or `hub.list` to fetch git repositories " | |||
"in DISTRIBUTED mode for the first time, processes are synchronized to " | |||
"ensure that target repository is ready to use for each process.\n" | |||
"Users are expected to see this warning no more than ONCE, otherwise" | |||
"(very little chance) you may need to remove corrupt hub cache %s and fetch again." | |||
) | |||
shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache | |||
logger.debug("Downloading from %s to %s", archive_url, repo_dir) | |||
cls._download_zip_and_extract(archive_url, repo_dir) | |||
return repo_dir | |||
@classmethod | |||
def _download_zip_and_extract(cls, url, target_dir): | |||
resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True) | |||
if resp.status_code != 200: | |||
raise HTTPDownloadError( | |||
"An error occured when downloading from {}".format(url) | |||
) | |||
total_size = int(resp.headers.get("Content-Length", 0)) | |||
_bar = tqdm(total=total_size, unit="iB", unit_scale=True) | |||
with NamedTemporaryFile("w+b") as f: | |||
for chunk in resp.iter_content(CHUNK_SIZE): | |||
if not chunk: | |||
break | |||
_bar.update(len(chunk)) | |||
f.write(chunk) | |||
_bar.close() | |||
f.seek(0) | |||
with ZipFile(f) as temp_zip_f: | |||
zip_dir_name = temp_zip_f.namelist()[0].split("/")[0] | |||
temp_zip_f.extractall(".") | |||
shutil.move(zip_dir_name, target_dir) | |||
@classmethod | |||
def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit): | |||
archive_link = "https://{}/{}/{}/archive/{}.zip".format( | |||
git_host, repo_owner, repo_name, commit or branch_info | |||
) | |||
return archive_link |
@@ -0,0 +1,333 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import hashlib | |||
import os | |||
import sys | |||
import types | |||
from typing import Any, List | |||
from urllib.parse import urlparse | |||
from megengine.utils.http_download import download_from_url | |||
from ..distributed import is_distributed | |||
from ..logger import get_logger | |||
from ..serialization import load as _mge_load_serialized | |||
from .const import ( | |||
DEFAULT_CACHE_DIR, | |||
DEFAULT_GIT_HOST, | |||
DEFAULT_PROTOCOL, | |||
ENV_MGE_HOME, | |||
ENV_XDG_CACHE_HOME, | |||
HTTP_READ_TIMEOUT, | |||
HUBCONF, | |||
HUBDEPENDENCY, | |||
) | |||
from .exceptions import InvalidProtocol | |||
from .fetcher import GitHTTPSFetcher, GitSSHFetcher | |||
from .tools import cd, check_module_exists, load_module | |||
logger = get_logger(__name__) | |||
PROTOCOLS = { | |||
"HTTPS": GitHTTPSFetcher, | |||
"SSH": GitSSHFetcher, | |||
} | |||
def _get_megengine_home() -> str: | |||
"""MGE_HOME setting complies with the XDG Base Directory Specification | |||
""" | |||
megengine_home = os.path.expanduser( | |||
os.getenv( | |||
ENV_MGE_HOME, | |||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), | |||
) | |||
) | |||
return megengine_home | |||
def _get_repo( | |||
git_host: str, | |||
repo_info: str, | |||
use_cache: bool = False, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> str: | |||
if protocol not in PROTOCOLS: | |||
raise InvalidProtocol( | |||
"Invalid protocol, the value should be one of {}.".format( | |||
", ".join(PROTOCOLS.keys()) | |||
) | |||
) | |||
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
with cd(cache_dir): | |||
fetcher = PROTOCOLS[protocol] | |||
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) | |||
return os.path.join(cache_dir, repo_dir) | |||
def _check_dependencies(module: types.ModuleType) -> None: | |||
if not hasattr(module, HUBDEPENDENCY): | |||
return | |||
dependencies = getattr(module, HUBDEPENDENCY) | |||
if not dependencies: | |||
return | |||
missing_deps = [m for m in dependencies if not check_module_exists(m)] | |||
if len(missing_deps): | |||
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) | |||
def _init_hub( | |||
repo_info: str, | |||
git_host: str, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
): | |||
"""Imports hubmodule like python import | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
hubconf.py as a python module | |||
""" | |||
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
absolute_repo_dir = _get_repo( | |||
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol | |||
) | |||
sys.path.insert(0, absolute_repo_dir) | |||
hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) | |||
sys.path.remove(absolute_repo_dir) | |||
return hubmodule | |||
@functools.wraps(_init_hub) | |||
def import_module(*args, **kwargs): | |||
return _init_hub(*args, **kwargs) | |||
def list( | |||
repo_info: str, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> List[str]: | |||
"""Lists all entrypoints available in repo hubconf | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
all entrypoint names of the model | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
return [ | |||
_ | |||
for _ in dir(hubmodule) | |||
if not _.startswith("__") and callable(getattr(hubmodule, _)) | |||
] | |||
def load( | |||
repo_info: str, | |||
entry: str, | |||
*args, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
**kwargs | |||
) -> Any: | |||
"""Loads model from github or gitlab repo, with pretrained weights. | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param entry: | |||
an entrypoint defined in hubconf | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
a single model with corresponding pretrained weights. | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
_check_dependencies(hubmodule) | |||
module = getattr(hubmodule, entry)(*args, **kwargs) | |||
return module | |||
def help( | |||
repo_info: str, | |||
entry: str, | |||
git_host: str = DEFAULT_GIT_HOST, | |||
use_cache: bool = True, | |||
commit: str = None, | |||
protocol: str = DEFAULT_PROTOCOL, | |||
) -> str: | |||
"""This function returns docstring of entrypoint ``entry`` by following steps: | |||
1. Pull the repo code specified by git and repo_info | |||
2. Load the entry defined in repo's hubconf.py | |||
3. Return docstring of function entry | |||
:param repo_info: | |||
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional | |||
tag/branch. The default branch is ``master`` if not specified. | |||
Example: ``"brain_sdk/MegBrain[:hub]"`` | |||
:param entry: | |||
an entrypoint defined in hubconf.py | |||
:param git_host: | |||
host address of git repo | |||
Example: github.com | |||
:param use_cache: | |||
whether to use locally cached code or completely re-fetch | |||
:param commit: | |||
commit id on github or gitlab | |||
:param protocol: | |||
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. | |||
The value should be one of HTTPS, SSH. | |||
:return: | |||
docstring of entrypoint ``entry`` | |||
""" | |||
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) | |||
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): | |||
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) | |||
doc = getattr(hubmodule, entry).__doc__ | |||
return doc | |||
def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: | |||
"""Loads MegEngine serialized object from the given URL. | |||
If the object is already present in ``model_dir``, it's deserialized and | |||
returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. | |||
:param url: url to serialized object | |||
:param model_dir: dir to cache target serialized file | |||
:return: loaded object | |||
""" | |||
if model_dir is None: | |||
model_dir = os.path.join(_get_megengine_home(), "serialized") | |||
os.makedirs(model_dir, exist_ok=True) | |||
parts = urlparse(url) | |||
filename = os.path.basename(parts.path) | |||
# use hash as prefix to avoid filename conflict from different urls | |||
sha256 = hashlib.sha256() | |||
sha256.update(url.encode()) | |||
digest = sha256.hexdigest()[:6] | |||
filename = digest + "_" + filename | |||
cached_file = os.path.join(model_dir, filename) | |||
logger.info( | |||
"load_serialized_obj_from_url: download to or using cached %s", cached_file | |||
) | |||
if not os.path.exists(cached_file): | |||
if is_distributed(): | |||
logger.warning( | |||
"Downloading serialized object in DISTRIBUTED mode\n" | |||
" File may be downloaded multiple times. We recommend\n" | |||
" users to download in single process first." | |||
) | |||
download_from_url(url, cached_file, HTTP_READ_TIMEOUT) | |||
state_dict = _mge_load_serialized(cached_file) | |||
return state_dict | |||
class pretrained: | |||
r""" | |||
Decorator which helps to download pretrained weights from the given url. | |||
For example, we can decorate a resnet18 function as follows | |||
.. code-block:: | |||
@hub.pretrained("https://url/to/pretrained_resnet18.pkl") | |||
def resnet18(**kwargs): | |||
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |||
When decorated function is called with ``pretrained=True``, MegEngine will automatically | |||
download and fill the returned model with pretrained weights. | |||
""" | |||
def __init__(self, url): | |||
self.url = url | |||
def __call__(self, func): | |||
@functools.wraps(func) | |||
def pretrained_model_func( | |||
pretrained=False, **kwargs | |||
): # pylint: disable=redefined-outer-name | |||
model = func(**kwargs) | |||
if pretrained: | |||
weights = load_serialized_obj_from_url(self.url) | |||
model.load_state_dict(weights) | |||
return model | |||
return pretrained_model_func | |||
__all__ = [ | |||
"list", | |||
"load", | |||
"help", | |||
"load_serialized_obj_from_url", | |||
"pretrained", | |||
"import_module", | |||
] |
@@ -0,0 +1,48 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import importlib.util | |||
import os | |||
import types | |||
from contextlib import contextmanager | |||
from typing import Iterator | |||
def load_module(name: str, path: str) -> types.ModuleType: | |||
""" | |||
Loads module specified by name and path | |||
:param name: module name | |||
:param path: module path | |||
""" | |||
spec = importlib.util.spec_from_file_location(name, path) | |||
module = importlib.util.module_from_spec(spec) | |||
spec.loader.exec_module(module) | |||
return module | |||
def check_module_exists(module: str) -> bool: | |||
"""Checks whether python module exists or not | |||
:param module: name of module | |||
""" | |||
return importlib.util.find_spec(module) is not None | |||
@contextmanager | |||
def cd(target: str) -> Iterator[None]: | |||
"""Changes current directory to target | |||
:param target: target directory | |||
""" | |||
prev = os.getcwd() | |||
os.chdir(os.path.expanduser(target)) | |||
try: | |||
yield | |||
finally: | |||
os.chdir(prev) |
@@ -0,0 +1,237 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import logging | |||
import os | |||
import sys | |||
_all_loggers = [] | |||
_default_level_name = os.getenv("MEGENGINE_LOGGING_LEVEL", "ERROR") | |||
_default_level = logging.getLevelName(_default_level_name.upper()) | |||
def set_log_file(fout, mode="a"): | |||
r"""Sets log output file. | |||
:type fout: str or file-like | |||
:param fout: file-like object that supports write and flush, or string for | |||
the filename | |||
:type mode: str | |||
:param mode: specify the mode to open log file if *fout* is a string | |||
""" | |||
if isinstance(fout, str): | |||
fout = open(fout, mode) | |||
MegEngineLogFormatter.log_fout = fout | |||
class MegEngineLogFormatter(logging.Formatter): | |||
log_fout = None | |||
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " | |||
date = "%(asctime)s " | |||
msg = "%(message)s" | |||
max_lines = 256 | |||
def _color_exc(self, msg): | |||
r"""Sets the color of message as the execution type. | |||
""" | |||
return "\x1b[34m{}\x1b[0m".format(msg) | |||
def _color_dbg(self, msg): | |||
r"""Sets the color of message as the debugging type. | |||
""" | |||
return "\x1b[36m{}\x1b[0m".format(msg) | |||
def _color_warn(self, msg): | |||
r"""Sets the color of message as the warning type. | |||
""" | |||
return "\x1b[1;31m{}\x1b[0m".format(msg) | |||
def _color_err(self, msg): | |||
r"""Sets the color of message as the error type. | |||
""" | |||
return "\x1b[1;4;31m{}\x1b[0m".format(msg) | |||
def _color_omitted(self, msg): | |||
r"""Sets the color of message as the omitted type. | |||
""" | |||
return "\x1b[35m{}\x1b[0m".format(msg) | |||
def _color_normal(self, msg): | |||
r"""Sets the color of message as the normal type. | |||
""" | |||
return msg | |||
def _color_date(self, msg): | |||
r"""Sets the color of message the same as date. | |||
""" | |||
return "\x1b[32m{}\x1b[0m".format(msg) | |||
def format(self, record): | |||
if record.levelno == logging.DEBUG: | |||
mcl, mtxt = self._color_dbg, "DBG" | |||
elif record.levelno == logging.WARNING: | |||
mcl, mtxt = self._color_warn, "WRN" | |||
elif record.levelno == logging.ERROR: | |||
mcl, mtxt = self._color_err, "ERR" | |||
else: | |||
mcl, mtxt = self._color_normal, "" | |||
if mtxt: | |||
mtxt += " " | |||
if self.log_fout: | |||
self.__set_fmt(self.date_full + mtxt + self.msg) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
nr_line = formatted.count("\n") + 1 | |||
if nr_line >= self.max_lines: | |||
head, body = formatted.split("\n", 1) | |||
formatted = "\n".join( | |||
[ | |||
head, | |||
"BEGIN_LONG_LOG_{}_LINES{{".format(nr_line - 1), | |||
body, | |||
"}}END_LONG_LOG_{}_LINES".format(nr_line - 1), | |||
] | |||
) | |||
self.log_fout.write(formatted) | |||
self.log_fout.write("\n") | |||
self.log_fout.flush() | |||
self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) | |||
formatted = super(MegEngineLogFormatter, self).format(record) | |||
if record.exc_text or record.exc_info: | |||
# handle exception format | |||
b = formatted.find("Traceback ") | |||
if b != -1: | |||
s = formatted[b:] | |||
s = self._color_exc(" " + s.replace("\n", "\n ")) | |||
formatted = formatted[:b] + s | |||
nr_line = formatted.count("\n") + 1 | |||
if nr_line >= self.max_lines: | |||
lines = formatted.split("\n") | |||
remain = self.max_lines // 2 | |||
removed = len(lines) - remain * 2 | |||
if removed > 0: | |||
mid_msg = self._color_omitted( | |||
"[{} log lines omitted (would be written to output file " | |||
"if set_log_file() has been called;\n" | |||
" the threshold can be set at " | |||
"MegEngineLogFormatter.max_lines)]".format(removed) | |||
) | |||
formatted = "\n".join(lines[:remain] + [mid_msg] + lines[-remain:]) | |||
return formatted | |||
if sys.version_info.major < 3: | |||
def __set_fmt(self, fmt): | |||
self._fmt = fmt | |||
else: | |||
def __set_fmt(self, fmt): | |||
self._style._fmt = fmt | |||
def get_logger(name=None, formatter=MegEngineLogFormatter): | |||
r"""Gets megengine logger with given name. | |||
""" | |||
logger = logging.getLogger(name) | |||
if getattr(logger, "_init_done__", None): | |||
return logger | |||
logger._init_done__ = True | |||
logger.propagate = False | |||
logger.setLevel(_default_level) | |||
handler = logging.StreamHandler() | |||
handler.setFormatter(formatter(datefmt="%d %H:%M:%S")) | |||
handler.setLevel(0) | |||
del logger.handlers[:] | |||
logger.addHandler(handler) | |||
_all_loggers.append(logger) | |||
return logger | |||
def set_log_level(level, update_existing=True): | |||
"""Sets default logging level. | |||
:type level: int e.g. logging.INFO | |||
:param level: loggin level given by python :mod:`logging` module | |||
:param update_existing: whether to update existing loggers | |||
""" | |||
global _default_level # pylint: disable=global-statement | |||
_default_level = level | |||
if update_existing: | |||
for i in _all_loggers: | |||
i.setLevel(level) | |||
_logger = get_logger(__name__) | |||
try: | |||
if sys.version_info.major < 3: | |||
raise ImportError() | |||
from .core._imperative_rt.utils import Logger as _imperative_rt_logger | |||
class MegBrainLogFormatter(MegEngineLogFormatter): | |||
date = "%(asctime)s[mgb] " | |||
def _color_date(self, msg): | |||
return "\x1b[33m{}\x1b[0m".format(msg) | |||
_megbrain_logger = get_logger("megbrain", MegBrainLogFormatter) | |||
_imperative_rt_logger.set_log_handler(_megbrain_logger) | |||
if _default_level == logging.getLevelName("ERROR"): | |||
_imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Error) | |||
elif _default_level == logging.getLevelName("INFO"): | |||
_imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Info) | |||
else: | |||
_imperative_rt_logger.set_log_level(_imperative_rt_logger.LogLevel.Debug) | |||
def set_mgb_log_level(level): | |||
r"""Sets megbrain log level | |||
:type level: int e.g. logging.INFO | |||
:param level: new log level | |||
:return: original log level | |||
""" | |||
logger = _megbrain_logger | |||
rst = logger.getEffectiveLevel() | |||
logger.setLevel(level) | |||
return rst | |||
except ImportError as exc: | |||
def set_mgb_log_level(level): | |||
raise NotImplementedError("imperative_rt has not been imported") | |||
@contextlib.contextmanager | |||
def replace_mgb_log_level(level): | |||
r"""Replaces megbrain log level in a block and restore after exiting. | |||
:type level: int e.g. logging.INFO | |||
:param level: new log level | |||
""" | |||
old = set_mgb_log_level(level) | |||
try: | |||
yield | |||
finally: | |||
set_mgb_log_level(old) | |||
def enable_debug_log(): | |||
r"""Sets logging level to debug for all components. | |||
""" | |||
set_log_level(logging.DEBUG) | |||
set_mgb_log_level(logging.DEBUG) |
@@ -0,0 +1,24 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
from .identity import Identity | |||
from .linear import Linear | |||
from .module import Module | |||
from .parampack import ParamPack | |||
from .pooling import AvgPool2d, MaxPool2d | |||
from .quant_dequant import DequantStub, QuantStub | |||
from .sequential import Sequential |
@@ -0,0 +1,231 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional import leaky_relu, prelu, relu, sigmoid, softmax | |||
from ..tensor_nn import Parameter | |||
from .module import Module | |||
class Softmax(Module): | |||
r""" | |||
Applies a softmax function. Softmax is defined as: | |||
.. math:: | |||
\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} | |||
It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the | |||
n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. | |||
:param axis: An axis along which softmax will be applied. By default, | |||
softmax will apply along the highest ranked axis. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2]).astype(np.float32)) | |||
softmax = M.Softmax() | |||
output = softmax(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.011656 0.031685 0.086129 0.234122 0.636409] | |||
""" | |||
def __init__(self, axis=None): | |||
super().__init__() | |||
self.axis = axis | |||
def forward(self, inputs): | |||
return softmax(inputs, self.axis) | |||
class Sigmoid(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)} | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
sigmoid = M.Sigmoid() | |||
output = sigmoid(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.119203 0.268941 0.5 0.731059 0.880797] | |||
""" | |||
def forward(self, inputs): | |||
return sigmoid(inputs) | |||
class ReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{ReLU}(x) = \max(x, 0) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-2,-1,0,1,2,]).astype(np.float32)) | |||
relu = M.ReLU() | |||
output = relu(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0. 0. 0. 1. 2.] | |||
""" | |||
def forward(self, x): | |||
return relu(x) | |||
class PReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x) | |||
or | |||
.. math:: | |||
\text{PReLU}(x) = | |||
\begin{cases} | |||
x, & \text{ if } x \geq 0 \\ | |||
ax, & \text{ otherwise } | |||
\end{cases} | |||
Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses | |||
a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, | |||
a seperate :math:`a` is used for each input channle. | |||
:param num_parameters: number of :math:`a` to learn, there is only two | |||
values are legitimate: 1, or the number of channels at input. Default: 1 | |||
:param init: the initial value of :math:`a`. Default: 0.25 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-1.2, -3.7, 2.7]).astype(np.float32)) | |||
prelu = M.PReLU() | |||
output = prelu(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[-0.3 -0.925 2.7 ] | |||
""" | |||
def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||
super().__init__() | |||
self.num_parameters = num_parameters | |||
if num_parameters > 1: | |||
# Assume format is NCHW | |||
self.weight = Parameter( | |||
data=np.full((1, num_parameters, 1, 1), init, dtype=np.float32) | |||
) | |||
else: | |||
self.weight = Parameter(data=[init]) | |||
def forward(self, inputs): | |||
assert self.weight.shape == (1,) or self.weight.shape == ( | |||
1, | |||
int(inputs.shape[1]), | |||
1, | |||
1, | |||
), "invalid weight's shape" | |||
return prelu(inputs, self.weight) | |||
class LeakyReLU(Module): | |||
r""" | |||
Applies the element-wise function: | |||
.. math:: | |||
\text{LeakyReLU}(x) = \max(0,x) + negative\_slope \times \min(0,x) | |||
or | |||
.. math:: | |||
\text{LeakyReLU}(x) = | |||
\begin{cases} | |||
x, & \text{ if } x \geq 0 \\ | |||
negative\_slope \times x, & \text{ otherwise } | |||
\end{cases} | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
data = mge.tensor(np.array([-8, -12, 6, 10]).astype(np.float32)) | |||
leakyrelu = M.LeakyReLU(0.01) | |||
output = leakyrelu(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[-0.08 -0.12 6. 10. ] | |||
""" | |||
def __init__(self, negative_slope: float = 0.01): | |||
super().__init__() | |||
self.negative_slope = negative_slope | |||
def forward(self, inputs): | |||
return leaky_relu(inputs, self.negative_slope) |
@@ -0,0 +1,281 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional | |||
import numpy as np | |||
from ..distributed.group import WORLD, Group | |||
from ..functional import batch_norm2d, sync_batch_norm | |||
from ..tensor_nn import Buffer, Parameter | |||
from . import init | |||
from .module import Module | |||
class _BatchNorm(Module): | |||
def __init__( | |||
self, | |||
num_features, | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
freeze=False, | |||
): | |||
super(_BatchNorm, self).__init__() | |||
self.num_features = num_features | |||
self.eps = eps | |||
self.momentum = momentum | |||
self.affine = affine | |||
self.track_running_stats = track_running_stats | |||
self._track_running_stats_saved = track_running_stats | |||
self.freeze = freeze | |||
if self.affine: | |||
self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||
self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||
else: | |||
self.weight = None | |||
self.bias = None | |||
tshape = (1, self.num_features, 1, 1) | |||
if self.track_running_stats: | |||
self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32)) | |||
self.running_var = Buffer(np.ones(tshape, dtype=np.float32)) | |||
else: | |||
self.running_mean = None | |||
self.running_var = None | |||
def reset_running_stats(self) -> None: | |||
if self.track_running_stats: | |||
init.zeros_(self.running_mean) | |||
init.ones_(self.running_var) | |||
def reset_parameters(self) -> None: | |||
self.reset_running_stats() | |||
if self.affine: | |||
init.ones_(self.weight) | |||
init.zeros_(self.bias) | |||
def _check_input_ndim(self, inp): | |||
raise NotImplementedError | |||
def forward(self, inp): | |||
self._check_input_ndim(inp) | |||
if self._track_running_stats_saved == False: | |||
assert ( | |||
self.track_running_stats == False | |||
), "track_running_stats can not be initilized to False and changed to True later" | |||
_ndims = len(inp.shape) | |||
if _ndims != 4: | |||
origin_shape = inp.shapeof() | |||
if _ndims == 2: | |||
n, c = inp.shapeof(0), inp.shapeof(1) | |||
new_shape = (n, c, 1, 1) | |||
elif _ndims == 3: | |||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
new_shape = (n, c, h, 1) | |||
inp = inp.reshape(new_shape) | |||
if self.freeze and self.training and self._track_running_stats_saved: | |||
scale = self.weight.reshape(1, -1, 1, 1) * ( | |||
self.running_var + self.eps | |||
) ** (-0.5) | |||
bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale | |||
return inp * scale.detach() + bias.detach() | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
output = batch_norm2d( | |||
inp, | |||
self.running_mean if self.track_running_stats else None, | |||
self.running_var if self.track_running_stats else None, | |||
self.weight, | |||
self.bias, | |||
training=self.training | |||
or ((self.running_mean is None) and (self.running_var is None)), | |||
momentum=exponential_average_factor, | |||
eps=self.eps, | |||
) | |||
if _ndims != 4: | |||
output = output.reshape(origin_shape) | |||
return output | |||
class SyncBatchNorm(_BatchNorm): | |||
r""" | |||
Applies Synchronization Batch Normalization. | |||
""" | |||
def __init__( | |||
self, | |||
num_features, | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
freeze=False, | |||
group: Optional[Group] = None, | |||
) -> None: | |||
super().__init__( | |||
num_features, eps, momentum, affine, track_running_stats, freeze | |||
) | |||
self.group = group | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) not in {2, 3, 4}: | |||
raise ValueError( | |||
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
) | |||
def forward(self, inp): | |||
self._check_input_ndim(inp) | |||
_ndims = len(inp.shape) | |||
if _ndims != 4: | |||
origin_shape = inp.shapeof() | |||
if _ndims == 2: | |||
n, c = inp.shapeof(0), inp.shapeof(1) | |||
new_shape = (n, c, 1, 1) | |||
elif _ndims == 3: | |||
n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
new_shape = (n, c, h, 1) | |||
inp = inp.reshape(new_shape) | |||
if self.training and self.track_running_stats: | |||
exponential_average_factor = self.momentum | |||
else: | |||
exponential_average_factor = 0.0 # useless | |||
output = sync_batch_norm( | |||
inp, | |||
self.running_mean, | |||
self.running_var, | |||
self.weight, | |||
self.bias, | |||
self.training or not self.track_running_stats, | |||
exponential_average_factor, | |||
self.eps, | |||
group=self.group, | |||
) | |||
if _ndims != 4: | |||
output = output.reshape(origin_shape) | |||
return output | |||
class BatchNorm1d(_BatchNorm): | |||
r""" | |||
Applies Batch Normalization over a 2D/3D tensor. | |||
Refer to :class:`~.BatchNorm2d` for more information. | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) not in {2, 3}: | |||
raise ValueError( | |||
"expected 2D or 3D input (got {}D input)".format(len(inp.shape)) | |||
) | |||
class BatchNorm2d(_BatchNorm): | |||
r""" | |||
Applies Batch Normalization over a 4D tensor. | |||
.. math:: | |||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |||
The mean and standard-deviation are calculated per-dimension over | |||
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable | |||
parameter vectors. | |||
By default, during training this layer keeps running estimates of its | |||
computed mean and variance, which are then used for normalization during | |||
evaluation. The running estimates are kept with a default :attr:`momentum` | |||
of 0.9. | |||
If :attr:`track_running_stats` is set to ``False``, this layer will not | |||
keep running estimates, and batch statistics are instead used during | |||
evaluation time. | |||
.. note:: | |||
This :attr:`momentum` argument is different from one used in optimizer | |||
classes and the conventional notion of momentum. Mathematically, the | |||
update rule for running statistics here is | |||
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`, | |||
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |||
new observed value. | |||
Because the Batch Normalization is done over the `C` dimension, computing | |||
statistics on `(N, H, W)` slices, it's common terminology to call this | |||
Spatial Batch Normalization. | |||
:type num_features: int | |||
:param num_features: usually the :math:`C` from an input of size | |||
:math:`(N, C, H, W)` or the highest ranked dimension of an input with | |||
less than 4D. | |||
:type eps: float | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
:type momentum: float | |||
:param momentum: the value used for the `running_mean` and `running_var` | |||
computation. | |||
Default: 0.9 | |||
:type affine: bool | |||
:param affine: a boolean value that when set to ``True``, this module has | |||
learnable affine parameters. Default: ``True`` | |||
:type track_running_stats: bool | |||
:param track_running_stats: when set to ``True``, this module tracks the | |||
running mean and variance. When set to ``False``, this module does not | |||
track such statistics and always uses batch statistics in both training | |||
and eval modes. Default: ``True``. | |||
:type freeze: bool | |||
:param freeze: when set to ``True``, this module does not update the | |||
running mean and variance, and uses the running mean and variance instead of | |||
the batch mean and batch variance to normalize the input. The parameter takes effect | |||
only when the module is initilized with ``track_running_stats`` as ``True`` and | |||
the module is in training mode. | |||
Default: ``False``. | |||
Examples: | |||
.. testcode:: | |||
import megengine as mge | |||
import megengine.module as M | |||
# With Learnable Parameters | |||
m = M.BatchNorm2d(4) | |||
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||
oup = m(inp) | |||
print(m.weight, m.bias) | |||
# Without Learnable Parameters | |||
m = M.BatchNorm2d(4, affine=False) | |||
oup = m(inp) | |||
print(m.weight, m.bias) | |||
.. testoutput:: | |||
Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.]) | |||
None None | |||
""" | |||
def _check_input_ndim(self, inp): | |||
if len(inp.shape) != 4: | |||
raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape))) |
@@ -0,0 +1,22 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from ..functional import concat | |||
from ..tensor import Tensor | |||
from .module import Module | |||
class Concat(Module): | |||
r""" | |||
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return concat(inps, axis) |
@@ -0,0 +1,391 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Tuple, Union | |||
import numpy as np | |||
from ..core.ops._internal import param_defs as P | |||
from ..functional import conv2d, conv_transpose2d, local_conv2d, relu | |||
from ..functional.types import _pair, _pair_nonzero | |||
from ..tensor_nn import Parameter | |||
from . import init | |||
from .module import Module | |||
class _ConvNd(Module): | |||
"""base class for convolution modules, including transposed conv""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]], | |||
padding: Union[int, Tuple[int, int]], | |||
dilation: Union[int, Tuple[int, int]], | |||
groups: int, | |||
bias: bool = True, | |||
): | |||
super().__init__() | |||
if in_channels % groups != 0: | |||
raise ValueError("in_channels must be divisible by groups") | |||
if out_channels % groups != 0: | |||
raise ValueError("out_channels must be divisible by groups") | |||
self.in_channels = in_channels | |||
self.out_channels = out_channels | |||
self.kernel_size = kernel_size | |||
self.stride = stride | |||
self.padding = padding | |||
self.dilation = dilation | |||
self.groups = groups | |||
self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32)) | |||
self.bias = None | |||
if bias: | |||
self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32)) | |||
self.reset_parameters() | |||
@abstractmethod | |||
def _get_fanin(self): | |||
pass | |||
def reset_parameters(self) -> None: | |||
fanin = self._get_fanin() | |||
std = np.sqrt(1 / fanin) | |||
init.normal_(self.weight, 0.0, std) | |||
if self.bias is not None: | |||
init.zeros_(self.bias) | |||
@abstractmethod | |||
def _infer_weight_shape(self): | |||
pass | |||
@abstractmethod | |||
def _infer_bias_shape(self): | |||
pass | |||
class Conv2d(_ConvNd): | |||
r"""Applies a 2D convolution over an input tensor. | |||
For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`, | |||
this layer generates an output of the size | |||
:math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the | |||
process described as below: | |||
.. math:: | |||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + | |||
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) | |||
where :math:`\star` is the valid 2D cross-correlation operator, | |||
:math:`N` is a batch size, :math:`C` denotes a number of channels, | |||
:math:`H` is a height of input planes in pixels, and :math:`W` is | |||
width in pixels. | |||
When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||
where `K` is a positive integer, this operation is also known as depthwise | |||
convolution. | |||
In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
a depthwise convolution with a depthwise multiplier `K`, can be constructed | |||
by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)``. | |||
:param bias: whether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
_compute_mode_type = P.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
ic = self.in_channels | |||
return kh * kw * ic | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ochl, ichl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ochl // group, ichl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def calc_conv(self, inp, weight, bias): | |||
return conv2d( | |||
inp, | |||
weight, | |||
bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
def forward(self, inp): | |||
return self.calc_conv(inp, self.weight, self.bias) | |||
class ConvTranspose2d(_ConvNd): | |||
r"""Applies a 2D transposed convolution over an input tensor. | |||
This module is also known as a deconvolution or a fractionally-strided convolution. | |||
:class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation | |||
with respect to its input. | |||
Convolution usually reduces the size of input, while transposed convolution works | |||
the opposite way, transforming a smaller input to a larger output while preserving the | |||
connectivity pattern. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | |||
:param bias: wether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
_compute_mode_type = P.Convolution.ComputeMode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
kernel_size = _pair_nonzero(kernel_size) | |||
stride = _pair_nonzero(stride) | |||
padding = _pair(padding) | |||
dilation = _pair_nonzero(dilation) | |||
self.conv_mode = self._conv_mode_type.convert(conv_mode) | |||
self.compute_mode = self._compute_mode_type.convert(compute_mode) | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
) | |||
def _get_fanin(self): | |||
kh, kw = self.kernel_size | |||
oc = self.out_channels | |||
return kh * kw * oc | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
ichl = self.in_channels | |||
ochl = self.out_channels | |||
kh, kw = self.kernel_size | |||
if group == 1: | |||
# Assume format is NCHW | |||
return (ichl, ochl, kh, kw) | |||
assert ( | |||
ichl % group == 0 and ochl % group == 0 | |||
), "invalid config: input_channels={} output_channels={} group={}".format( | |||
ichl, ochl, group | |||
) | |||
# Assume format is NCHW | |||
return (group, ichl // group, ochl // group, kh, kw) | |||
def _infer_bias_shape(self): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def forward(self, inp): | |||
return conv_transpose2d( | |||
inp, | |||
self.weight, | |||
self.bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
self.conv_mode, | |||
self.compute_mode, | |||
) | |||
class LocalConv2d(Conv2d): | |||
r"""Applies a spatial convolution with untied kernels over an input 4D tensor. | |||
It is also known as the locally connected layer. | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param input_height: the height of the input images. | |||
:param input_width: the width of the input images. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
The shape of weight is ``(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
input_height: int, | |||
input_width: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
): | |||
self.input_height = input_height | |||
self.input_width = input_width | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias=False, | |||
) | |||
def _infer_weight_shape(self): | |||
group = self.groups | |||
output_height = ( | |||
self.input_height + self.padding[0] * 2 - self.kernel_size[0] | |||
) // self.stride[0] + 1 | |||
output_width = ( | |||
self.input_width + self.padding[1] * 2 - self.kernel_size[1] | |||
) // self.stride[1] + 1 | |||
# Assume format is NCHW | |||
return ( | |||
group, | |||
output_height, | |||
output_width, | |||
self.in_channels // group, | |||
self.kernel_size[0], | |||
self.kernel_size[1], | |||
self.out_channels // group, | |||
) | |||
def forward(self, inp): | |||
return local_conv2d( | |||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
) | |||
class ConvRelu2d(Conv2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return relu(self.calc_conv(inp, self.weight, self.bias)) |
@@ -0,0 +1,69 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
from ..functional import relu | |||
from .batchnorm import BatchNorm2d | |||
from .conv import Conv2d | |||
from .module import Module | |||
class _ConvBnActivation2d(Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
): | |||
super().__init__() | |||
self.conv = Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
conv_mode, | |||
compute_mode, | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return self.bn(self.conv(inp)) | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return relu(self.bn(self.conv(inp))) |
@@ -0,0 +1,29 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional import dropout | |||
from .module import Module | |||
class Dropout(Module): | |||
r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||
Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||
by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||
:param drop_prob: The probability to drop (set to zero) each single element | |||
""" | |||
def __init__(self, drop_prob=0.0): | |||
super().__init__() | |||
self.drop_prob = drop_prob | |||
def forward(self, inputs): | |||
if self.training: | |||
return dropout(inputs, self.drop_prob, rescale=True) | |||
else: | |||
return inputs |
@@ -0,0 +1,79 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core.ops._internal import param_defs as P | |||
from ..functional.elemwise import _elwise | |||
from ..tensor import Tensor | |||
from .module import Module | |||
class Elemwise(Module): | |||
r""" | |||
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
:param method: the elemwise method, support the following string. | |||
It will do the normal elemwise operator for float. | |||
* "ADD": a + b | |||
* "FUSE_ADD_RELU": max(x+y, 0) | |||
* "MUL": x * y | |||
* "MIN": min(x, y) | |||
* "MAX": max(x, y) | |||
* "SUB": x - y | |||
* "TRUE_DIV": x / y | |||
* "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
* "FUSE_ADD_TANH": tanh(x + y) | |||
* "RELU": x > 0 ? x : 0 | |||
* "ABS": x > 0 ? x : -x | |||
* "SIGMOID": sigmoid(x) | |||
* "EXP": exp(x) | |||
* "TANH": tanh(x) | |||
* "FUSE_MUL_ADD3": x * y + z | |||
* "FAST_TANH": fast_tanh(x) | |||
* "NEGATE": -x | |||
* "ACOS": acos(x) | |||
* "ASIN": asin(x) | |||
* "CEIL": ceil(x) | |||
* "COS": cos(x) | |||
* "EXPM1": expm1(x) | |||
* "FLOOR": floor(x) | |||
* "LOG": log(x) | |||
* "LOG1P": log1p(x) | |||
* "SIN": sin(x) | |||
* "ROUND": round(x) | |||
* "ERF": erf(x) | |||
* "ERFINV": erfinv(x) | |||
* "ERFC": erfc(x) | |||
* "ERFCINV": erfcinv(x) | |||
* "ABS_GRAD": abs_grad | |||
* "FLOOR_DIV": floor_div | |||
* "MOD": mod | |||
* "SIGMOID_GRAD": sigmoid_grad | |||
* "SWITCH_GT0": switch_gt0 | |||
* "TANH_GRAD": tanh_grad | |||
* "LT": lt | |||
* "LEQ": leq | |||
* "EQ": eq | |||
* "POW": pow | |||
* "LOG_SUM_EXP": log_sum_exp | |||
* "FAST_TANH_GRAD": fast_tanh_grad | |||
* "ATAN2": atan2 | |||
* "COND_LEQ_MOV": cond_leq_mov | |||
* "H_SWISH": h_swish | |||
* "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
* "H_SWISH_GRAD": h_swish_grad | |||
""" | |||
_elemwise_mode_type = P.Elemwise.Mode | |||
def __init__(self, method): | |||
super().__init__() | |||
self.method = self._elemwise_mode_type.convert(method) | |||
def forward(self, *inps): | |||
return _elwise(*inps, mode=self.method) |
@@ -0,0 +1,171 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional | |||
import numpy as np | |||
from ..functional import embedding as embedding_func | |||
from ..tensor_nn import Parameter | |||
from . import init | |||
from .module import Module | |||
class Embedding(Module): | |||
r""" | |||
A simple lookup table that stores embeddings of a fixed dictionary and size. | |||
This module is often used to store word embeddings and retrieve them using indices. | |||
The input to the module is a list of indices, and the output is the corresponding word embeddings. | |||
The indices should less than num_embeddings. | |||
:param num_embeddings: size of embedding dictionary. | |||
:param embedding_dim: size of each embedding vector. | |||
:param padding_idx: should be set to None, not support now. | |||
:param max_norm: should be set to None, not support now. | |||
:param norm_type: should be set to None, not support now. | |||
:param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
embedding = M.Embedding(2, 5, initial_weight=weight) | |||
output = embedding(data) | |||
with np.printoptions(precision=6): | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[0.1 1.1 2.1 3.1 4.1] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[1.2 2.3 3.4 4.5 5.6] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]]] | |||
""" | |||
def __init__( | |||
self, | |||
num_embeddings: int, | |||
embedding_dim: int, | |||
padding_idx: Optional[int] = None, | |||
max_norm: Optional[float] = None, | |||
norm_type: Optional[float] = None, | |||
initial_weight: Parameter = None, | |||
): | |||
super().__init__() | |||
if padding_idx is not None: | |||
raise ValueError("Not support padding index now.") | |||
if max_norm is not None or norm_type is not None: | |||
raise ValueError("Not support weight normalize now.") | |||
self.padding_idx = padding_idx | |||
self.max_norm = max_norm | |||
self.norm_type = norm_type | |||
self.num_embeddings = num_embeddings | |||
self.embedding_dim = embedding_dim | |||
if initial_weight is None: | |||
self.weight = Parameter( | |||
np.random.uniform( | |||
size=(self.num_embeddings, self.embedding_dim) | |||
).astype(np.float32) | |||
) | |||
self.reset_parameters() | |||
else: | |||
if initial_weight.shape != (num_embeddings, embedding_dim): | |||
raise ValueError( | |||
"The weight shape should match num_embeddings and embedding_dim" | |||
) | |||
self.weight = Parameter(initial_weight.numpy()) | |||
def reset_parameters(self) -> None: | |||
init.normal_(self.weight) | |||
def forward(self, inputs): | |||
return embedding_func(inputs, self.weight) | |||
@classmethod | |||
def from_pretrained( | |||
cls, | |||
embeddings: Parameter, | |||
freeze: Optional[bool] = True, | |||
padding_idx: Optional[int] = None, | |||
max_norm: Optional[float] = None, | |||
norm_type: Optional[float] = None, | |||
): | |||
r""" | |||
Creates Embedding instance from given 2-dimensional FloatTensor. | |||
:param embeddings: Tensor contained weight for the embedding. | |||
:param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. | |||
:param padding_idx: should be set to None, not support Now. | |||
:param max_norm: should be set to None, not support Now. | |||
:param norm_type: should be set to None, not support Now. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) | |||
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) | |||
embedding = M.Embedding.from_pretrained(weight, freeze=False) | |||
output = embedding(data) | |||
print(output.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[0.1 1.1 2.1 3.1 4.1] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]] | |||
[[1.2 2.3 3.4 4.5 5.6] | |||
[1.2 2.3 3.4 4.5 5.6] | |||
[0.1 1.1 2.1 3.1 4.1]]] | |||
""" | |||
embeddings_shape = embeddings.shape | |||
embeddings_dim = len(embeddings_shape) | |||
if embeddings_dim != 2: | |||
raise ValueError("Embeddings parameter is expected to be 2-dimensional") | |||
rows = embeddings_shape[0] | |||
cols = embeddings_shape[1] | |||
embedding = cls( | |||
num_embeddings=rows, | |||
embedding_dim=cols, | |||
initial_weight=embeddings, | |||
padding_idx=padding_idx, | |||
max_norm=max_norm, | |||
norm_type=norm_type, | |||
) | |||
embedding.weight.requires_grad = not freeze | |||
return embedding |
@@ -0,0 +1,56 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional import cambricon_subgraph, extern_opr_subgraph | |||
from .module import Module | |||
class CambriconSubgraph(Module): | |||
r"""Load a serialized Cambricon subgraph. | |||
See :func:`~.cambricon_subgraph` for more details. | |||
""" | |||
def __init__( | |||
self, data, symbol, tensor_dim_mutable, | |||
): | |||
super(CambriconSubgraph, self).__init__() | |||
self._data = data | |||
self.symbol = symbol | |||
self.tensor_dim_mutable = tensor_dim_mutable | |||
@property | |||
def data(self): | |||
return self._data.tobytes() | |||
@data.setter | |||
def data(self, val): | |||
self._data = np.frombuffer(val, dtype=np.uint8) | |||
def forward(self, inputs): | |||
outputs = cambricon_subgraph( | |||
inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||
) | |||
return outputs | |||
class ExternOprSubgraph(Module): | |||
r"""Load a serialized extern opr subgraph. | |||
""" | |||
def __init__(self, data, name, output_shapes): | |||
super(ExternOprSubgraph, self).__init__() | |||
self.data = data | |||
self.name = name | |||
self.output_shapes = output_shapes | |||
def forward(self, inputs): | |||
outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||
return outputs |
@@ -0,0 +1,17 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..functional import identity | |||
from .module import Module | |||
class Identity(Module): | |||
r"""A placeholder identity operator that will ignore any argument.""" | |||
def forward(self, x): | |||
return identity(x) |
@@ -0,0 +1,261 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
from functools import reduce | |||
from typing import Optional, Tuple, Union | |||
import numpy as np | |||
from ..tensor import Tensor | |||
def fill_(tensor: Tensor, val: Union[float, int]) -> None: | |||
"""Fill the given ``tensor`` with value ``val``. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param val: The value to be filled throughout the tensor | |||
""" | |||
tensor.set_value(np.full(tensor.shape, val, tensor.dtype)) | |||
def zeros_(tensor: Tensor) -> None: | |||
"""Fill the given ``tensor`` with scalar value `0`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
""" | |||
fill_(tensor, 0) | |||
def ones_(tensor: Tensor) -> None: | |||
"""Fill the given ``tensor`` with the scalar value `1`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
""" | |||
fill_(tensor, 1) | |||
def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||
r"""Fill the given ``tensor`` with random value sampled from uniform distribution | |||
:math:`\mathcal{U}(\text{a}, \text{b})`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Lower bound of the sampling interval | |||
:param b: Upper bound of the sampling interval | |||
""" | |||
tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) | |||
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||
r"""Fill the given ``tensor`` with random value sampled from normal distribution | |||
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param mean: The mean of the normal distribution | |||
:param std: The standard deviation of the normal distribution | |||
""" | |||
tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) | |||
def calculate_gain( | |||
nonlinearity: str, param: Optional[Union[int, float]] = None | |||
) -> float: | |||
r"""Return a recommended gain value (see the table below) for the given nonlinearity | |||
function. | |||
================= ==================================================== | |||
nonlinearity gain | |||
================= ==================================================== | |||
Linear / Identity :math:`1` | |||
Conv{1,2,3}D :math:`1` | |||
Sigmoid :math:`1` | |||
Tanh :math:`\frac{5}{3}` | |||
ReLU :math:`\sqrt{2}` | |||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||
================= ==================================================== | |||
:param nonlinearity: Name of the non-linear function | |||
:param param: Optional parameter for leaky_relu. Only effective when | |||
``nonlinearity`` is "leaky_relu". | |||
""" | |||
linear_fns = [ | |||
"linear", | |||
"conv1d", | |||
"conv2d", | |||
"conv3d", | |||
"conv_transpose1d", | |||
"conv_transpose2d", | |||
"conv_transpose3d", | |||
] | |||
if nonlinearity in linear_fns or nonlinearity == "sigmoid": | |||
return 1 | |||
if nonlinearity == "tanh": | |||
return 5.0 / 3 | |||
if nonlinearity == "relu": | |||
return math.sqrt(2.0) | |||
if nonlinearity == "leaky_relu": | |||
if param is None: | |||
negative_slope = 0.01 | |||
elif ( | |||
not isinstance(param, bool) | |||
and isinstance(param, int) | |||
or isinstance(param, float) | |||
): | |||
# True/False are instances of int, hence check above | |||
negative_slope = param | |||
else: | |||
raise ValueError("negative_slope {} not a valid number".format(param)) | |||
return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: | |||
""" | |||
Calculate fan_in / fan_out value for given weight tensor. This function assumes | |||
input tensor is stored in NCHW format. | |||
:param tensor: Weight tensor in NCHW format | |||
""" | |||
shape = tensor.shape | |||
ndim = len(shape) | |||
if ndim < 2: | |||
raise ValueError( | |||
"fan_in and fan_out can not be computed for tensor with fewer than 2 " | |||
"dimensions" | |||
) | |||
if ndim == 2: # Linear | |||
fan_in = shape[1] | |||
fan_out = shape[0] | |||
else: | |||
num_input_fmaps = shape[1] | |||
num_output_fmaps = shape[0] | |||
receptive_field_size = 1 | |||
if ndim > 2: | |||
receptive_field_size = reduce(lambda x, y: x * y, shape[2:], 1) | |||
fan_in = num_input_fmaps * receptive_field_size | |||
fan_out = num_output_fmaps * receptive_field_size | |||
return fan_in, fan_out | |||
def calculate_correct_fan(tensor: Tensor, mode: str) -> float: | |||
""" | |||
Calculate fan_in or fan_out value for given weight tensor, depending on given | |||
``mode``. | |||
See :func:`calculate_fan_in_and_fan_out` for details. | |||
:param tensor: Weight tensor in NCHW format | |||
:param mode: ``'fan_in'`` or ``'fan_out'`` | |||
""" | |||
mode = mode.lower() | |||
valid_modes = ["fan_in", "fan_out"] | |||
if mode not in valid_modes: | |||
raise ValueError( | |||
"Mode {} not supported, please use one of {}".format(mode, valid_modes) | |||
) | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
return fan_in if mode == "fan_in" else fan_out | |||
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: | |||
r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` | |||
where | |||
.. math:: | |||
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} | |||
Also known as Glorot initialization. Detailed information can be retrieved from | |||
`Understanding the difficulty of training deep feedforward neural networks` - | |||
Glorot, X. & Bengio, Y. (2010). | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param gain: Scaling factor for :math:`a`. | |||
""" | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
a = math.sqrt(3.0) * std | |||
uniform_(tensor, -a, a) | |||
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: | |||
r"""Fill ``tensor`` with random values sampled from | |||
:math:`\mathcal{N}(0, \text{std}^2)` where | |||
.. math:: | |||
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} | |||
Also known as Glorot initialization. Detailed information can be retrieved from | |||
`Understanding the difficulty of training deep feedforward neural networks` - | |||
Glorot, X. & Bengio, Y. (2010). | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param gain: Scaling factor for :math:`std`. | |||
""" | |||
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) | |||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |||
normal_(tensor, 0.0, std) | |||
def msra_uniform_( | |||
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
) -> None: | |||
r"""Fill ``tensor`` wilth random values sampled from | |||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
.. math:: | |||
\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} | |||
Detailed information can be retrieved from | |||
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
classification` | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Optional parameter for calculating gain for leaky_relu. See | |||
:func:`calculate_gain` for details. | |||
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for | |||
details. | |||
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
See :func:`calculate_gain` for details. | |||
""" | |||
fan = calculate_correct_fan(tensor, mode) | |||
gain = calculate_gain(nonlinearity, a) | |||
std = gain / math.sqrt(fan) | |||
bound = math.sqrt(3.0) * std | |||
uniform_(tensor, -bound, bound) | |||
def msra_normal_( | |||
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" | |||
) -> None: | |||
r"""Fill ``tensor`` wilth random values sampled from | |||
:math:`\mathcal{N}(0, \text{std}^2)` where | |||
.. math:: | |||
\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} | |||
Detailed information can be retrieved from | |||
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet | |||
classification` | |||
:param tensor: An n-dimentional tensor to be initialized | |||
:param a: Optional parameter for calculating gain for leaky_relu. See | |||
:func:`calculate_gain` for details. | |||
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the | |||
scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for | |||
details. | |||
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. | |||
See :func:`calculate_gain` for details. | |||
""" | |||
fan = calculate_correct_fan(tensor, mode) | |||
gain = calculate_gain(nonlinearity, a) | |||
std = gain / math.sqrt(fan) | |||
normal_(tensor, 0, std) |
@@ -0,0 +1,61 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional import linear | |||
from ..tensor_nn import Parameter | |||
from . import init | |||
from .module import Module | |||
class Linear(Module): | |||
r"""Applies a linear transformation to the input. For instance, if input | |||
is x, then output y is: | |||
.. math:: | |||
y = xW^T + b | |||
where :math:`y_i= \sum_j W_{ij} x_j + b_i` | |||
:param in_features: size of each input sample. | |||
:param out_features: size of each output sample. | |||
:param bias: If set to ``False``, the layer will not learn an additive bias. | |||
Default: ``True`` | |||
""" | |||
def __init__( | |||
self, in_features: int, out_features: int, bias: bool = True, **kwargs | |||
): | |||
super().__init__(**kwargs) | |||
self.out_features = out_features | |||
self.in_features = in_features | |||
w_shape = (out_features, in_features) | |||
self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
self.bias = None | |||
if bias: | |||
b_shape = (out_features,) | |||
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
self.reset_parameters() | |||
def _get_fanin(self): | |||
return self.in_features | |||
def reset_parameters(self) -> None: | |||
fanin = self._get_fanin() | |||
std = np.sqrt(1 / fanin) | |||
init.normal_(self.weight, 0.0, std) | |||
if self.bias is not None: | |||
init.zeros_(self.bias) | |||
def _calc_linear(self, x, weight, bias): | |||
return linear(x, weight, bias) | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) |
@@ -0,0 +1,508 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
from ..core.tensor.dtype import is_quantize | |||
from ..logger import get_logger | |||
from ..tensor import Tensor | |||
from ..tensor_nn import Buffer, Parameter | |||
from ..utils.hook import HookHandler | |||
logger = get_logger(__name__) | |||
def _expand_structure(key, obj): | |||
if isinstance(obj, (Tensor, Module)): | |||
return [(key, obj)] | |||
elif isinstance(obj, (list, tuple, dict)): | |||
ret = [] | |||
if isinstance(obj, dict): | |||
targets = ((k, obj[k]) for k in sorted(obj)) | |||
else: | |||
targets = ((str(k), v) for k, v in enumerate(obj)) | |||
for k, o in targets: | |||
sub_ret = _expand_structure(k, o) | |||
if sub_ret and not isinstance(k, str): | |||
raise AssertionError( | |||
"keys for Tensor and Module must be str, error key: {}".format(k) | |||
) | |||
for kt, vt in sub_ret: | |||
ret.extend([(key + "." + kt, vt)]) | |||
return ret | |||
else: | |||
return [] | |||
def _is_parameter(obj): | |||
return isinstance(obj, Parameter) | |||
def _is_buffer(obj): | |||
return isinstance(obj, Buffer) | |||
def _is_module(obj): | |||
return isinstance(obj, Module) | |||
class Module(metaclass=ABCMeta): | |||
"""Base Module class. | |||
""" | |||
def __init__(self): | |||
# runtime attributes | |||
self.training = True | |||
self.quantize_disabled = False | |||
# hooks | |||
self._forward_pre_hooks = OrderedDict() | |||
self._forward_hooks = OrderedDict() | |||
@abstractmethod | |||
def forward(self, inputs): | |||
pass | |||
def register_forward_pre_hook(self, hook: Callable) -> HookHandler: | |||
"""Register a hook to handle forward inputs. `hook` should be a function | |||
Note that `inputs` keyword inputs | |||
:param hook: a function that receive `module` and `inputs`, then return | |||
a modified `inputs` or `None`. | |||
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
""" | |||
return HookHandler(self._forward_pre_hooks, hook) | |||
def register_forward_hook(self, hook: Callable) -> HookHandler: | |||
"""Register a hook to handle forward results. `hook` should be a function that | |||
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. | |||
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. | |||
""" | |||
return HookHandler(self._forward_hooks, hook) | |||
def __call__(self, *inputs, **kwargs): | |||
for hook in self._forward_pre_hooks.values(): | |||
modified_inputs = hook(self, inputs) | |||
if modified_inputs is not None: | |||
if not isinstance(modified_inputs, tuple): | |||
modified_inputs = (modified_inputs,) | |||
inputs = modified_inputs | |||
outputs = self.forward(*inputs, **kwargs) | |||
for hook in self._forward_hooks.values(): | |||
modified_outputs = hook(self, inputs, outputs) | |||
if modified_outputs is not None: | |||
outputs = modified_outputs | |||
return outputs | |||
def _flatten( | |||
self, | |||
*, | |||
recursive: bool = True, | |||
with_key: bool = False, | |||
with_parent: bool = False, | |||
prefix: Optional[str] = None, | |||
predicate: Callable[[Any], bool] = lambda _: True, | |||
seen: Optional[Set[int]] = None | |||
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | |||
"""Scans the module object and returns an iterable for the :class:`~.Tensor` | |||
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||
calls of this function with same arguments, the order of objects within the | |||
returned iterable is guaranteed to be identical, as long as all the involved | |||
module objects' ``__dict__`` does not change thoughout those calls. | |||
:param recursive: Whether to recursively scan all the submodules. | |||
:param with_key: Whether to yield keys along with yielded objects. | |||
:param with_parent: Whether to yield ``self`` along with yielded objects. | |||
:param prefix: The prefix appended to the yielded keys. | |||
:param predicate: The predicate function applied to scanned objects. | |||
:param seen: A dict that records whether a module has been traversed yet. | |||
""" | |||
if seen is None: | |||
seen = set([id(self)]) | |||
module_dict = vars(self) | |||
_prefix = "" if prefix is None else prefix + "." | |||
for key in sorted(module_dict): | |||
for expanded_key, leaf in _expand_structure(key, module_dict[key]): | |||
leaf_id = id(leaf) | |||
if leaf_id in seen: | |||
continue | |||
seen.add(leaf_id) | |||
if predicate(leaf): | |||
if with_key and with_parent: | |||
yield _prefix + expanded_key, leaf, self | |||
elif with_key: | |||
yield _prefix + expanded_key, leaf | |||
elif with_parent: | |||
yield leaf, self | |||
else: | |||
yield leaf | |||
if recursive and isinstance(leaf, Module): | |||
yield from leaf._flatten( | |||
recursive=recursive, | |||
with_key=with_key, | |||
with_parent=with_parent, | |||
prefix=_prefix + expanded_key if with_key else None, | |||
predicate=predicate, | |||
seen=seen, | |||
) | |||
def parameters( | |||
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs | |||
) -> Iterable[Parameter]: | |||
r"""Returns an iterable for the :class:`~.Parameter` of the module. | |||
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
attribute of returned :class:`.Parameter`. ``None`` for no limitation. | |||
:param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
module, else only returns :class:`~.Parameter` that are direct attributes | |||
of this module. | |||
""" | |||
def predicate(obj) -> bool: | |||
return _is_parameter(obj) and ( | |||
requires_grad is None or obj.requires_grad == requires_grad | |||
) | |||
yield from self._flatten( | |||
with_key=False, predicate=predicate, recursive=recursive, **kwargs | |||
) | |||
def named_parameters( | |||
self, | |||
requires_grad: Optional[bool] = None, | |||
prefix: Optional[str] = None, | |||
recursive: bool = True, | |||
**kwargs | |||
) -> Iterable[Tuple[str, Parameter]]: | |||
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where | |||
``key`` is the dotted path from this module to the :class:`~.Parameter` . | |||
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` | |||
attribute of returned :class:`~.Parameter` . ``None`` for no limitation. | |||
:param prefix: The prefix prepended to the keys. | |||
:param recursive: If ``True``, returns all :class:`~.Parameter` within this | |||
module, else only returns :class:`~.Parameter` that are direct attributes | |||
of this module. | |||
""" | |||
def predicate(obj) -> bool: | |||
return _is_parameter(obj) and ( | |||
requires_grad is None or obj.requires_grad == requires_grad | |||
) | |||
yield from self._flatten( | |||
with_key=True, | |||
prefix=prefix, | |||
predicate=predicate, | |||
recursive=recursive, | |||
**kwargs, | |||
) | |||
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: | |||
"""Returns an iterable for the :class:`~.Buffer` of the module. | |||
:param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
module, else only returns :class:`~.Buffer` that are direct attributes | |||
of this module. | |||
""" | |||
yield from self._flatten( | |||
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs | |||
) | |||
def named_buffers( | |||
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs | |||
) -> Iterable[Tuple[str, Buffer]]: | |||
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where | |||
``key`` is the dotted path from this module to the :class:`~.Buffer` . | |||
:param prefix: The prefix prepended to the keys. | |||
:param recursive: If ``True``, returns all :class:`~.Buffer` within this | |||
module, else only returns :class:`~.Buffer` that are direct attributes | |||
of this module. | |||
""" | |||
yield from self._flatten( | |||
with_key=True, | |||
prefix=prefix, | |||
predicate=_is_buffer, | |||
recursive=recursive, | |||
**kwargs, | |||
) | |||
def children(self, **kwargs) -> "Iterable[Module]": | |||
"""Returns an iterable for all the submodules that are direct attributes of this | |||
module. | |||
""" | |||
yield from self._flatten( | |||
with_key=False, predicate=_is_module, recursive=False, **kwargs | |||
) | |||
def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": | |||
"""Returns an iterable of key-submodule pairs for all the submodules that are | |||
direct attributes of this module, where 'key' is the attribute name of | |||
submodules. | |||
""" | |||
yield from self._flatten( | |||
with_key=True, predicate=_is_module, recursive=False, **kwargs | |||
) | |||
def modules(self, **kwargs) -> "Iterable[Module]": | |||
"""Returns an iterable for all the modules within this module, including itself. | |||
""" | |||
if "with_parent" in kwargs and kwargs["with_parent"]: | |||
yield self, None | |||
else: | |||
yield self | |||
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) | |||
def named_modules( | |||
self, prefix: Optional[str] = None, **kwargs | |||
) -> "Iterable[Tuple[str, Module]]": | |||
"""Returns an iterable of key-module pairs for all the modules within this | |||
module, including itself, where 'key' is the dotted path from this module to the | |||
submodules. | |||
:param prefix: The prefix prepended to the path. | |||
""" | |||
if "with_parent" in kwargs and kwargs["with_parent"]: | |||
yield ("" if prefix is None else prefix), self, None | |||
else: | |||
yield ("" if prefix is None else prefix), self | |||
yield from self._flatten( | |||
with_key=True, prefix=prefix, predicate=_is_module, **kwargs | |||
) | |||
def apply(self, fn: "Callable[[Module], Any]") -> None: | |||
"""Apply function ``fn`` to all the modules within this module, including | |||
itself. | |||
:param fn: The function to be applied on modules. | |||
""" | |||
for it in self.modules(): | |||
fn(it) | |||
def zero_grad(self) -> None: | |||
"""Set all parameters' grads to zero | |||
""" | |||
for param in self.parameters(): | |||
if param.grad is not None: | |||
param.grad.reset_zero() | |||
def train(self, mode: bool = True, recursive: bool = True) -> None: | |||
"""Set training mode of all the modules within this module (including itself) to | |||
``mode``. This effectively sets the ``training`` attributes of those modules | |||
to ``mode``, but only has effect on certain modules (e.g. | |||
:class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) | |||
:param mode: the training mode to be set on modules. | |||
:param recursive: whether to recursively call submodules' ``train()``. | |||
""" | |||
if not recursive: | |||
self.training = mode | |||
return | |||
def fn(module: Module) -> None: | |||
module.train(mode, recursive=False) | |||
self.apply(fn) | |||
def eval(self) -> None: | |||
"""Set training mode of all the modules within this module (including itself) to | |||
``False``. See :meth:`~.Module.train` for details. | |||
""" | |||
self.train(False) | |||
def disable_quantize(self, value=True): | |||
r""" | |||
Set ``module``'s ``quantize_disabled`` attribute and return ``module``. | |||
Could be used as a decorator. | |||
""" | |||
def fn(module: Module) -> None: | |||
module.quantize_disabled = value | |||
self.apply(fn) | |||
def replace_param( | |||
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | |||
): | |||
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to | |||
speedup multimachine training. | |||
""" | |||
offset = 0 | |||
if seen is None: | |||
seen = set([id(self)]) | |||
module_dict = vars(self) | |||
for key in sorted(module_dict): | |||
hash_id = id(module_dict[key]) | |||
if hash_id in seen: | |||
continue | |||
seen.add(hash_id) | |||
if isinstance(module_dict[key], Parameter): | |||
if start_pos + offset in params: | |||
assert module_dict[key].shape == params[start_pos + offset].shape | |||
module_dict[key] = params[start_pos + offset] | |||
offset += 1 | |||
if isinstance(module_dict[key], Module): | |||
offset += module_dict[key].replace_param( | |||
params, start_pos + offset, seen | |||
) | |||
return offset | |||
def state_dict(self, rst=None, prefix="", keep_var=False): | |||
r"""Returns a dictionary containing whole states of the module. | |||
""" | |||
def is_state(obj): | |||
return _is_parameter(obj) or _is_buffer(obj) | |||
if rst is None: | |||
rst = OrderedDict() | |||
for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | |||
assert prefix + k not in rst, "duplicated state: {}".format(k) | |||
if keep_var: | |||
rst[prefix + k] = v | |||
else: | |||
rst[prefix + k] = v.numpy() | |||
for k, submodule in self._flatten( | |||
recursive=False, | |||
with_key=True, | |||
predicate=lambda obj: isinstance(obj, Module), | |||
): | |||
submodule.state_dict(rst, prefix + k + ".", keep_var) | |||
return rst | |||
def load_state_dict( | |||
self, | |||
state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], | |||
strict=True, | |||
): | |||
r"""Load a given dictionary created by :func:`state_dict` into this module. | |||
If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys | |||
returned by :func:`state_dict`. | |||
Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` | |||
as a `state_dict`, in order to handle complex situations. For example, load everything | |||
except for the final linear classifier: | |||
.. code-block:: | |||
state_dict = {...} # Dict[str, np.ndarray] | |||
model.load_state_dict({ | |||
k: None if k.startswith('fc') else v | |||
for k, v in state_dict.items() | |||
}, strict=False) | |||
Here returning `None` means skipping parameter `k`. | |||
To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: | |||
.. code-block:: | |||
state_dict = {...} | |||
def reshape_accordingly(k, v): | |||
return state_dict[k].reshape(v.shape) | |||
model.load_state_dict(reshape_accordingly) | |||
We can also perform inplace re-initialization or pruning: | |||
.. code-block:: | |||
def reinit_and_pruning(k, v): | |||
if 'bias' in k: | |||
M.init.zero_(v) | |||
if 'conv' in k: | |||
return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) | |||
model.load_state_dict(reinit_and_pruning, strict=False) | |||
""" | |||
unused = [] | |||
if isinstance(state_dict, dict): | |||
unused = state_dict.keys() | |||
def closure(k, _): # var unused | |||
return state_dict[k] if k in state_dict else None | |||
elif callable(state_dict): | |||
closure = state_dict | |||
else: | |||
raise ValueError( | |||
"`state_dict` must load a dict or callable, got {}".format( | |||
type(state_dict) | |||
) | |||
) | |||
loaded, skipped = self._load_state_dict_with_closure(closure) | |||
unused = set(unused) - loaded | |||
if len(unused) != 0: | |||
if strict: | |||
raise KeyError( | |||
"Unused params violate `strict=True`, unused={}".format(unused) | |||
) | |||
else: | |||
logger.warning( | |||
"Unused params in `strict=False` mode, unused={}".format(unused) | |||
) | |||
if len(skipped) != 0: | |||
if strict: | |||
raise KeyError( | |||
"Missing params violate `strict=True`, missing={}".format(skipped) | |||
) | |||
else: | |||
logger.warning( | |||
"Missing params in `strict=False` mode, missing={}".format(skipped) | |||
) | |||
def _load_state_dict_with_closure(self, closure): | |||
"""Advance state_dict load through callable `closure` whose signature is | |||
`closure(key: str, var: Tensor) -> Union[np.ndarry, None]` | |||
""" | |||
assert callable(closure), "closure must be a function" | |||
loaded = [] | |||
skipped = [] | |||
local_state_dict = self.state_dict(keep_var=True) | |||
for k, var in local_state_dict.items(): | |||
to_be_load = closure(k, var) | |||
if to_be_load is None: | |||
skipped.append(k) | |||
continue | |||
assert isinstance( | |||
to_be_load, np.ndarray | |||
), "closure should return a `np.ndarray`, now `{}` get {}".format( | |||
k, to_be_load | |||
) | |||
assert ( | |||
var.shape == to_be_load.shape | |||
), "param `{}` shape mismatch, should be {}, get {}".format( | |||
k, var.shape, to_be_load.shape | |||
) | |||
# For quantized dtype, the initialized dtype | |||
# scale/zero_points maybe invalid, use pretrained dtype instead. | |||
if is_quantize(to_be_load.dtype) and is_quantize(var.dtype): | |||
var = var.astype(to_be_load.dtype) | |||
var.set_value(to_be_load) | |||
loaded.append(k) | |||
return set(loaded), set(skipped) |
@@ -0,0 +1,156 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Callable, Iterable, Optional, Tuple | |||
import numpy as np | |||
from ..tensor_nn import Parameter, Tensor | |||
from .module import Module | |||
class ParamPack(Module): | |||
r"""Pack module's parameters by gathering their memory to continuous address. | |||
Using (device, dtype, requires_grad) as key, for example ('gpu0', float32, True), | |||
parameters with same key will be packed togather. | |||
It helps a lot for multimachine training by speeding up allreduce gradients. | |||
:param model: the module you want to pack parameters. | |||
:param nr_ignore_first: how many parameters will be unpacked at first. | |||
:param max_size_per_group: upper bound of packed parameters' size in MB. | |||
:param max_nr_params_per_group: upper bound of the number of parameters of each group. | |||
""" | |||
def __init__( | |||
self, | |||
model: Module, | |||
nr_ignore_first: int = 8, | |||
max_size_per_group: int = 10, | |||
max_nr_params_per_group: int = 100, | |||
group_func: Callable = lambda name, param: 0, | |||
): | |||
super().__init__() | |||
self._model = model | |||
self._nr_ignore_first = nr_ignore_first | |||
self._max_size_per_group = max_size_per_group | |||
self._max_nr_params_per_group = max_nr_params_per_group | |||
self._group_func = group_func | |||
self._grouped_params = [] | |||
self._packed_params = [] | |||
params = model.named_parameters() | |||
self._pack_params(params) | |||
def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: | |||
for param in self._packed_params: | |||
if requires_grad is None or param.requires_grad == requires_grad: | |||
yield param | |||
def named_parameters( | |||
self, requires_grad: Optional[bool] = None | |||
) -> Iterable[Tuple[str, Parameter]]: | |||
for idx, param in enumerate(self._packed_params): | |||
if requires_grad is None or param.requires_grad == requires_grad: | |||
yield "packed_param_" + str(idx), param | |||
def _pack_params(self, params: Iterable[Tuple[str, Parameter]]): | |||
groups = collections.defaultdict(list) | |||
ignored = 0 | |||
param_id = 0 | |||
for name, param in params: | |||
if self._nr_ignore_first > ignored: | |||
ignored += 1 | |||
self._grouped_params.append([{"shape": param.shape, "id": param_id}]) | |||
param.pack_group_key = self._group_func(name, param) | |||
self._packed_params.append(param) | |||
else: | |||
key = ( | |||
param.dtype, | |||
param.device, | |||
param.requires_grad, | |||
self._group_func(name, param), | |||
) | |||
groups[key].append({"tensor": param, "id": param_id}) | |||
param_id += 1 | |||
for (dtype, device, requires_grad, group_key) in groups.keys(): | |||
dtype_sz = np.dtype(dtype).itemsize | |||
align = device.mem_align | |||
if align < dtype_sz: | |||
align = 1 | |||
else: | |||
assert align % dtype_sz == 0 | |||
align //= dtype_sz | |||
group = groups[(dtype, device, requires_grad, group_key)] | |||
while group: | |||
aligned_pos = [] | |||
offset = 0 | |||
params = [] | |||
idx = 0 | |||
while idx < len(group): | |||
param = group[idx] | |||
assert param["tensor"].device == device | |||
padding = (align - (offset & (align - 1))) & (align - 1) | |||
offset += padding | |||
aligned_pos.append(offset) | |||
params.append(param) | |||
offset += int(np.prod(param["tensor"].shape)) | |||
idx += 1 | |||
if ( | |||
offset * dtype_sz >= self._max_size_per_group * 1024 * 1024 | |||
or idx >= self._max_nr_params_per_group | |||
): | |||
break | |||
group = group[idx:] | |||
if idx == 1: | |||
# ignore param packs with only one item | |||
params[0]["tensor"].pack_group_key = group_key | |||
self._packed_params.append(params[0]["tensor"]) | |||
self._grouped_params.append( | |||
[{"shape": params[0]["tensor"].shape, "id": params[0]["id"]}] | |||
) | |||
continue | |||
packed_value = np.zeros((offset,), dtype=dtype) | |||
for param, pos in zip(params, aligned_pos): | |||
val = param["tensor"].numpy() | |||
packed_value[pos : pos + val.size] = val.flatten() | |||
new_param = Parameter( | |||
value=packed_value, | |||
device=device, | |||
dtype=dtype, | |||
requires_grad=requires_grad, | |||
) | |||
new_param.pack_group_key = group_key | |||
self._packed_params.append(new_param) | |||
self._grouped_params.append( | |||
[{"shape": i["tensor"].shape, "id": i["id"]} for i in params] | |||
) | |||
def forward(self, *args, **kwargs): | |||
replace_param = dict() | |||
for i in range(len(self._packed_params)): | |||
packed_param = self._packed_params[i] | |||
grouped_params = self._grouped_params[i] | |||
if len(grouped_params) == 1: | |||
continue | |||
split = param_pack_split( | |||
packed_param._symvar, [i["shape"] for i in grouped_params] | |||
) | |||
split = [ | |||
Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) | |||
for i in split | |||
] | |||
for j in range(len(split)): | |||
replace_param[grouped_params[j]["id"]] = split[j] | |||
self._model.replace_param(replace_param, 0) | |||
return self._model.forward(*args, **kwargs) |
@@ -0,0 +1,80 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from typing import Tuple, Union | |||
from ..functional import avg_pool2d, max_pool2d | |||
from .module import Module | |||
class _PoolNd(Module): | |||
def __init__( | |||
self, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = None, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
): | |||
super(_PoolNd, self).__init__() | |||
self.kernel_size = kernel_size | |||
self.stride = stride or kernel_size | |||
self.padding = padding | |||
@abstractmethod | |||
def forward(self, inp): | |||
pass | |||
class MaxPool2d(_PoolNd): | |||
r"""Applies a 2D max pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
:attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
.. math:: | |||
\begin{aligned} | |||
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} | |||
\text{input}(N_i, C_j, \text{stride[0]} \times h + m, | |||
\text{stride[1]} \times w + n) | |||
\end{aligned} | |||
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window to take a max over. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param padding: implicit zero padding to be added on both sides. | |||
""" | |||
def forward(self, inp): | |||
return max_pool2d(inp, self.kernel_size, self.stride, self.padding) | |||
class AvgPool2d(_PoolNd): | |||
r"""Applies a 2D average pooling over an input. | |||
For instance, given an input of the size :math:`(N, C, H, W)` and | |||
:attr:`kernel_size` :math:`(kH, kW)`, this layer generates the output of | |||
the size :math:`(N, C, H_{out}, W_{out})` through a process described as: | |||
.. math:: | |||
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} | |||
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) | |||
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param padding: implicit zero padding to be added on both sides. | |||
""" | |||
def forward(self, inp): | |||
return avg_pool2d(inp, self.kernel_size, self.stride, self.padding) |
@@ -0,0 +1,14 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QATModule | |||
from .quant_dequant import DequantStub, QuantStub |
@@ -0,0 +1,30 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from ...tensor import Tensor | |||
from .. import concat as Float | |||
from .module import QATModule | |||
class Concat(Float.Concat, QATModule): | |||
r""" | |||
A :class:`~.QATModule` to do functional concat with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return self.apply_quant_activation(super().forward(inps, axis)) | |||
@classmethod | |||
def from_float_module(cls, float_module): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
return cls() |
@@ -0,0 +1,59 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import functional as F | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import conv as Float | |||
from .module import QATModule | |||
class Conv2d(Float.Conv2d, QATModule): | |||
r""" | |||
A :class:`~.QATModule` Conv2d with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def calc_conv_qat(self, inp): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||
conv = self.calc_conv(inp, w_qat, b_qat) | |||
return conv | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.Conv2d): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
qat_module = cls( | |||
float_module.in_channels, | |||
float_module.out_channels, | |||
float_module.kernel_size, | |||
float_module.stride, | |||
float_module.padding, | |||
float_module.dilation, | |||
float_module.groups, | |||
float_module.bias is not None, | |||
float_module.conv_mode.name, | |||
float_module.compute_mode.name, | |||
) | |||
qat_module.weight = float_module.weight | |||
qat_module.bias = float_module.bias | |||
return qat_module | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_qat(inp)) | |||
class ConvRelu2d(Conv2d): | |||
r""" | |||
A :class:`~.QATModule` include Conv2d and Relu with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) |
@@ -0,0 +1,193 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...functional import add_update, ones, relu, sqrt, sum, zeros | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import conv_bn as Float | |||
from .module import QATModule | |||
class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
def get_batch_mean_var(self, inp): | |||
def _sum_channel(inp, axis=0, keepdims=True): | |||
if isinstance(axis, int): | |||
out = sum(inp, axis=axis, keepdims=keepdims) | |||
elif isinstance(axis, tuple): | |||
for idx, elem in enumerate(axis): | |||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
return out | |||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
reduce_size = inp.size / inp.shape[1] | |||
batch_mean = sum1 / reduce_size | |||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
return batch_mean, batch_var | |||
def fold_weight_bias(self, bn_mean, bn_var): | |||
# get fold bn conv param | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
if bn_mean is None: | |||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
w_fold = self.apply_quant_weight(w_fold) | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
def update_running_mean_and_running_var( | |||
self, bn_mean, bn_var, num_elements_per_channel | |||
): | |||
# update running mean and running var. no grad, use unbiased bn var | |||
bn_mean = bn_mean.detach() | |||
bn_var = ( | |||
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||
) | |||
exponential_average_factor = 1 - self.bn.momentum | |||
add_update( | |||
self.bn.running_mean, | |||
delta=bn_mean, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
add_update( | |||
self.bn.running_var, | |||
delta=bn_var, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
def calc_conv_bn_qat(self, inp, approx=True): | |||
if self.training and not approx: | |||
conv = self.conv(inp) | |||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
num_elements_per_channel = conv.size / conv.shape[1] | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
else: | |||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
# get gamma and beta in BatchNorm | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
# conv_bias | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
b_fold = None | |||
if not (self.training and approx): | |||
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
w_qat = self.apply_quant_weight(w_fold) | |||
b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||
conv = self.conv.calc_conv(inp, w_qat, b_qat) | |||
if not (self.training and approx): | |||
return conv | |||
# rescale conv to get original conv output | |||
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
if self.conv.bias is not None: | |||
orig_conv = orig_conv + self.conv.bias | |||
# calculate batch norm | |||
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
num_elements_per_channel = conv.size / conv.shape[1] | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
return conv | |||
@classmethod | |||
def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
qat_module = cls( | |||
float_module.conv.in_channels, | |||
float_module.conv.out_channels, | |||
float_module.conv.kernel_size, | |||
float_module.conv.stride, | |||
float_module.conv.padding, | |||
float_module.conv.dilation, | |||
float_module.conv.groups, | |||
float_module.conv.bias is not None, | |||
float_module.conv.conv_mode.name, | |||
float_module.conv.compute_mode.name, | |||
) | |||
qat_module.conv.weight = float_module.conv.weight | |||
qat_module.conv.bias = float_module.conv.bias | |||
qat_module.bn = float_module.bn | |||
return qat_module | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) |