Browse Source

feat(mge): add bfloat16 support

GitOrigin-RevId: a942ce6791
tags/v0.5.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
0293d58ade
100 changed files with 4979 additions and 308 deletions
  1. +11
    -0
      dnn/include/megdnn/dtype.h
  2. +2965
    -0
      dnn/include/megdnn/dtype/bfloat16.hpp
  3. +2
    -171
      dnn/include/megdnn/dtype/half.hpp
  4. +48
    -0
      dnn/include/megdnn/dtype/half_common_epilogue.h
  5. +202
    -0
      dnn/include/megdnn/dtype/half_common_prologue.h
  6. +2
    -2
      dnn/scripts/gen_cond_take_kern_impls.py
  7. +2
    -2
      dnn/scripts/gen_elemwise_kern_impls.py
  8. +2
    -2
      dnn/scripts/gen_elemwise_special_kern_impls.py
  9. +2
    -1
      dnn/scripts/gen_elemwise_utils.py
  10. +6
    -4
      dnn/src/common/convolution.cpp
  11. +2
    -1
      dnn/src/common/elemwise/kern_defs.cuh
  12. +13
    -12
      dnn/src/common/matrix_mul.cpp
  13. +8
    -0
      dnn/src/common/rounding_converter.cuh
  14. +54
    -0
      dnn/src/common/utils.h
  15. +41
    -36
      dnn/src/common/warp_perspective.cpp
  16. +29
    -0
      dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu
  17. +7
    -0
      dnn/src/cuda/conv_bias/algo.cpp
  18. +25
    -1
      dnn/src/cuda/conv_bias/algo.h
  19. +120
    -0
      dnn/src/cuda/conv_bias/bfloat16.cpp
  20. +4
    -0
      dnn/src/cuda/conv_bias/chanwise.cpp
  21. +4
    -0
      dnn/src/cuda/conv_bias/chanwise_small.cpp
  22. +4
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
  23. +4
    -0
      dnn/src/cuda/conv_bias/group_conv.cpp
  24. +5
    -0
      dnn/src/cuda/conv_bias/helper.cpp
  25. +4
    -0
      dnn/src/cuda/conv_bias/matmul.cpp
  26. +20
    -7
      dnn/src/cuda/conv_bias/opr_impl.cpp
  27. +1
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  28. +15
    -7
      dnn/src/cuda/convolution/backward_data/algo.cpp
  29. +32
    -9
      dnn/src/cuda/convolution/backward_data/algo.h
  30. +115
    -0
      dnn/src/cuda/convolution/backward_data/bfloat16.cpp
  31. +4
    -0
      dnn/src/cuda/convolution/backward_data/chanwise.cpp
  32. +4
    -0
      dnn/src/cuda/convolution/backward_data/chanwise_small.cpp
  33. +4
    -0
      dnn/src/cuda/convolution/backward_data/group_conv.cpp
  34. +4
    -0
      dnn/src/cuda/convolution/backward_data/matmul.cpp
  35. +16
    -11
      dnn/src/cuda/convolution/backward_filter/algo.cpp
  36. +29
    -6
      dnn/src/cuda/convolution/backward_filter/algo.h
  37. +117
    -0
      dnn/src/cuda/convolution/backward_filter/bfloat16.cpp
  38. +4
    -0
      dnn/src/cuda/convolution/backward_filter/chanwise.cpp
  39. +4
    -0
      dnn/src/cuda/convolution/backward_filter/group_conv.cpp
  40. +4
    -0
      dnn/src/cuda/convolution/backward_filter/matmul.cpp
  41. +4
    -0
      dnn/src/cuda/convolution/helper.cpp
  42. +1
    -0
      dnn/src/cuda/convolution/helper.h
  43. +54
    -28
      dnn/src/cuda/convolution/opr_impl.cpp
  44. +9
    -6
      dnn/src/cuda/convolution/opr_impl.h
  45. +1
    -1
      dnn/src/cuda/convolution3d/forward/algo.h
  46. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu
  47. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu
  48. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu
  49. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu
  50. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu
  51. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu
  52. +17
    -0
      dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu
  53. +17
    -0
      dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu
  54. +17
    -0
      dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu
  55. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu
  56. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu
  57. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu
  58. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu
  59. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu
  60. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu
  61. +17
    -0
      dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu
  62. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu
  63. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu
  64. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu
  65. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu
  66. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu
  67. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu
  68. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu
  69. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu
  70. +17
    -0
      dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu
  71. +17
    -0
      dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu
  72. +17
    -0
      dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu
  73. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu
  74. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu
  75. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu
  76. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu
  77. +17
    -0
      dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu
  78. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu
  79. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu
  80. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu
  81. +17
    -0
      dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu
  82. +17
    -0
      dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu
  83. +17
    -0
      dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu
  84. +17
    -0
      dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu
  85. +17
    -0
      dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu
  86. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu
  87. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu
  88. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu
  89. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu
  90. +17
    -0
      dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu
  91. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu
  92. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu
  93. +17
    -0
      dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu
  94. +18
    -0
      dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu
  95. +3
    -0
      dnn/src/cuda/elemwise_helper.cpp
  96. +12
    -0
      dnn/src/cuda/elemwise_helper.cuh
  97. +5
    -0
      dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
  98. +4
    -0
      dnn/src/cuda/matrix_mul/algos.cpp
  99. +22
    -1
      dnn/src/cuda/matrix_mul/algos.h
  100. +91
    -0
      dnn/src/cuda/matrix_mul/bfloat16.cpp

