Browse Source

feat(mge): add bfloat16 support

GitOrigin-RevId: a942ce6791
tags/v0.5.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
0293d58ade
100 changed files with 4979 additions and 308 deletions
  1. +11
    -0
      dnn/include/megdnn/dtype.h
  2. +2965
    -0
      dnn/include/megdnn/dtype/bfloat16.hpp
  3. +2
    -171
      dnn/include/megdnn/dtype/half.hpp
  4. +48
    -0
      dnn/include/megdnn/dtype/half_common_epilogue.h
  5. +202
    -0
      dnn/include/megdnn/dtype/half_common_prologue.h
  6. +2
    -2
      dnn/scripts/gen_cond_take_kern_impls.py
  7. +2
    -2
      dnn/scripts/gen_elemwise_kern_impls.py
  8. +2
    -2
      dnn/scripts/gen_elemwise_special_kern_impls.py
  9. +2
    -1
      dnn/scripts/gen_elemwise_utils.py
  10. +6
    -4
      dnn/src/common/convolution.cpp
  11. +2
    -1
      dnn/src/common/elemwise/kern_defs.cuh
  12. +13
    -12
      dnn/src/common/matrix_mul.cpp
  13. +8
    -0
      dnn/src/common/rounding_converter.cuh
  14. +54
    -0
      dnn/src/common/utils.h
  15. +41
    -36
      dnn/src/common/warp_perspective.cpp
  16. +29
    -0
      dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu
  17. +7
    -0
      dnn/src/cuda/conv_bias/algo.cpp
  18. +25
    -1
      dnn/src/cuda/conv_bias/algo.h
  19. +120
    -0
      dnn/src/cuda/conv_bias/bfloat16.cpp
  20. +4
    -0
      dnn/src/cuda/conv_bias/chanwise.cpp
  21. +4
    -0
      dnn/src/cuda/conv_bias/chanwise_small.cpp
  22. +4
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
  23. +4
    -0
      dnn/src/cuda/conv_bias/group_conv.cpp
  24. +5
    -0
      dnn/src/cuda/conv_bias/helper.cpp
  25. +4
    -0
      dnn/src/cuda/conv_bias/matmul.cpp
  26. +20
    -7
      dnn/src/cuda/conv_bias/opr_impl.cpp
  27. +1
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  28. +15
    -7
      dnn/src/cuda/convolution/backward_data/algo.cpp
  29. +32
    -9
      dnn/src/cuda/convolution/backward_data/algo.h
  30. +115
    -0
      dnn/src/cuda/convolution/backward_data/bfloat16.cpp
  31. +4
    -0
      dnn/src/cuda/convolution/backward_data/chanwise.cpp
  32. +4
    -0
      dnn/src/cuda/convolution/backward_data/chanwise_small.cpp
  33. +4
    -0
      dnn/src/cuda/convolution/backward_data/group_conv.cpp
  34. +4
    -0
      dnn/src/cuda/convolution/backward_data/matmul.cpp
  35. +16
    -11
      dnn/src/cuda/convolution/backward_filter/algo.cpp
  36. +29
    -6
      dnn/src/cuda/convolution/backward_filter/algo.h
  37. +117
    -0
      dnn/src/cuda/convolution/backward_filter/bfloat16.cpp
  38. +4
    -0
      dnn/src/cuda/convolution/backward_filter/chanwise.cpp
  39. +4
    -0
      dnn/src/cuda/convolution/backward_filter/group_conv.cpp
  40. +4
    -0
      dnn/src/cuda/convolution/backward_filter/matmul.cpp
  41. +4
    -0
      dnn/src/cuda/convolution/helper.cpp
  42. +1
    -0
      dnn/src/cuda/convolution/helper.h
  43. +54
    -28
      dnn/src/cuda/convolution/opr_impl.cpp
  44. +9
    -6
      dnn/src/cuda/convolution/opr_impl.h
  45. +1
    -1
      dnn/src/cuda/convolution3d/forward/algo.h
  46. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu
  47. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu
  48. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu
  49. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu
  50. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu
  51. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu
  52. +17
    -0
      dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu
  53. +17
    -0
      dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu
  54. +17
    -0
      dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu
  55. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu
  56. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu
  57. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu
  58. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu
  59. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu
  60. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu
  61. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu
  62. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu
  63. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu
  64. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu
  65. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu
  66. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu
  67. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu
  68. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu
  69. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu
  70. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu
  71. +17
    -0
      dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu
  72. +17
    -0
      dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu
  73. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu
  74. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu
  75. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu
  76. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu
  77. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu
  78. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu
  79. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu
  80. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu
  81. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu
  82. +17
    -0
      dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu
  83. +17
    -0
      dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu
  84. +17
    -0
      dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu
  85. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu
  86. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu
  87. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu
  88. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu
  89. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu
  90. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu
  91. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu
  92. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu
  93. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu
  94. +18
    -0
      dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu
  95. +3
    -0
      dnn/src/cuda/elemwise_helper.cpp
  96. +12
    -0
      dnn/src/cuda/elemwise_helper.cuh
  97. +5
    -0
      dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
  98. +4
    -0
      dnn/src/cuda/matrix_mul/algos.cpp
  99. +22
    -1
      dnn/src/cuda/matrix_mul/algos.h
  100. +91
    -0
      dnn/src/cuda/matrix_mul/bfloat16.cpp

