Browse Source

Merge pull request #435 from MegEngine/try-import

tags/v1.9.0
XindaH GitHub 3 years ago
parent
commit
ea91babbce
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 235 additions and 39 deletions
  1. +3
    -0
      .gitattributes
  2. +0
    -2
      .github/workflows/ci.yml
  3. +1
    -1
      README.md
  4. +3
    -3
      README_CN.md
  5. +1
    -2
      ci/cmake.sh
  6. +54
    -0
      dnn/src/aarch64/relayout/opr_impl.cpp
  7. +3
    -0
      dnn/test/aarch64/relayout.cpp
  8. +3
    -1
      imperative/python/megengine/device.py
  9. +22
    -18
      imperative/python/megengine/functional/math.py
  10. +19
    -10
      imperative/python/megengine/module/init.py
  11. +28
    -1
      imperative/python/test/unit/module/test_init.py
  12. +15
    -0
      lite/include/lite/global.h
  13. +16
    -1
      lite/lite-c/include/lite-c/global_c.h
  14. +15
    -0
      lite/lite-c/src/global.cpp
  15. +20
    -0
      lite/pylite/megenginelite/global_setting.py
  16. +0
    -0
      lite/pylite/test/test_network_device.py
  17. +31
    -0
      lite/src/global.cpp
  18. +1
    -0
      lite/test/test_network.cpp

+ 3
- 0
.gitattributes View File

@@ -21,3 +21,6 @@ ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text
ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text
imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_data_input.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_data_output.npy filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_model.mge filter=lfs diff=lfs merge=lfs -text

+ 0
- 2
.github/workflows/ci.yml View File

@@ -29,7 +29,6 @@ jobs:
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Checkout submodules - name: Checkout submodules
run: | run: |
apt update&&apt install ninja-build
./third_party/prepare.sh ./third_party/prepare.sh
./third_party/install-mkl.sh ./third_party/install-mkl.sh
- name: Build MegEngine - name: Build MegEngine
@@ -58,7 +57,6 @@ jobs:
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Checkout submodules - name: Checkout submodules
run: | run: |
apt update&&apt install ninja-build
./third_party/prepare.sh ./third_party/prepare.sh
./third_party/install-mkl.sh ./third_party/install-mkl.sh
- name: Build MegEngine - name: Build MegEngine


+ 1
- 1
README.md View File

@@ -12,7 +12,7 @@ MegEngine is a fast, scalable and easy-to-use deep learning framework, with auto


## Installation ## Installation


**NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+/Android 7+(CPU-Only) platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference.
**NOTE:** MegEngine now supports Python installation on Linux-64bit/Windows-64bit/MacOS(CPU-Only)-10.14+ platforms with Python from 3.5 to 3.8. On Windows 10 you can either install the Linux distribution through [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) or install the Windows distribution directly. Many other platforms are supported for inference.


### Binaries ### Binaries




+ 3
- 3
README_CN.md View File

@@ -13,7 +13,7 @@ MegEngine 是一个快速、可拓展、易于使用且支持自动求导的深


## 安装说明 ## 安装说明


**注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14/Android 7+ 及其以上 (MacOS/Android只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。
**注意:** MegEngine 现在支持在 Linux-64bit/Windows-64bit/macos-10.14及其以上 (MacOS只支持cpu) 等平台上安装 Python 包,支持Python3.5 到 Python3.8。对于 Windows 10 用户,可以通过安装 [Windows Subsystem for Linux (WSL)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验,同时我们也原生支持Windows。MegEngine 也支持在很多其它平台上进行推理运算。


### 通过包管理器安装 ### 通过包管理器安装


@@ -26,8 +26,8 @@ python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html


## 通过源码编译安装 ## 通过源码编译安装


* CMake 编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md)
* Python 绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md)
* CMake编译细节请参考 [BUILD_README.md](scripts/cmake-build/BUILD_README.md)
* Python绑定编译细节请参考 [BUILD_PYTHON_WHL_README.md](scripts/whl/BUILD_PYTHON_WHL_README.md)