+ 11
- 0
dnn/include/megdnn/dtype.h View File

@@ -29,6 +29,7 @@
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y #define MEGDNN_FLOAT16_SELECT(_x, _y) _y
#else #else
#include "megdnn/dtype/half.hpp" #include "megdnn/dtype/half.hpp"
#include "megdnn/dtype/bfloat16.hpp"
#define MEGDNN_INC_FLOAT16(_x) _x #define MEGDNN_INC_FLOAT16(_x) _x
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x #define MEGDNN_FLOAT16_SELECT(_x, _y) _x
#endif #endif
@@ -49,6 +50,7 @@ namespace megdnn {
cb(IntB4) \ cb(IntB4) \
cb(Byte) \ cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \ cb(UintB4) \


/*! /*!
@@ -62,6 +64,7 @@ namespace megdnn {
cb(Int32) \ cb(Int32) \
cb(Byte) \ cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \ MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \


/*! /*!
* \brief iterate through each fractional byte dtype * \brief iterate through each fractional byte dtype
@@ -101,6 +104,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
cb(::megdnn::dtype::Float32) \ cb(::megdnn::dtype::Float32) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \


/*! /*!
* \brief iterate through each dtype object that can be involved in integer * \brief iterate through each dtype object that can be involved in integer
@@ -345,6 +349,7 @@ typedef int16_t dt_int16;
typedef int8_t dt_int8; typedef int8_t dt_int8;
typedef uint8_t dt_uint8; typedef uint8_t dt_uint8;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)


#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
@@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
Float16, Float16,
#endif #endif
UintB4 = 10, UintB4 = 10,
#if !MEGDNN_DISABLE_FLOAT16
BFloat16 = 11,
#endif


#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name, #define D(_name) _name,
@@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(), std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max())); std::numeric_limits<dt_float16>::max()));
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED,
std::numeric_limits<dt_bfloat16>::lowest(),
std::numeric_limits<dt_bfloat16>::max()));


template <> template <>
struct DTypeTrait<dtype::Byte> { struct DTypeTrait<dtype::Byte> {


+ 2965
- 0
dnn/include/megdnn/dtype/bfloat16.hpp
File diff suppressed because it is too large
View File


+ 2
- 171
dnn/include/megdnn/dtype/half.hpp View File

@@ -50,167 +50,7 @@
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#endif #endif


/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)

//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif

//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION

//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif

//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif

#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif

#include "megdnn/dtype/half_common_prologue.h"


/// Default rounding mode. /// Default rounding mode.
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as /// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as
@@ -3141,16 +2981,7 @@ namespace std
#endif #endif
} }



#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif

#include "megdnn/dtype/half_common_epilogue.h"
#endif #endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 48
- 0
dnn/include/megdnn/dtype/half_common_epilogue.h View File

@@ -0,0 +1,48 @@
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file include/megdnn/dtype/half_common_epilogue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/

#undef HALF_CONSTEXPR
#undef HALF_CONSTEXPR_CONST
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif

// vim: syntax=cpp.doxygen

+ 202
- 0
dnn/include/megdnn/dtype/half_common_prologue.h View File

@@ -0,0 +1,202 @@
/**
* half - IEEE 754-based half-precision floating point library.
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
* Main header file for half precision functionality.
*
* --------------------------------------------------------------------------
* \file dnn/include/megdnn/dtype/half_common_prologue.h
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
*
* --------------------------------------------------------------------------
*/

#include "megdnn/arch.h"

/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)

//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
#endif

//check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION

//support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif

//support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif

#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#endif

// vim: syntax=cpp.doxygen

+ 2
- 2
dnn/scripts/gen_cond_take_kern_impls.py View File

@@ -30,7 +30,7 @@ def main():
w('// generated by gen_cond_take_kern_impls.py') w('// generated by gen_cond_take_kern_impls.py')
w('#include "../kern.inl"') w('#include "../kern.inl"')
w('') w('')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16') w('#if !MEGDNN_DISABLE_FLOAT16')
w('namespace megdnn {') w('namespace megdnn {')
w('namespace cuda {') w('namespace cuda {')
@@ -48,7 +48,7 @@ def main():
w('} // cond_take') w('} // cond_take')
w('} // cuda') w('} // cuda')
w('} // megdnn') w('} // megdnn')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif') w('#endif')


print('generated {}'.format(fname)) print('generated {}'.format(fname))


+ 2
- 2
dnn/scripts/gen_elemwise_kern_impls.py View File

@@ -34,7 +34,7 @@ def main():
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_kern_impls.py') w('// generated by gen_elemwise_kern_impls.py')


if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16') w('#if !MEGDNN_DISABLE_FLOAT16')


w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
@@ -42,7 +42,7 @@ def main():
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) w('#define KERN_IMPL_CTYPE {}'.format(ctype))
w('#include "../kern_impl.inl"') w('#include "../kern_impl.inl"')


if ctype == 'dt_float16':
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#endif') w('#endif')


print('generated {}'.format(fname)) print('generated {}'.format(fname))


+ 2
- 2
dnn/scripts/gen_elemwise_special_kern_impls.py View File

@@ -30,14 +30,14 @@ def main():
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)