+ 11
- 0
dnn/include/megdnn/dtype.h View File

@@ -29,6 +29,7 @@
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y
#else
#include "megdnn/dtype/half.hpp"
#include "megdnn/dtype/bfloat16.hpp"
#define MEGDNN_INC_FLOAT16(_x) _x
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x
#endif
@@ -49,6 +50,7 @@ namespace megdnn {
cb(IntB4) \
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \

/*!
@@ -62,6 +64,7 @@ namespace megdnn {
cb(Int32) \
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \

/*!
* \brief iterate through each fractional byte dtype
@@ -101,6 +104,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
cb(::megdnn::dtype::Float32) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \

/*!
* \brief iterate through each dtype object that can be involved in integer
@@ -345,6 +349,7 @@ typedef int16_t dt_int16;
typedef int8_t dt_int8;
typedef uint8_t dt_uint8;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)

#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST
@@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
Float16,
#endif
UintB4 = 10,
#if !MEGDNN_DISABLE_FLOAT16
BFloat16 = 11,
#endif

#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
@@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max()));
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED,
std::numeric_limits<dt_bfloat16>::lowest(),
std::numeric_limits<dt_bfloat16>::max()));

template <>
struct DTypeTrait<dtype::Byte> {


+ 2965
- 0
dnn/include/megdnn/dtype/bfloat16.hpp
File diff suppressed because it is too large
View File


+ 2
- 171
dnn/include/megdnn/dtype/half.hpp View File

@@ -50,167 +50,7 @@
#include <hip/hip_fp16.h>
#endif

/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)

//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif

//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION

//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif

//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif

#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif

#include "megdnn/dtype/half_common_prologue.h"

/// Default rounding mode.
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as
@@ -3141,16 +2981,7 @@ namespace std
#endif
}


#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif

#include "megdnn/dtype/half_common_epilogue.h"
#endif

// vim: syntax=cpp.doxygen

+ 48
- 0
dnn/include/megdnn/dtype/half_common_epilogue.h View File

@@ -0,0 +1,48 @@
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file include/megdnn/dtype/half_common_epilogue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/

#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif

// vim: syntax=cpp.doxygen

+ 202
- 0
dnn/include/megdnn/dtype/half_common_prologue.h View File

@@ -0,0 +1,202 @@
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file dnn/include/megdnn/dtype/half_common_prologue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/

#include "megdnn/arch.h"

/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)

//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif

//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION

//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif

//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif

#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif

// vim: syntax=cpp.doxygen

+ 2
- 2
dnn/scripts/gen_cond_take_kern_impls.py View File

@@ -30,7 +30,7 @@ def main():
w('// generated by gen_cond_take_kern_impls.py')
w('#include "../kern.inl"')
w('')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('namespace megdnn {')
w('namespace cuda {')
@@ -48,7 +48,7 @@ def main():
w('} // cond_take')
w('} // cuda')
w('} // megdnn')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')

print('generated {}'.format(fname))


+ 2
- 2
dnn/scripts/gen_elemwise_kern_impls.py View File

@@ -34,7 +34,7 @@ def main():
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_kern_impls.py')

if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')

w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
@@ -42,7 +42,7 @@ def main():
w('#define KERN_IMPL_CTYPE {}'.format(ctype))
w('#include "../kern_impl.inl"')

if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#endif')

print('generated {}'.format(fname))


+ 2
- 2
dnn/scripts/gen_elemwise_special_kern_impls.py View File

@@ -30,14 +30,14 @@ def main():
w = lambda s: print(s, file=fout)

w('// generated by gen_elemwise_special_kern_impls.py')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('#include "../special_kerns.inl"')
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef INST')
w('}')
w('}')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')

print('generated {}'.format(fname))


+ 2
- 1
dnn/scripts/gen_elemwise_utils.py View File

@@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'),
'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT')
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
}

MODES = {


+ 6
- 4
dnn/src/common/convolution.cpp View File

@@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16
|| src.enumv() == DTypeEnum::Float16
|| src.enumv() == DTypeEnum::BFloat16
#endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
}

@@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16
|| filter.enumv() == DTypeEnum::Float16
|| filter.enumv() == DTypeEnum::BFloat16
#endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
}



+ 2
- 1
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -87,7 +87,8 @@ namespace megdnn {
//! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp) \
DEF_KERN(dt_float32, _mode, _imp); \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);)
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)

//! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp) \


+ 13
- 12
dnn/src/common/matrix_mul.cpp View File

@@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
C = TensorLayout(TensorShape({A0, B1}), C.dtype);
} else {
auto do_deduce = [&](size_t pack_size) {
megdnn_assert(
A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
megdnn_assert(A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); "
"get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
A0 = A.shape[0];
A1 = A.shape[1];
B0 = B.shape[0];
@@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
std::swap(A0, A1);
if (m_param.transposeB)
std::swap(B0, B1);
megdnn_assert(
A1 == B0,
"shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
megdnn_assert(A1 == B0,
"shape mismatch in matmal: (transposed) A is "
"(%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype);
};
do_deduce(pack_size(param().format));
@@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B,
}
megdnn_assert(param().compute_mode !=
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16(
|| A.dtype == dtype::Float16()),
"ComputeMode::FLOAT32 is only available for Float16 "
|| A.dtype == dtype::Float16() ||
A.dtype == dtype::BFloat16()),
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output.");
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);


+ 8
- 0
dnn/src/common/rounding_converter.cuh View File

@@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> {
}
};

template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
float x) const {
return static_cast<half_bfloat16::bfloat16>(x);
}
};

#endif // #ifdef MEGDNN_DISABLE_FLOAT16

template <>


+ 54
- 0
dnn/src/common/utils.h View File

@@ -16,6 +16,7 @@
#include "megdnn/dtype.h"
#include "megdnn/handle.h"
#include "megdnn/thin/small_vector.h"
#include "megdnn/oprs/general.h"

#include "src/common/hash_ct.h"
#include "src/common/utils.cuh"
@@ -548,6 +549,59 @@ public:
std::string to_string() const;
};

/**!
* \brief helpers for oprs using typecvt between comp_type and dst_type
* \tparam SrcType src type
* \tparam CompType compute type, such as fp32 for conv
* \tparam DstType dst type
*/
template <typename SrcType, typename CompType, typename DstType = SrcType>
struct CompTypeCvter {
std::unique_ptr<TypeCvt> m_cvt_opr;
WorkspaceBundle* m_workspace_bundle;
size_t m_workspace_idx;
CompTypeCvter(Handle* handle, WorkspaceBundle* bundle)
: m_workspace_bundle(bundle), m_workspace_idx(0) {
megdnn_assert(
(DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv &&
DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv),
"SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is "
"not "
"supportted.",
SrcType().name(), CompType().name(), DstType().name(),
CompType().name());
m_cvt_opr = handle->create_operator<TypeCvt>();
}

//! Convert tensor dtype from SrcType to CompType.
CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) {
if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) {
if (!comp.layout.dtype.valid() ||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) {
comp.layout.dtype = CompType();
comp.layout.init_contiguous_stride();
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++);
if (src.layout.ndim) {
m_cvt_opr->exec(src, comp);
}
}
}
return *this;
}