## 如何参与贡献 ## 如何参与贡献




+ 1
- 2
ci/cmake.sh View File

@@ -27,8 +27,7 @@ function build() {
-DMGE_WITH_DISTRIBUTED=${DMGE_WITH_DISTRIBUTED} \ -DMGE_WITH_DISTRIBUTED=${DMGE_WITH_DISTRIBUTED} \
-DMGE_WITH_CUDA=${DMGE_WITH_CUDA} \ -DMGE_WITH_CUDA=${DMGE_WITH_CUDA} \
-DMGE_WITH_TEST=ON \ -DMGE_WITH_TEST=ON \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DMGE_WITH_CUSTOM_OP=ON
-DCMAKE_BUILD_TYPE=RelWithDebInfo
make -j$(($(nproc) * 2)) -I ${build_dir} make -j$(($(nproc) * 2)) -I ${build_dir}
make develop make develop
popd >/dev/null popd >/dev/null


+ 54
- 0
dnn/src/aarch64/relayout/opr_impl.cpp View File

@@ -363,6 +363,58 @@ static inline void trans_8x4_u16(
vst1q_u16(dst_ptr + 3 * dst_step, row_3); vst1q_u16(dst_ptr + 3 * dst_step, row_3);
} }


static inline void trans_8x3_u16(
const void* src, void* dst, const size_t src_step, const size_t dst_step) {
uint16_t* src_ptr = (uint16_t*)src;
uint16_t* dst_ptr = (uint16_t*)dst;
uint16x4_t src0 = vld1_u16(src_ptr + 0 * src_step); // A0A1A2A3
uint16x4_t src1 = vld1_u16(src_ptr + 1 * src_step); // B0B1B2B3
uint16x4_t src2 = vld1_u16(src_ptr + 2 * src_step); // C0C1C2C3
uint16x4_t src3 = vld1_u16(src_ptr + 3 * src_step); // D0D1D2D3
uint16x4_t src4 = vld1_u16(src_ptr + 4 * src_step); // E0E1E2E3
uint16x4_t src5 = vld1_u16(src_ptr + 5 * src_step); // F0F1F2F3
uint16x4_t src6 = vld1_u16(src_ptr + 6 * src_step); // G0G1G2G3
// H0H1H2
uint16x4_t src7 =
vreinterpret_u16_u32(vld1_dup_u32((uint32_t*)(src_ptr + 7 * src_step)));
src7 = vld1_lane_u16(src_ptr + 7 * src_step + 2, src7, 2);

uint16x4_t ab_low = vzip1_u16(src0, src1); // A0B0A1B1
uint16x4_t ab_high = vzip2_u16(src0, src1); // A2B2A3B3
uint16x4_t cd_low = vzip1_u16(src2, src3); // C0D0C1D1
uint16x4_t cd_high = vzip2_u16(src2, src3); // C2D2C3D3
uint16x4_t ef_low = vzip1_u16(src4, src5); // E0F0E1F1
uint16x4_t ef_high = vzip2_u16(src4, src5); // E2F2E3F3
uint16x4_t gh_low = vzip1_u16(src6, src7); // G0H0G1H1
uint16x4_t gh_high = vzip2_u16(src6, src7); // G2H2G3

uint16x4_t abcd_0 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ab_low),
vreinterpret_u32_u16(cd_low))); // A0B0C0D0
uint16x4_t abcd_1 = vreinterpret_u16_u32(vzip2_u32(
vreinterpret_u32_u16(ab_low),
vreinterpret_u32_u16(cd_low))); // A1B1C1D1
uint16x4_t abcd_2 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ab_high),
vreinterpret_u32_u16(cd_high))); // A2B2C2D2
uint16x4_t efgh_0 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ef_low),
vreinterpret_u32_u16(gh_low))); // E0F0G0H0
uint16x4_t efgh_1 = vreinterpret_u16_u32(vzip2_u32(
vreinterpret_u32_u16(ef_low),
vreinterpret_u32_u16(gh_low))); // E1F1G1H1
uint16x4_t efgh_2 = vreinterpret_u16_u32(vzip1_u32(
vreinterpret_u32_u16(ef_high),
vreinterpret_u32_u16(gh_high))); // E2F2G2H2