w('// generated by gen_elemwise_special_kern_impls.py') w('// generated by gen_elemwise_special_kern_impls.py')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16') w('#if !MEGDNN_DISABLE_FLOAT16')
w('#include "../special_kerns.inl"') w('#include "../special_kerns.inl"')
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef INST') w('#undef INST')
w('}') w('}')
w('}') w('}')
if dtype == 'dt_float16':
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif') w('#endif')


print('generated {}'.format(fname)) print('generated {}'.format(fname))


+ 2
- 1
dnn/scripts/gen_elemwise_utils.py View File

@@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_int8': ('Int8', 'INT'), 'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'), 'dt_int16': ('Int16', 'INT'),
'dt_float32': ('Float32', 'FLOAT'), 'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT')
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
} }


MODES = { MODES = {


+ 6
- 4
dnn/src/common/convolution.cpp View File

@@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
|| src.enumv() == DTypeEnum::Float16 || src.enumv() == DTypeEnum::Float16
|| src.enumv() == DTypeEnum::BFloat16
#endif #endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output."); "input / output.");
} }


@@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff,
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
|| filter.enumv() == DTypeEnum::Float16 || filter.enumv() == DTypeEnum::Float16
|| filter.enumv() == DTypeEnum::BFloat16
#endif #endif
,
"ComputeMode::FLOAT32 is only available for Float16 "
,
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output."); "input / output.");
} }




+ 2
- 1
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -87,7 +87,8 @@ namespace megdnn {
//! define kernel for all float types //! define kernel for all float types
#define DEF_KERN_FLOAT(_mode, _imp) \ #define DEF_KERN_FLOAT(_mode, _imp) \
DEF_KERN(dt_float32, _mode, _imp); \ DEF_KERN(dt_float32, _mode, _imp); \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);)
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \
MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);)


//! define kernel for all int types //! define kernel for all int types
#define DEF_KERN_INT(_mode, _imp) \ #define DEF_KERN_INT(_mode, _imp) \


+ 13
- 12
dnn/src/common/matrix_mul.cpp View File

@@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
C = TensorLayout(TensorShape({A0, B1}), C.dtype); C = TensorLayout(TensorShape({A0, B1}), C.dtype);
} else { } else {
auto do_deduce = [&](size_t pack_size) { auto do_deduce = [&](size_t pack_size) {
megdnn_assert(
A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
megdnn_assert(A.ndim == 4 && B.ndim == 3,
"matmul requires input dimension to be A(4), B(3); "
"get: %s %s",
A.TensorShape::to_string().c_str(),
B.TensorShape::to_string().c_str());
A0 = A.shape[0]; A0 = A.shape[0];
A1 = A.shape[1]; A1 = A.shape[1];
B0 = B.shape[0]; B0 = B.shape[0];
@@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A,
std::swap(A0, A1); std::swap(A0, A1);
if (m_param.transposeB) if (m_param.transposeB)
std::swap(B0, B1); std::swap(B0, B1);
megdnn_assert(
A1 == B0,
"shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
megdnn_assert(A1 == B0,
"shape mismatch in matmal: (transposed) A is "
"(%zu,%zu,4,4), "
"(transposed) B is (%zu,%zu,4)",
A0, A1, B0, B1);
C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype);
}; };
do_deduce(pack_size(param().format)); do_deduce(pack_size(param().format));
@@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B,
} }
megdnn_assert(param().compute_mode != megdnn_assert(param().compute_mode !=
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16(
|| A.dtype == dtype::Float16()),
"ComputeMode::FLOAT32 is only available for Float16 "
|| A.dtype == dtype::Float16() ||
A.dtype == dtype::BFloat16()),
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
"input / output."); "input / output.");
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);


+ 8
- 0
dnn/src/common/rounding_converter.cuh View File

@@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> {
} }
}; };


template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
float x) const {
return static_cast<half_bfloat16::bfloat16>(x);
}
};

#endif // #ifdef MEGDNN_DISABLE_FLOAT16 #endif // #ifdef MEGDNN_DISABLE_FLOAT16


template <> template <>


+ 54
- 0
dnn/src/common/utils.h View File

@@ -16,6 +16,7 @@
#include "megdnn/dtype.h" #include "megdnn/dtype.h"
#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/thin/small_vector.h" #include "megdnn/thin/small_vector.h"
#include "megdnn/oprs/general.h"


#include "src/common/hash_ct.h" #include "src/common/hash_ct.h"
#include "src/common/utils.cuh" #include "src/common/utils.cuh"
@@ -548,6 +549,59 @@ public:
std::string to_string() const; std::string to_string() const;
}; };


/**!
* \brief helpers for oprs using typecvt between comp_type and dst_type
* \tparam SrcType src type
* \tparam CompType compute type, such as fp32 for conv
* \tparam DstType dst type
*/
template <typename SrcType, typename CompType, typename DstType = SrcType>
struct CompTypeCvter {
std::unique_ptr<TypeCvt> m_cvt_opr;
WorkspaceBundle* m_workspace_bundle;
size_t m_workspace_idx;
CompTypeCvter(Handle* handle, WorkspaceBundle* bundle)
: m_workspace_bundle(bundle), m_workspace_idx(0) {
megdnn_assert(
(DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv &&
DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv),
"SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is "
"not "
"supportted.",
SrcType().name(), CompType().name(), DstType().name(),
CompType().name());
m_cvt_opr = handle->create_operator<TypeCvt>();
}

//! Convert tensor dtype from SrcType to CompType.
CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) {
if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) {
if (!comp.layout.dtype.valid() ||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) {
comp.layout.dtype = CompType();
comp.layout.init_contiguous_stride();
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++);
if (src.layout.ndim) {
m_cvt_opr->exec(src, comp);
}
}
}
return *this;
}