//! Convert tensor dtype from CompType to DstType.
CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) {
megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv);
if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) {
m_cvt_opr->exec(comp, dst);
}
return *this;
}

Workspace workspace() {
return m_workspace_bundle->get_workspace(m_workspace_idx);
}
};
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 41
- 36
dnn/src/common/warp_perspective.cpp View File

@@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());

if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
src.dtype.enumv() == DTypeEnum::Float16,
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ".");
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT &&
(src.dtype == mat.dtype ||
@@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
megdnn_assert(src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT(
src.dtype == dtype::Float16(), false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
src.dtype == dtype::BFloat16()),
false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16",
"") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
"The input to WarpPerspective is in NHWCD4 format, in this "
@@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
}
}

void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(grad, mat, diff);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(src, mat, diff);
megdnn_assert_eq_layout(mat, grad);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src,
mat, diff, grad);
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}



+ 29
- 0
dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu View File

@@ -0,0 +1,29 @@
/**
* \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_cond_take_kern_impls.py
#include "../kern.inl"

#if !MEGDNN_DISABLE_FLOAT16
namespace megdnn {
namespace cuda {
namespace cond_take {

inst_genidx(::megdnn::dtype::BFloat16)
#undef inst_genidx

inst_copy(::megdnn::dtype::BFloat16)
#undef inst_copy
#undef inst_copy_

} // cond_take
} // cuda
} // megdnn
#endif

+ 7
- 0
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1

algo_size = all_algos.size();
for (size_t i = 0; i < algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}

size_t all_algo_size = all_algos.size();
#if CUDA_VERSION >= 10000
fill_imma_algos();


+ 25
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -499,6 +499,28 @@ private:
};
#endif

class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(AlgoBase* impl);

bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }

bool is_reproducible() const override { return m_impl->is_reproducible(); }

private:
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
AlgoBase* m_impl;
std::string m_name;
};

class ConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
@@ -508,7 +530,8 @@ public:

std::vector<AlgoBase*> all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
std::vector<AlgoCUDNNConv> cudnn_convs;
AlgoChanwise chanwise;
@@ -531,6 +554,7 @@ public:
int8_chwn4_imma_unroll_width;
#endif
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;

AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);


+ 120
- 0
dnn/src/cuda/conv_bias/bfloat16.cpp View File

@@ -0,0 +1,120 @@
/**
* \file dnn/src/cuda/conv_bias/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16(
ConvBiasForwardImpl::AlgoBase* algorithm)
: m_impl(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("BFLOAT16:%s", m_impl->name());
}

ConvBiasForwardImpl::AlgoBase::SizeArgs
ConvBiasForwardImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const {
fsrc = *args.src_layout;
ffilter = *args.filter_layout;
fbias = *args.bias_layout;
fz = *args.z_layout;
fdst = *args.dst_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(ffilter);
change_dtype(fbias);
change_dtype(fz);
change_dtype(fdst);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_impl};
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst);
}

bool ConvBiasForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
return args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_impl->is_available(fargs);
}

WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.bias_layout, fbias);
get_workspace(*args.z_layout, fz);
get_workspace(*args.dst_layout, fdst);
sizes.push_back(m_impl->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fbias_tensor = *args.bias_tensor;
TensorND fz_tensor = *args.z_tensor;
TensorND fdst_tensor = *args.dst_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.bias_tensor, fbias_tensor)
.src_to_comp_type(*args.z_tensor, fz_tensor)
.src_to_comp_type(*args.dst_tensor, fdst_tensor);
}
{
auto convbias_opr = args.handle->create_operator<ConvBias>();
convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
convbias_opr->execution_policy() = {m_impl};
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor,
fdst_tensor, cvter.workspace());
}
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/conv_bias/chanwise.cpp View File

@@ -20,6 +20,10 @@ using namespace conv_bias;

bool ConvBiasForwardImpl::AlgoChanwise::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;



+ 4
- 0
dnn/src/cuda/conv_bias/chanwise_small.cpp View File

@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) {

bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
#if CUDA_VERSION < 9000


+ 4
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp View File

@@ -23,6 +23,10 @@ using namespace conv_bias;

bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.bias_layout->ndim == 0 ||
args.bias_layout->eq_shape(*args.dst_layout))
return false;


+ 4
- 0
dnn/src/cuda/conv_bias/group_conv.cpp View File

@@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl)

bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1)
return false;
auto&& param = args.opr->param();


+ 5
- 0
dnn/src/cuda/conv_bias/helper.cpp View File

@@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param,
namespace conv_bias {

bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}

// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1.
if (args.handle->is_tegra_k1())


+ 4
- 0
dnn/src/cuda/conv_bias/matmul.cpp View File

@@ -20,6 +20,10 @@ using namespace cuda;
using namespace conv_bias;

bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;



+ 20
- 7
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/conv_bias/opr_impl.h"
#include "megdnn/dtype.h"
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
@@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
conv_args = orig_args;
}

if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
}
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
}
}
}



+ 1
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -57,6 +57,7 @@ public:
class AlgoInt8NCHW4IMMAImplicitGemm;
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter;
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth;
class AlgoBFloat16;

class AlgoPack;



+ 15
- 7
dnn/src/cuda/convolution/backward_data/algo.cpp View File

@@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {

// add gconv algos by AlgoGroupConvGeneral
auto all_algos_data = all_algos.data();
for (size_t i = 2; i < all_algos.size(); ++ i) {
size_t group_algo_start = 2;
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
gconv.push_back({all_algos[i]});
}
for (size_t i = 2; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - 2];
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - group_algo_start];
}
for (auto &&i: gconv) {
all_algos.push_back(&i);
@@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data());

non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
}

ConvolutionBackwardDataImpl::AlgoCUDNN*
@@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad):
SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad)
SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad)
{
}

ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
ConvolutionBackwardDataImpl *o, const TensorLayout& filter,
const CanonizedFilterMeta &filter_meta, const TensorLayout &diff,
const TensorLayout &grad):
handle{concrete_handle(o->handle())},
filter_meta{filter},
filter_meta{filter_meta},
diff_layout{&diff},
grad_layout{&grad},
filter_layout{&filter},
opr{o}
{
}


+ 32
- 9
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl *opr;

std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(ConvolutionBackwardDataImpl *opr,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl *opr,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout};
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
@@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase {
}
};

class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }

private:
std::string m_name;
ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! implement group conv by another algo
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
@@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack {
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;

std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;

AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
};


+ 115
- 0
dnn/src/cuda/convolution/backward_data/bfloat16.cpp View File

@@ -0,0 +1,115 @@
/**
* \file src/cuda/convolution/backward_data/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardDataImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s",
m_algorithm->name());
}

ConvolutionBackwardDataImpl::AlgoBase::SizeArgs
ConvolutionBackwardDataImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const {
ffilter = *args.filter_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(ffilter);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, ffilter, fdiff, fgrad);
}

bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
return args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->execution_policy() = {m_algorithm};
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/convolution/backward_data/chanwise.cpp View File

@@ -19,6 +19,10 @@ using namespace convolution;

bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available(
const SizeArgs& args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto&& fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/backward_data/chanwise_small.cpp View File

@@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) {

bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
#if CUDA_VERSION < 9000
if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16)
return false;


+ 4
- 0
dnn/src/cuda/convolution/backward_data/group_conv.cpp View File

@@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(

bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args;
TensorLayout diff_pg, grad_pg;
modify_size_args(sub_args, diff_pg, grad_pg);


+ 4
- 0
dnn/src/cuda/convolution/backward_data/matmul.cpp View File

@@ -20,6 +20,10 @@ using namespace cuda;

bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(
const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 16
- 11
dnn/src/cuda/convolution/backward_filter/algo.cpp View File

@@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data());

non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
}

ConvolutionBackwardFilterImpl::AlgoCUDNN*
@@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad):
SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff))
SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff))
{
}

ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad):
handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_filter_meta{grad},
opr{o}
{
}
ConvolutionBackwardFilterImpl* o, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta)
: handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_layout{&grad},
grad_filter_meta{grad_meta},
opr{o} {}

ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs(
ConvolutionBackwardFilterImpl *opr,


+ 29
- 6
dnn/src/cuda/convolution/backward_filter/algo.h View File

@@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
public:
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl *opr;

@@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr,
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout};
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
@@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
}
};

class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }

private:
std::string m_name;
ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args,
ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! implement group conv by another algo
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
@@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack {
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;

std::vector<AlgoBase*>
//! all algorithms
all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;

AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
};


+ 117
- 0
dnn/src/cuda/convolution/backward_filter/bfloat16.cpp View File

@@ -0,0 +1,117 @@
/**
* \file src/cuda/convolution/backward_filter/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardFilterImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s",
m_algorithm->name());
}

ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs
ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardFilterImpl* opr,
TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const {
fsrc = *args.src_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, fsrc, fdiff, fgrad);
}

bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
return args.src_layout->dtype == args.diff_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
conv_back_filter_opr->param() = args.opr->param();
conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT;
conv_back_filter_opr->execution_policy() = {m_algorithm};
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/convolution/backward_filter/chanwise.cpp View File

@@ -19,6 +19,10 @@ using namespace convolution;

bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/backward_filter/group_conv.cpp View File

@@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(

bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args;
TensorLayout src_pg, diff_pg;
modify_size_args(sub_args, src_pg, diff_pg);


+ 4
- 0
dnn/src/cuda/convolution/backward_filter/matmul.cpp View File

@@ -19,6 +19,10 @@ using namespace cuda;

bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available(
const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/helper.cpp View File

@@ -16,6 +16,10 @@ using namespace cuda;
using namespace convolution;

bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}

// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1.


+ 1
- 0
dnn/src/cuda/convolution/helper.h View File

@@ -25,6 +25,7 @@ namespace convolution {
struct ForwardSizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout;
const TensorLayout *filter_layout;
CanonizedFilterMeta filter_meta;
const TensorLayout *dst_layout;
};


+ 54
- 28
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout);
auto algo = get_algorithm(this, filter.layout, args.filter_meta,
diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args);
}

@@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(filter, fm, diff, grad,
workspace_limit_in_bytes, reproducible);
}

ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);

if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible(
@@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
args = orig_args;
}

if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (args.filter_layout->dtype.enumv() !=
DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
}
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
}
}
}

@@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout &diff,
const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, args.filter_meta, diff, grad)->
return get_algorithm(this, filter, args.filter_meta, diff, grad)->
get_workspace_in_bytes(args);
}

@@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo = get_algorithm(this, src.layout, diff.layout,
args.grad_filter_meta);
grad.layout, args.grad_filter_meta);
algo->check_workspace(args, workspace).exec(args);
}

@@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(src, diff, grad, fm,
workspace_limit_in_bytes, reproducible);
}

ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad);
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);

if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible(
@@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
args = orig_args;
}

if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
}
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
}
}
}

@@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout &diff,
const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, args.grad_filter_meta)->
return get_algorithm(this, src, diff, grad, args.grad_filter_meta)->
get_workspace_in_bytes(args);
}



+ 9
- 6
dnn/src/cuda/convolution/opr_impl.h View File

@@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
@@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;

@@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
const TensorLayout& gradk,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
@@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;



+ 1
- 1
dnn/src/cuda/convolution3d/forward/algo.h View File

@@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
const CanonizedFilterMeta &filter,
const TensorLayout &dst);
};
struct ExecArgs: public SizeArgs {
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *filter_tensor, *dst_tensor;
Workspace workspace;



+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 18
- 0
dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu View File

@@ -0,0 +1,18 @@
/**
* \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_special_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#include "../special_kerns.inl"
INST(::megdnn::dtype::BFloat16)
#undef INST
}
}
#endif

+ 3
- 0
dnn/src/cuda/elemwise_helper.cpp View File

@@ -141,6 +141,9 @@ INST_FOR_CTYPE
#define ct dt_float16
INST_FOR_CTYPE
#undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8
INST_FOR_CTYPE
#undef ct


+ 12
- 0
dnn/src/cuda/elemwise_helper.cuh View File

@@ -68,6 +68,17 @@ namespace elemwise_intl {
return t;
}

struct __attribute__((aligned(8))) bhalf4 {
dt_bfloat16 x, y, z, w;
};

__device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y,
dt_bfloat16 z, dt_bfloat16 w) {
bhalf4 t;
t.x = x, t.y = y, t.z = z, t.w = w;
return t;
}

#define INST(_ctype, _vect_type) \
template <> \
class VectTypeTrait<_ctype> { \
@@ -87,6 +98,7 @@ namespace elemwise_intl {
INST(dt_uint8, uchar4);
INST(dt_float32, float4);
INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4);
INST(dt_int16, short4);
#undef as_raw


+ 5
- 0
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu View File

@@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) {
__trap();
((int*)0)[0] = 1;
}

__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) {
__trap();
((int*)0)[0] = 1;
}
#endif

__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) {


+ 4
- 0
dnn/src/cuda/matrix_mul/algos.cpp View File

@@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&cublas_lt);
#endif
all_algos.push_back(&naive);
#if !MEGDNN_DISABLE_FLOAT16
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas);
all_algos.push_back(cublas_bfloat16.get());
#endif
}

MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;


+ 22
- 1
dnn/src/cuda/matrix_mul/algos.h View File

@@ -15,6 +15,7 @@
#include "src/cuda/matrix_mul/opr_impl.h"

#include <cuda.h>
#include <memory>
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
#endif
@@ -140,6 +141,24 @@ public:
bool is_reproducible() const override { return true; }
};

#if !MEGDNN_DISABLE_FLOAT16
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }

private:
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
std::string m_name;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
SizeArgs float_args(const SizeArgs& args) const;
};
#endif

class MatrixMulForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
@@ -154,7 +173,9 @@ public:
#if CUDA_VERSION >= 10010
AlgoCuBlasLt cublas_lt;
#endif

#if !MEGDNN_DISABLE_FLOAT16
std::unique_ptr<AlgoBFloat16> cublas_bfloat16;
#endif
std::vector<AlgoBase*> all_algos;
};



+ 91
- 0
dnn/src/cuda/matrix_mul/bfloat16.cpp View File

@@ -0,0 +1,91 @@
/**
* \file dnn/src/cuda/matrix_mul/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16(
MatrixMulForwardImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name());
}

MatrixMulForwardImpl::AlgoBase::SizeArgs
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const {
auto new_args = args;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(new_args.layout_a);
change_dtype(new_args.layout_b);
change_dtype(new_args.layout_c);
return new_args;
}

bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
auto fargs = float_args(args);
return args.layout_a.dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto fargs = float_args(args);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src) {
TensorLayout dst = src;
if (dst.dtype == dtype::BFloat16()) {
dst.dtype = dtype::Float32();
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(args.layout_a);
get_workspace(args.layout_b);
get_workspace(args.layout_c);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
TensorND a = args.tensor_a;
TensorND b = args.tensor_b;
TensorND c = args.tensor_c;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
args.opr->handle(), &bundle);
ctypecvt.src_to_comp_type(args.tensor_a, a)
.src_to_comp_type(args.tensor_b, b)
.src_to_comp_type(args.tensor_c, c);
{
auto matmul_opr =
args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
matmul_opr->execution_policy() = {m_algorithm};
matmul_opr->exec(a, b, c, ctypecvt.workspace());
}
ctypecvt.comp_to_dst_type(c, args.tensor_c);
}

// vim: syntax=cpp.doxygen

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save