uint16x8_t row_0 = vcombine_u16(abcd_0, efgh_0);
uint16x8_t row_1 = vcombine_u16(abcd_1, efgh_1);
uint16x8_t row_2 = vcombine_u16(abcd_2, efgh_2);

vst1q_u16(dst_ptr + 0 * dst_step, row_0);
vst1q_u16(dst_ptr + 1 * dst_step, row_1);
vst1q_u16(dst_ptr + 2 * dst_step, row_2);
}
} // anonymous namespace } // anonymous namespace


namespace megdnn { namespace megdnn {
@@ -410,6 +462,8 @@ void transpose_block<Transpose2Byte>(
const size_t dst_stride, size_t block_h, size_t block_w) { const size_t dst_stride, size_t block_h, size_t block_w) {
if (block_h == 8 && block_w == 4) { if (block_h == 8 && block_w == 4) {
trans_8x4_u16(src, dst, src_stride, dst_stride); trans_8x4_u16(src, dst, src_stride, dst_stride);
} else if (block_h == 8 && block_w == 3) {
trans_8x3_u16(src, dst, src_stride, dst_stride);
} else { } else {
transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w);
} }


+ 3
- 0
dnn/test/aarch64/relayout.cpp View File

@@ -40,6 +40,9 @@ TEST_F(AARCH64, Relayout) {
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype);
checker.execl({src, dst}); checker.execl({src, dst});
} }
TensorLayout src_4_3({1, 3, 112, 256}, {3, 1, 1024, 4}, dtype::Uint16());
TensorLayout dst_4_3({1, 3, 112, 256}, {86016, 28672, 256, 1}, dtype::Uint16());
checker.execl({src_4_3, dst_4_3});
} }