//! Convert tensor dtype from CompType to DstType.
CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) {
megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv);
if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) {
m_cvt_opr->exec(comp, dst);
}
return *this;
}

Workspace workspace() {
return m_workspace_bundle->get_workspace(m_workspace_idx);
}
};
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 41
- 36
dnn/src/common/warp_perspective.cpp View File

@@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());


if (param().format == param::WarpPerspective::Format::NCHW) { if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
src.dtype.enumv() == DTypeEnum::Float16,
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ".");
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert( megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT && (src.dtype.category() == DTypeCategory::FLOAT &&
(src.dtype == mat.dtype || (src.dtype == mat.dtype ||
@@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
param::WarpPerspective::BorderMode::ISOLATED); param::WarpPerspective::BorderMode::ISOLATED);
} else { } else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
megdnn_assert(src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT(
src.dtype == dtype::Float16(), false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16", "") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
src.dtype == dtype::BFloat16()),
false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16",
"") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert( megdnn_assert(
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()), (src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
"The input to WarpPerspective is in NHWCD4 format, in this " "The input to WarpPerspective is in NHWCD4 format, in this "
@@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
} }
} }


void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(grad, mat, diff); check_layout_fwd(grad, mat, diff);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }


void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes)
{
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
const TensorLayout& mat,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(src, mat, diff); check_layout_fwd(src, mat, diff);
megdnn_assert_eq_layout(mat, grad); megdnn_assert_eq_layout(mat, grad);
megdnn_assert(grad.dtype == dtype::Float32(),
"Backward WarpPerspective only supports Float32.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src,
mat, diff, grad);
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, mat, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }




+ 29
- 0
dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu View File