TEST_F(AARCH64, RelayoutNonContig) { TEST_F(AARCH64, RelayoutNonContig) {


+ 3
- 1
imperative/python/megengine/device.py View File

@@ -50,7 +50,9 @@ _sh = _stream_helper()




def _valid_device(inp): def _valid_device(inp):
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp):
if isinstance(inp, str) and re.match(
"^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp
):
return True return True
return False return False




+ 22
- 18
imperative/python/megengine/functional/math.py View File

@@ -1153,35 +1153,39 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:




def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
r"""Computes the singular value decompositions of input matrix.
r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.


Args: Args:
inp: input matrix, must has shape `[..., M, N]`.
x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` .
full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .

Note:
* naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` .


Returns: Returns:
output matrices, `(U, sigma, V)`.
Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True).
``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .


Examples: Examples:


.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
_, y, _ = F.svd(x)
print(y.numpy().round(decimals=3))
>>> import numpy as np
>>> x = Tensor(np.random.randn(9, 6))
>>> y = Tensor(np.random.randn(2, 7, 8, 3))


Outputs:

.. testoutput::
Reconstruction based on reduced SVD, 2D case:
>>> U, S, Vh = F.svd(x, full_matrices=False)
>>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape)
(9, 6) (6,) (6, 6)


[7.348 1. ]
Reconsturction based on reduced SVD, 4D case:
>>> u, s, vh = F.svd(y, full_matrices=False)
>>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape)
(2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3)
""" """
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
U, sigma, V = apply(op, inp)
return U, sigma, V
U, S, Vh = apply(op, inp)
return U, S, Vh




def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:


+ 19
- 10
imperative/python/megengine/module/init.py View File

@@ -74,7 +74,7 @@ def calculate_gain(
) -> float: ) -> float:
r"""Returns a recommended gain value (see the table below) for the given nonlinearity r"""Returns a recommended gain value (see the table below) for the given nonlinearity
function. function.
================= ==================================================== ================= ====================================================
nonlinearity gain nonlinearity gain
================= ==================================================== ================= ====================================================
@@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in ``NCHW`` format. input tensor is stored in ``NCHW`` format.


Note:
The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This
function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses
``fan_out = O * K * K``.

Args: Args:
tensor: weight tensor in ``NCHW`` format. tensor: weight tensor in ``NCHW`` format.
""" """
@@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
fan_in = shape[1] fan_in = shape[1]
fan_out = shape[0] fan_out = shape[0]
else: else:
if ndim >= 5:
# ignore the groups dimension of group conv2d and group conv3d
# FIXME: will be wrong for conv3d
shape = shape[1:]
num_input_fmaps = shape[1] num_input_fmaps = shape[1]
num_output_fmaps = shape[0] num_output_fmaps = shape[0]
receptive_field_size = 1 receptive_field_size = 1
@@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def calculate_correct_fan(tensor: Tensor, mode: str) -> float: def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
r"""Calculates fan_in / fan_out value for given weight tensor, depending on given r"""Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``. ``mode``.
See :func:`calculate_fan_in_and_fan_out` for details. See :func:`calculate_fan_in_and_fan_out` for details.


Args: Args:
@@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where where
.. math:: .. math::


a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
@@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fills tensor with random values sampled from r"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::


\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as Glorot initialization. Detailed information can be retrieved from Also known as Glorot initialization. Detailed information can be retrieved from
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
@@ -220,11 +229,11 @@ def msra_uniform_(
) -> None: ) -> None:
r"""Fills tensor wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math:: .. math::


\text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`
@@ -251,11 +260,11 @@ def msra_normal_(
) -> None: ) -> None:
r"""Fills tensor wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::


\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}
Detailed information can be retrieved from Detailed information can be retrieved from
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`


+ 28
- 1
imperative/python/test/unit/module/test_init.py View File

@@ -10,7 +10,7 @@ import numpy as np
import pytest import pytest


from megengine import tensor from megengine import tensor
from megengine.module import Conv2d, Linear
from megengine.module import Conv1d, Conv2d, Conv3d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out, fill_ from megengine.module.init import calculate_fan_in_and_fan_out, fill_




@@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out():
with pytest.raises(ValueError): with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias) calculate_fan_in_and_fan_out(l.bias)


l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5
assert fanout == 3 * 5

# FIXME: will be wrong for group conv1d
# l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2)
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 // 2 * 5
# assert fanout == 4 // 2 * 5

l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight) fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7 assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7 assert fanout == 3 * 5 * 7

l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7
assert fanout == 4 // 2 * 5 * 7

# FIXME: will be wrong for conv3d
# l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9))
# fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
# assert fanin == 2 * 5 * 7 * 9
# assert fanout == 3 * 5 * 7 * 9

l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7 * 9
assert fanout == 4 // 2 * 5 * 7 * 9

+ 15
- 0
lite/include/lite/global.h View File

@@ -154,6 +154,21 @@ LITE_API void set_tensor_rt_cache(std::string tensorrt_cache_path);
*/ */
LITE_API void dump_tensor_rt_cache(); LITE_API void dump_tensor_rt_cache();


/**
* register the physical and virtual address pair to the mge, some device
* need the map from physical to virtual.
*/
LITE_API bool register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend = LiteBackend::LITE_DEFAULT);

/**
* clear the physical and virtual address pair in mge.
*/
LITE_API bool clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device,
LiteBackend backend = LiteBackend::LITE_DEFAULT);

} // namespace lite } // namespace lite


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 16
- 1
lite/lite-c/include/lite-c/global_c.h View File

@@ -160,9 +160,24 @@ LITE_API int LITE_dump_persistent_cache(const char* cache_path);
* \brief dump the tensorrt policy cache to file * \brief dump the tensorrt policy cache to file
*/ */
LITE_API int LITE_dump_tensor_rt_cache(); LITE_API int LITE_dump_tensor_rt_cache();
#endif

/**
* register the physical and virtual address pair to the mge, some device
* need the map from physical to virtual.
*/
LITE_API int LITE_register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend);

/**
* clear the physical and virtual address pair in mge.
*/
LITE_API int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend);

#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 15
- 0
lite/lite-c/src/global.cpp View File

@@ -189,4 +189,19 @@ int LITE_dump_tensor_rt_cache() {
LITE_CAPI_END(); LITE_CAPI_END();
} }


int LITE_register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend) {
LITE_CAPI_BEGIN();
lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
LITE_CAPI_END();
}

int LITE_clear_memory_pair(
void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_CAPI_BEGIN();
lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
LITE_CAPI_END();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 20
- 0
lite/pylite/megenginelite/global_setting.py View File

@@ -42,6 +42,8 @@ class _GlobalAPI(_LiteCObjBase):
# ('LITE_set_tensor_rt_cache', [c_char_p]), # ('LITE_set_tensor_rt_cache', [c_char_p]),
("LITE_dump_persistent_cache", [c_char_p]), ("LITE_dump_persistent_cache", [c_char_p]),
("LITE_dump_tensor_rt_cache", [c_char_p]), ("LITE_dump_tensor_rt_cache", [c_char_p]),
("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]),
("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]),
] ]




@@ -121,3 +123,21 @@ class LiteGlobal(object):
@staticmethod @staticmethod
def try_coalesce_all_free_memory(): def try_coalesce_all_free_memory():
LiteGlobal._api.LITE_try_coalesce_all_free_memory() LiteGlobal._api.LITE_try_coalesce_all_free_memory()

@staticmethod
def register_memory_pair(
vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT
):
assert isinstance(vir_ptr, c_void_p) and isinstance(
phy_ptr, c_void_p
), "clear memory pair only accept c_void_p type."
LiteGlobal._api.LITE_register_memory_pair(
vir_ptr, phy_ptr, length, device, backend
)

@staticmethod
def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT):
assert isinstance(vir_ptr, c_void_p) and isinstance(
phy_ptr, c_void_p
), "clear memory pair only accept c_void_p type."
LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend)

lite/pylite/test/test_network_cuda.py → lite/pylite/test/test_network_device.py View File


+ 31
- 0
lite/src/global.cpp View File

@@ -212,6 +212,26 @@ void lite::dump_tensor_rt_cache() {
#endif #endif
} }


bool lite::register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend backend) {
LITE_MARK_USED_VAR(vir_ptr);
LITE_MARK_USED_VAR(phy_ptr);
LITE_MARK_USED_VAR(length);
LITE_MARK_USED_VAR(device);
LITE_MARK_USED_VAR(backend);
LITE_THROW("register_memory_pair is not implement yet!");
}

bool lite::clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
LITE_MARK_USED_VAR(vir_ptr);
LITE_MARK_USED_VAR(phy_ptr);
LITE_MARK_USED_VAR(device);
LITE_MARK_USED_VAR(backend);
LITE_THROW("clear_memory_pair is not implement yet!");
}

#else // LITE_BUILD_WITH_MGE #else // LITE_BUILD_WITH_MGE
void lite::try_coalesce_all_free_memory() {} void lite::try_coalesce_all_free_memory() {}


@@ -235,6 +255,17 @@ void lite::set_tensor_rt_cache(std::string) {
void lite::dump_tensor_rt_cache() { void lite::dump_tensor_rt_cache() {
LITE_THROW("mge is disbale at build time, please build with mge"); LITE_THROW("mge is disbale at build time, please build with mge");
} }

bool lite::register_memory_pair(
void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
LiteBackend beckend) {
LITE_THROW("register_memory_pair is not implement yet!");
}

bool lite::clear_memory_pair(
void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend beckend) {
LITE_THROW("clear_memory_pair is not implement yet!");
}
#endif #endif
namespace lite { namespace lite {
REGIST_DECRYPTION_METHOD( REGIST_DECRYPTION_METHOD(


+ 1
- 0
lite/test/test_network.cpp View File

@@ -1357,5 +1357,6 @@ TEST(TestNetWork, CambriconDeviceID) {
load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb"); load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb");
} }
#endif #endif

#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save