@@ -0,0 +1,29 @@
/**
* \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_cond_take_kern_impls.py
#include "../kern.inl"

#if !MEGDNN_DISABLE_FLOAT16
namespace megdnn {
namespace cuda {
namespace cond_take {

inst_genidx(::megdnn::dtype::BFloat16)
#undef inst_genidx

inst_copy(::megdnn::dtype::BFloat16)
#undef inst_copy
#undef inst_copy_

} // cond_take
} // cuda
} // megdnn
#endif

+ 7
- 0
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1


algo_size = all_algos.size();
for (size_t i = 0; i < algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}

size_t all_algo_size = all_algos.size(); size_t all_algo_size = all_algos.size();
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
fill_imma_algos(); fill_imma_algos();


+ 25
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -499,6 +499,28 @@ private:
}; };
#endif #endif


class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(AlgoBase* impl);

bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }

bool is_reproducible() const override { return m_impl->is_reproducible(); }

private:
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
AlgoBase* m_impl;
std::string m_name;
};

class ConvBiasForwardImpl::AlgoPack { class ConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete; AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete; AlgoPack& operator=(const AlgoPack&) = delete;
@@ -508,7 +530,8 @@ public:


std::vector<AlgoBase*> all_algos, std::vector<AlgoBase*> all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
std::vector<AlgoCUDNNConv> cudnn_convs; std::vector<AlgoCUDNNConv> cudnn_convs;
AlgoChanwise chanwise; AlgoChanwise chanwise;
@@ -531,6 +554,7 @@ public:
int8_chwn4_imma_unroll_width; int8_chwn4_imma_unroll_width;
#endif #endif
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;


AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);


+ 120
- 0
dnn/src/cuda/conv_bias/bfloat16.cpp View File

@@ -0,0 +1,120 @@
/**
* \file dnn/src/cuda/conv_bias/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16(
ConvBiasForwardImpl::AlgoBase* algorithm)
: m_impl(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("BFLOAT16:%s", m_impl->name());
}

ConvBiasForwardImpl::AlgoBase::SizeArgs
ConvBiasForwardImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz,
TensorLayout& fdst) const {
fsrc = *args.src_layout;
ffilter = *args.filter_layout;
fbias = *args.bias_layout;
fz = *args.z_layout;
fdst = *args.dst_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(ffilter);
change_dtype(fbias);
change_dtype(fz);
change_dtype(fdst);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_impl};
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst);
}

bool ConvBiasForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
return args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_impl->is_available(fargs);
}

WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, ffilter, fbias, fz, fdst;
auto convbias_opr = args.handle->create_operator<ConvBias>();
SizeArgs fargs = float_args(
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc,
ffilter, fbias, fz, fdst);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.bias_layout, fbias);
get_workspace(*args.z_layout, fz);
get_workspace(*args.dst_layout, fdst);
sizes.push_back(m_impl->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fbias_tensor = *args.bias_tensor;
TensorND fz_tensor = *args.z_tensor;
TensorND fdst_tensor = *args.dst_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.bias_tensor, fbias_tensor)
.src_to_comp_type(*args.z_tensor, fz_tensor)
.src_to_comp_type(*args.dst_tensor, fdst_tensor);
}
{
auto convbias_opr = args.handle->create_operator<ConvBias>();
convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
convbias_opr->execution_policy() = {m_impl};
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor,
fdst_tensor, cvter.workspace());
}
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/conv_bias/chanwise.cpp View File

@@ -20,6 +20,10 @@ using namespace conv_bias;


bool ConvBiasForwardImpl::AlgoChanwise::is_available( bool ConvBiasForwardImpl::AlgoChanwise::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;




+ 4
- 0
dnn/src/cuda/conv_bias/chanwise_small.cpp View File

@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) {


bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000


+ 4
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp View File

@@ -23,6 +23,10 @@ using namespace conv_bias;


bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.bias_layout->ndim == 0 || if (args.bias_layout->ndim == 0 ||
args.bias_layout->eq_shape(*args.dst_layout)) args.bias_layout->eq_shape(*args.dst_layout))
return false; return false;


+ 4
- 0
dnn/src/cuda/conv_bias/group_conv.cpp View File

@@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl)


bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1)
return false; return false;
auto&& param = args.opr->param(); auto&& param = args.opr->param();


+ 5
- 0
dnn/src/cuda/conv_bias/helper.cpp View File

@@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param,
namespace conv_bias { namespace conv_bias {


bool is_cudnn_supported(const BiasForwardSizeArgs& args) { bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}

// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1. // on Tegra K1.
if (args.handle->is_tegra_k1()) if (args.handle->is_tegra_k1())


+ 4
- 0
dnn/src/cuda/conv_bias/matmul.cpp View File

@@ -20,6 +20,10 @@ using namespace cuda;
using namespace conv_bias; using namespace conv_bias;


bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;




+ 20
- 7
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/conv_bias/opr_impl.h"
#include "megdnn/dtype.h"
#include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
@@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
conv_args = orig_args; conv_args = orig_args;
} }


if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda convbias fwd");
}
} else { } else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
} else {
return megdnn::get_usable_algo<ConvBiasForwardImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda convbias fwd");
}
} }
} }




+ 1
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -57,6 +57,7 @@ public:
class AlgoInt8NCHW4IMMAImplicitGemm; class AlgoInt8NCHW4IMMAImplicitGemm;
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter;
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth;
class AlgoBFloat16;


class AlgoPack; class AlgoPack;




+ 15
- 7
dnn/src/cuda/convolution/backward_data/algo.cpp View File

@@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {


// add gconv algos by AlgoGroupConvGeneral // add gconv algos by AlgoGroupConvGeneral
auto all_algos_data = all_algos.data(); auto all_algos_data = all_algos.data();
for (size_t i = 2; i < all_algos.size(); ++ i) {
size_t group_algo_start = 2;
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
gconv.push_back({all_algos[i]}); gconv.push_back({all_algos[i]});
} }
for (size_t i = 2; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - 2];
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) {
algo2gconv[all_algos[i]] = &gconv[i - group_algo_start];
} }
for (auto &&i: gconv) { for (auto &&i: gconv) {
all_algos.push_back(&i); all_algos.push_back(&i);
@@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data()); megdnn_assert(all_algos_data == all_algos.data());


non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
} }


ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoCUDNN*
@@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o, ConvolutionBackwardDataImpl *o,
const TensorLayout &filter, const TensorLayout &diff, const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad): const TensorLayout &grad):
SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad)
SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad)
{ {
} }


ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl *o,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
ConvolutionBackwardDataImpl *o, const TensorLayout& filter,
const CanonizedFilterMeta &filter_meta, const TensorLayout &diff,
const TensorLayout &grad): const TensorLayout &grad):
handle{concrete_handle(o->handle())}, handle{concrete_handle(o->handle())},
filter_meta{filter},
filter_meta{filter_meta},
diff_layout{&diff}, diff_layout{&diff},
grad_layout{&grad}, grad_layout{&grad},
filter_layout{&filter},
opr{o} opr{o}
{ {
} }


+ 32
- 9
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
CanonizedFilterMeta filter_meta; CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl *opr; ConvolutionBackwardDataImpl *opr;


std::string to_string() const; std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs &desc) const { void init_desc(convolution::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
} }
SizeArgs(ConvolutionBackwardDataImpl *opr,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl *opr,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);


convolution::ForwardSizeArgs as_fwd_args() const { convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout};
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
} }
}; };
struct ExecArgs: public SizeArgs { struct ExecArgs: public SizeArgs {
@@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase {
} }
}; };


class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }

private:
std::string m_name;
ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! implement group conv by another algo //! implement group conv by another algo
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl; AlgoBase *m_impl;
@@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack {
AlgoChanwiseSmall chanwise_small; AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv; std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;


std::vector<AlgoBase*> std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
}; };


+ 115
- 0
dnn/src/cuda/convolution/backward_data/bfloat16.cpp View File

@@ -0,0 +1,115 @@
/**
* \file src/cuda/convolution/backward_data/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardDataImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s",
m_algorithm->name());
}

ConvolutionBackwardDataImpl::AlgoBase::SizeArgs
ConvolutionBackwardDataImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardDataImpl* opr,
TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const {
ffilter = *args.filter_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(ffilter);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, ffilter, fdiff, fgrad);
}

bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
return args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout ffilter, fdiff, fgrad;
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
SizeArgs fargs = float_args(
args,
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()),
ffilter, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.filter_layout, ffilter);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND ffilter_tensor = *args.filter_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_data_opr =
args.handle->create_operator<ConvolutionBackwardData>();
conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->execution_policy() = {m_algorithm};
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/convolution/backward_data/chanwise.cpp View File

@@ -19,6 +19,10 @@ using namespace convolution;


bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto&& fm = args.filter_meta; auto&& fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW && return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/backward_data/chanwise_small.cpp View File

@@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) {


bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16)
return false; return false;


+ 4
- 0
dnn/src/cuda/convolution/backward_data/group_conv.cpp View File

@@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(


bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args; auto sub_args = args;
TensorLayout diff_pg, grad_pg; TensorLayout diff_pg, grad_pg;
modify_size_args(sub_args, diff_pg, grad_pg); modify_size_args(sub_args, diff_pg, grad_pg);


+ 4
- 0
dnn/src/cuda/convolution/backward_data/matmul.cpp View File

@@ -20,6 +20,10 @@ using namespace cuda;


bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.diff_layout->dtype == args.filter_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.filter_meta; auto &&fm = args.filter_meta;
return args.filter_meta.format == Param::Format::NCHW && return args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 16
- 11
dnn/src/cuda/convolution/backward_filter/algo.cpp View File

@@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
megdnn_assert(all_algos_data == all_algos.data()); megdnn_assert(all_algos_data == all_algos.data());


non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul
size_t algo_size = all_algos.size();
for (size_t i=0; i<algo_size; ++i) {
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i]));
all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get());
}
} }


ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoCUDNN*
@@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o, ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff, const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad): const TensorLayout &grad):
SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff))
SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff))
{ {
} }


ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardFilterImpl *o,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad):
handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_filter_meta{grad},
opr{o}
{
}
ConvolutionBackwardFilterImpl* o, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta)
: handle{concrete_handle(o->handle())},
src_layout{&src},
diff_layout{&diff},
grad_layout{&grad},
grad_filter_meta{grad_meta},
opr{o} {}


ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs(
ConvolutionBackwardFilterImpl *opr, ConvolutionBackwardFilterImpl *opr,


+ 29
- 6
dnn/src/cuda/convolution/backward_filter/algo.h View File

@@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
public: public:
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta; CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl *opr; ConvolutionBackwardFilterImpl *opr;


@@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
SizeArgs(ConvolutionBackwardFilterImpl *opr, SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff, const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad); const TensorLayout &grad);
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr,
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);


convolution::ForwardSizeArgs as_fwd_args() const { convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout};
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
} }
}; };
struct ExecArgs: public SizeArgs { struct ExecArgs: public SizeArgs {
@@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
} }
}; };


class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; }

private:
std::string m_name;
ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr;
SizeArgs float_args(const SizeArgs& args,
ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc,
TensorLayout& ffilter, TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
};

//! implement group conv by another algo //! implement group conv by another algo
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl; AlgoBase *m_impl;
@@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack {
AlgoChanwise chanwise; AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv; std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;


std::vector<AlgoBase*> std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos;
non_cudnn_algos,
bfloat16_algos;


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
}; };


+ 117
- 0
dnn/src/cuda/convolution/backward_filter/bfloat16.cpp View File

@@ -0,0 +1,117 @@
/**
* \file src/cuda/convolution/backward_filter/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "./algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace convolution;

ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16(
ConvolutionBackwardFilterImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s",
m_algorithm->name());
}

ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs
ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args(
const SizeArgs& args, ConvolutionBackwardFilterImpl* opr,
TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const {
fsrc = *args.src_layout;
fdiff = *args.diff_layout;
fgrad = *args.grad_layout;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(fsrc);
change_dtype(fdiff);
change_dtype(fgrad);
opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
return SizeArgs(opr, fsrc, fdiff, fgrad);
}

bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
return args.src_layout->dtype == args.diff_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
TensorLayout fsrc, fdiff, fgrad;
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
SizeArgs fargs = float_args(args,
static_cast<ConvolutionBackwardFilterImpl*>(
conv_back_filter_opr.get()),
fsrc, fdiff, fgrad);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src,
const TensorLayout& dst) {
if (src.dtype != dst.dtype) {
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(*args.src_layout, fsrc);
get_workspace(*args.diff_layout, fdiff);
get_workspace(*args.grad_layout, fgrad);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
const ExecArgs& args) const {
TensorND fsrc_tensor = *args.src_tensor;
TensorND fdiff_tensor = *args.diff_tensor;
TensorND fgrad_tensor = *args.grad_tensor;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle);
{
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor)
.src_to_comp_type(*args.diff_tensor, fdiff_tensor)
.src_to_comp_type(*args.grad_tensor, fgrad_tensor);
}
{
auto conv_back_filter_opr =
args.handle->create_operator<ConvolutionBackwardFilter>();
conv_back_filter_opr->param() = args.opr->param();
conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT;
conv_back_filter_opr->execution_policy() = {m_algorithm};
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace());
}
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); }
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/cuda/convolution/backward_filter/chanwise.cpp View File

@@ -19,6 +19,10 @@ using namespace convolution;


bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta; auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW && return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/backward_filter/group_conv.cpp View File

@@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(


bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto sub_args = args; auto sub_args = args;
TensorLayout src_pg, diff_pg; TensorLayout src_pg, diff_pg;
modify_size_args(sub_args, src_pg, diff_pg); modify_size_args(sub_args, src_pg, diff_pg);


+ 4
- 0
dnn/src/cuda/convolution/backward_filter/matmul.cpp View File

@@ -19,6 +19,10 @@ using namespace cuda;


bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.src_layout->dtype == args.src_layout->dtype &&
args.diff_layout->dtype == dtype::BFloat16()) {
return false;
}
auto &&fm = args.grad_filter_meta; auto &&fm = args.grad_filter_meta;
return fm.format == Param::Format::NCHW && return fm.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&


+ 4
- 0
dnn/src/cuda/convolution/helper.cpp View File

@@ -16,6 +16,10 @@ using namespace cuda;
using namespace convolution; using namespace convolution;


bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}


// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN
// on Tegra K1. // on Tegra K1.


+ 1
- 0
dnn/src/cuda/convolution/helper.h View File

@@ -25,6 +25,7 @@ namespace convolution {
struct ForwardSizeArgs { struct ForwardSizeArgs {
HandleImpl *handle; HandleImpl *handle;
const TensorLayout *src_layout; const TensorLayout *src_layout;
const TensorLayout *filter_layout;
CanonizedFilterMeta filter_meta; CanonizedFilterMeta filter_meta;
const TensorLayout *dst_layout; const TensorLayout *dst_layout;
}; };


+ 54
- 28
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_out grad, _megdnn_tensor_out grad,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); AlgoBase::ExecArgs args(this, filter, diff, grad, workspace);
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout);
auto algo = get_algorithm(this, filter.layout, args.filter_meta,
diff.layout, grad.layout);
algo->check_workspace(args, workspace).exec(args); algo->check_workspace(args, workspace).exec(args);
} }


@@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
auto fm = check_layout_fwd(grad, filter, diff); auto fm = check_layout_fwd(grad, filter, diff);
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(filter, fm, diff, grad,
workspace_limit_in_bytes, reproducible);
} }


ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
AlgoBase::SizeArgs args(this, filter, diff, grad);
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad);


if (args.filter_meta.group > 1 && if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_reproducible(
@@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
args = orig_args; args = orig_args;
} }


if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (args.filter_layout->dtype.enumv() !=
DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_data");
}
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_data");
}
} }
} }


@@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
const TensorLayout &diff, const TensorLayout &diff,
const TensorLayout &grad) { const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
return get_algorithm(this, args.filter_meta, diff, grad)->
return get_algorithm(this, filter, args.filter_meta, diff, grad)->
get_workspace_in_bytes(args); get_workspace_in_bytes(args);
} }


@@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
AlgoBase::ExecArgs args(this, src, diff, grad, workspace); AlgoBase::ExecArgs args(this, src, diff, grad, workspace);
auto algo = get_algorithm(this, src.layout, diff.layout, auto algo = get_algorithm(this, src.layout, diff.layout,
args.grad_filter_meta);
grad.layout, args.grad_filter_meta);
algo->check_workspace(args, workspace).exec(args); algo->check_workspace(args, workspace).exec(args);
} }


@@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
auto fm = check_layout_fwd(src, grad, diff); auto fm = check_layout_fwd(src, grad, diff);
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes,
reproducible);
return get_algorithm_heuristic(src, diff, grad, fm,
workspace_limit_in_bytes, reproducible);
} }


ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad);
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, bool reproducible) {
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta);


if (args.grad_filter_meta.group > 1 && if (args.grad_filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_reproducible( sm_algo_pack.chanwise.is_available_reproducible(
@@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
args = orig_args; args = orig_args;
} }


if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) {
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args,
workspace_limit_in_bytes, "cuda conv bwd_filter");
}
} else { } else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
if (reproducible) {
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
} else {
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>(
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes,
"cuda conv bwd_filter");
}
} }
} }


@@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
const TensorLayout &diff, const TensorLayout &diff,
const TensorLayout &grad) { const TensorLayout &grad) {
AlgoBase::SizeArgs args(this, src, diff, grad); AlgoBase::SizeArgs args(this, src, diff, grad);
return get_algorithm(this, src, diff, args.grad_filter_meta)->
return get_algorithm(this, src, diff, grad, args.grad_filter_meta)->
get_workspace_in_bytes(args); get_workspace_in_bytes(args);
} }




+ 9
- 6
dnn/src/cuda/convolution/opr_impl.h View File

@@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout& grad, const TensorLayout& grad,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter, size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
@@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
class AlgoChanwise; class AlgoChanwise;
class AlgoChanwiseSmall; class AlgoChanwiseSmall;
class AlgoGroupConvGeneral; class AlgoGroupConvGeneral;
class AlgoBFloat16;


class AlgoPack; class AlgoPack;


@@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
bool reproducible) override; bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src, Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& diff,
const CanonizedFilterMeta& grad,
const TensorLayout& gradk,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
@@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
class AlgoMatmul; class AlgoMatmul;
class AlgoChanwise; class AlgoChanwise;
class AlgoGroupConvGeneral; class AlgoGroupConvGeneral;
class AlgoBFloat16;


class AlgoPack; class AlgoPack;




+ 1
- 1
dnn/src/cuda/convolution3d/forward/algo.h View File

@@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
const CanonizedFilterMeta &filter, const CanonizedFilterMeta &filter,
const TensorLayout &dst); const TensorLayout &dst);
}; };
struct ExecArgs: public SizeArgs {
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *filter_tensor, *dst_tensor; const TensorND *src_tensor, *filter_tensor, *dst_tensor;
Workspace workspace; Workspace workspace;




+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 17
- 0
dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu View File

@@ -0,0 +1,17 @@
/**
* \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif

+ 18
- 0
dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu View File

@@ -0,0 +1,18 @@
/**
* \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_special_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#include "../special_kerns.inl"
INST(::megdnn::dtype::BFloat16)
#undef INST
}
}
#endif

+ 3
- 0
dnn/src/cuda/elemwise_helper.cpp View File

@@ -141,6 +141,9 @@ INST_FOR_CTYPE
#define ct dt_float16 #define ct dt_float16
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8 #define ct dt_int8
INST_FOR_CTYPE INST_FOR_CTYPE
#undef ct #undef ct


+ 12
- 0
dnn/src/cuda/elemwise_helper.cuh View File

@@ -68,6 +68,17 @@ namespace elemwise_intl {
return t; return t;
} }


struct __attribute__((aligned(8))) bhalf4 {
dt_bfloat16 x, y, z, w;
};

__device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y,
dt_bfloat16 z, dt_bfloat16 w) {
bhalf4 t;
t.x = x, t.y = y, t.z = z, t.w = w;
return t;
}

#define INST(_ctype, _vect_type) \ #define INST(_ctype, _vect_type) \
template <> \ template <> \
class VectTypeTrait<_ctype> { \ class VectTypeTrait<_ctype> { \
@@ -87,6 +98,7 @@ namespace elemwise_intl {
INST(dt_uint8, uchar4); INST(dt_uint8, uchar4);
INST(dt_float32, float4); INST(dt_float32, float4);
INST(dt_float16, half4); INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4); INST(dt_int32, int4);
INST(dt_int16, short4); INST(dt_int16, short4);
#undef as_raw #undef as_raw


+ 5
- 0
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu View File

@@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) {
__trap(); __trap();
((int*)0)[0] = 1; ((int*)0)[0] = 1;
} }

__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) {
__trap();
((int*)0)[0] = 1;
}
#endif #endif


__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { __device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) {


+ 4
- 0
dnn/src/cuda/matrix_mul/algos.cpp View File

@@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&cublas_lt); all_algos.push_back(&cublas_lt);
#endif #endif
all_algos.push_back(&naive); all_algos.push_back(&naive);
#if !MEGDNN_DISABLE_FLOAT16
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas);
all_algos.push_back(cublas_bfloat16.get());
#endif
} }


MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;


+ 22
- 1
dnn/src/cuda/matrix_mul/algos.h View File

@@ -15,6 +15,7 @@
#include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h"


#include <cuda.h> #include <cuda.h>
#include <memory>
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
@@ -140,6 +141,24 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
}; };


#if !MEGDNN_DISABLE_FLOAT16
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase {
public:
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }

private:
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
std::string m_name;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
SizeArgs float_args(const SizeArgs& args) const;
};
#endif

class MatrixMulForwardImpl::AlgoPack { class MatrixMulForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete; AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete; AlgoPack& operator=(const AlgoPack&) = delete;
@@ -154,7 +173,9 @@ public:
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
AlgoCuBlasLt cublas_lt; AlgoCuBlasLt cublas_lt;
#endif #endif

#if !MEGDNN_DISABLE_FLOAT16
std::unique_ptr<AlgoBFloat16> cublas_bfloat16;
#endif
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
}; };




+ 91
- 0
dnn/src/cuda/matrix_mul/bfloat16.cpp View File

@@ -0,0 +1,91 @@
/**
* \file dnn/src/cuda/matrix_mul/bfloat16.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16(
MatrixMulForwardImpl::AlgoBase* algorithm)
: m_algorithm(algorithm) {
megdnn_assert_internal(algorithm);
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name());
}

MatrixMulForwardImpl::AlgoBase::SizeArgs
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const {
auto new_args = args;
auto change_dtype = [](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
}
};
change_dtype(new_args.layout_a);
change_dtype(new_args.layout_b);
change_dtype(new_args.layout_c);
return new_args;
}

bool MatrixMulForwardImpl::AlgoBFloat16::is_available(
const SizeArgs& args) const {
auto fargs = float_args(args);
return args.layout_a.dtype == dtype::BFloat16() &&
m_algorithm->is_available(fargs);
}

WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto fargs = float_args(args);
SmallVector<size_t> sizes;
auto get_workspace = [&sizes](const TensorLayout& src) {
TensorLayout dst = src;
if (dst.dtype == dtype::BFloat16()) {
dst.dtype = dtype::Float32();
sizes.push_back(dst.span().dist_byte());
}
};
get_workspace(args.layout_a);
get_workspace(args.layout_b);
get_workspace(args.layout_c);
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs));
return {ptr, std::move(sizes)};
}

size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
TensorND a = args.tensor_a;
TensorND b = args.tensor_b;
TensorND c = args.tensor_c;
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
args.opr->handle(), &bundle);
ctypecvt.src_to_comp_type(args.tensor_a, a)
.src_to_comp_type(args.tensor_b, b)
.src_to_comp_type(args.tensor_c, c);
{
auto matmul_opr =
args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
matmul_opr->execution_policy() = {m_algorithm};
matmul_opr->exec(a, b, c, ctypecvt.workspace());
}
ctypecvt.comp_to_dst_type(c, args.tensor_c);
}

// vim: syntax=cpp.doxygen

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save