@@ -29,6 +29,7 @@ | |||||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y | #define MEGDNN_FLOAT16_SELECT(_x, _y) _y | ||||
#else | #else | ||||
#include "megdnn/dtype/half.hpp" | #include "megdnn/dtype/half.hpp" | ||||
#include "megdnn/dtype/bfloat16.hpp" | |||||
#define MEGDNN_INC_FLOAT16(_x) _x | #define MEGDNN_INC_FLOAT16(_x) _x | ||||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x | #define MEGDNN_FLOAT16_SELECT(_x, _y) _x | ||||
#endif | #endif | ||||
@@ -49,6 +50,7 @@ namespace megdnn { | |||||
cb(IntB4) \ | cb(IntB4) \ | ||||
cb(Byte) \ | cb(Byte) \ | ||||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | MEGDNN_INC_FLOAT16(cb(Float16)) \ | ||||
MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||||
cb(UintB4) \ | cb(UintB4) \ | ||||
/*! | /*! | ||||
@@ -62,6 +64,7 @@ namespace megdnn { | |||||
cb(Int32) \ | cb(Int32) \ | ||||
cb(Byte) \ | cb(Byte) \ | ||||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | MEGDNN_INC_FLOAT16(cb(Float16)) \ | ||||
MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||||
/*! | /*! | ||||
* \brief iterate through each fractional byte dtype | * \brief iterate through each fractional byte dtype | ||||
@@ -101,6 +104,7 @@ namespace megdnn { | |||||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | ||||
cb(::megdnn::dtype::Float32) \ | cb(::megdnn::dtype::Float32) \ | ||||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | ||||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ | |||||
/*! | /*! | ||||
* \brief iterate through each dtype object that can be involved in integer | * \brief iterate through each dtype object that can be involved in integer | ||||
@@ -345,6 +349,7 @@ typedef int16_t dt_int16; | |||||
typedef int8_t dt_int8; | typedef int8_t dt_int8; | ||||
typedef uint8_t dt_uint8; | typedef uint8_t dt_uint8; | ||||
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | ||||
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||||
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 | #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 | ||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
@@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||||
Float16, | Float16, | ||||
#endif | #endif | ||||
UintB4 = 10, | UintB4 = 10, | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
BFloat16 = 11, | |||||
#endif | |||||
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | ||||
#define D(_name) _name, | #define D(_name) _name, | ||||
@@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | |||||
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | ||||
std::numeric_limits<dt_float16>::lowest(), | std::numeric_limits<dt_float16>::lowest(), | ||||
std::numeric_limits<dt_float16>::max())); | std::numeric_limits<dt_float16>::max())); | ||||
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED, | |||||
std::numeric_limits<dt_bfloat16>::lowest(), | |||||
std::numeric_limits<dt_bfloat16>::max())); | |||||
template <> | template <> | ||||
struct DTypeTrait<dtype::Byte> { | struct DTypeTrait<dtype::Byte> { | ||||
@@ -50,167 +50,7 @@ | |||||
#include <hip/hip_fp16.h> | #include <hip/hip_fp16.h> | ||||
#endif | #endif | ||||
/// Combined gcc version number. | |||||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||||
//check C++11 language features | |||||
#if defined(__clang__) //clang | |||||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif*/ | |||||
#elif defined(__GNUC__) //gcc | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(_MSC_VER) //Visual C++ | |||||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#define HALF_POP_WARNINGS 1 | |||||
#pragma warning(push) | |||||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||||
#endif | |||||
//check C++11 library features | |||||
#include <utility> | |||||
#if defined(_LIBCPP_VERSION) //libc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(__GLIBCXX__) //libstdc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifdef __clang__ | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#else | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||||
#if _CPPLIB_VER >= 520 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#if _CPPLIB_VER >= 610 | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#undef HALF_GNUC_VERSION | |||||
//support constexpr | |||||
#if HALF_ENABLE_CPP11_CONSTEXPR | |||||
#define HALF_CONSTEXPR constexpr | |||||
#define HALF_CONSTEXPR_CONST constexpr | |||||
#else | |||||
#define HALF_CONSTEXPR | |||||
#define HALF_CONSTEXPR_CONST const | |||||
#endif | |||||
//support noexcept | |||||
#if HALF_ENABLE_CPP11_NOEXCEPT | |||||
#define HALF_NOEXCEPT noexcept | |||||
#define HALF_NOTHROW noexcept | |||||
#else | |||||
#define HALF_NOEXCEPT | |||||
#define HALF_NOTHROW throw() | |||||
#endif | |||||
#include <algorithm> | |||||
#include <limits> | |||||
#include <climits> | |||||
#include <cmath> | |||||
#include <cstring> | |||||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#include <type_traits> | |||||
#endif | |||||
#if HALF_ENABLE_CPP11_CSTDINT | |||||
#include <cstdint> | |||||
#endif | |||||
#if HALF_ENABLE_CPP11_HASH | |||||
#include <functional> | |||||
#endif | |||||
#include "megdnn/dtype/half_common_prologue.h" | |||||
/// Default rounding mode. | /// Default rounding mode. | ||||
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as | /// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as | ||||
@@ -3141,16 +2981,7 @@ namespace std | |||||
#endif | #endif | ||||
} | } | ||||
#undef HALF_CONSTEXPR | |||||
#undef HALF_CONSTEXPR_CONST | |||||
#undef HALF_NOEXCEPT | |||||
#undef HALF_NOTHROW | |||||
#ifdef HALF_POP_WARNINGS | |||||
#pragma warning(pop) | |||||
#undef HALF_POP_WARNINGS | |||||
#endif | |||||
#include "megdnn/dtype/half_common_epilogue.h" | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,48 @@ | |||||
/** | |||||
* half - IEEE 754-based half-precision floating point library. | |||||
* | |||||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||||
* | |||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
* Software is furnished to do so, subject to the following conditions: | |||||
* | |||||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
* | |||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* | |||||
* Version 1.11.0 | |||||
* \file | |||||
* Main header file for half precision functionality. | |||||
* | |||||
* -------------------------------------------------------------------------- | |||||
* \file include/megdnn/dtype/half_common_epilogue.h | |||||
* | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* | |||||
* This file has been modified by Megvii ("Megvii Modifications"). | |||||
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||||
* | |||||
* -------------------------------------------------------------------------- | |||||
*/ | |||||
#undef HALF_CONSTEXPR | |||||
#undef HALF_CONSTEXPR_CONST | |||||
#undef HALF_NOEXCEPT | |||||
#undef HALF_NOTHROW | |||||
#ifdef HALF_POP_WARNINGS | |||||
#pragma warning(pop) | |||||
#undef HALF_POP_WARNINGS | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,202 @@ | |||||
/** | |||||
* half - IEEE 754-based half-precision floating point library. | |||||
* | |||||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||||
* | |||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
* Software is furnished to do so, subject to the following conditions: | |||||
* | |||||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
* | |||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* | |||||
* Version 1.11.0 | |||||
* \file | |||||
* Main header file for half precision functionality. | |||||
* | |||||
* -------------------------------------------------------------------------- | |||||
* \file dnn/include/megdnn/dtype/half_common_prologue.h | |||||
* | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* | |||||
* This file has been modified by Megvii ("Megvii Modifications"). | |||||
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||||
* | |||||
* -------------------------------------------------------------------------- | |||||
*/ | |||||
#include "megdnn/arch.h" | |||||
/// Combined gcc version number. | |||||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||||
//check C++11 language features | |||||
#if defined(__clang__) //clang | |||||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif*/ | |||||
#elif defined(__GNUC__) //gcc | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(_MSC_VER) //Visual C++ | |||||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#define HALF_POP_WARNINGS 1 | |||||
#pragma warning(push) | |||||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||||
#endif | |||||
//check C++11 library features | |||||
#include <utility> | |||||
#if defined(_LIBCPP_VERSION) //libc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(__GLIBCXX__) //libstdc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifdef __clang__ | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#else | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||||
#if _CPPLIB_VER >= 520 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#if _CPPLIB_VER >= 610 | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#undef HALF_GNUC_VERSION | |||||
//support constexpr | |||||
#if HALF_ENABLE_CPP11_CONSTEXPR | |||||
#define HALF_CONSTEXPR constexpr | |||||
#define HALF_CONSTEXPR_CONST constexpr | |||||
#else | |||||
#define HALF_CONSTEXPR | |||||
#define HALF_CONSTEXPR_CONST const | |||||
#endif | |||||
//support noexcept | |||||
#if HALF_ENABLE_CPP11_NOEXCEPT | |||||
#define HALF_NOEXCEPT noexcept | |||||
#define HALF_NOTHROW noexcept | |||||
#else | |||||
#define HALF_NOEXCEPT | |||||
#define HALF_NOTHROW throw() | |||||
#endif | |||||
#include <algorithm> | |||||
#include <limits> | |||||
#include <climits> | |||||
#include <cmath> | |||||
#include <cstring> | |||||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#include <type_traits> | |||||
#endif | |||||
#if HALF_ENABLE_CPP11_CSTDINT | |||||
#include <cstdint> | |||||
#endif | |||||
#if HALF_ENABLE_CPP11_HASH | |||||
#include <functional> | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -30,7 +30,7 @@ def main(): | |||||
w('// generated by gen_cond_take_kern_impls.py') | w('// generated by gen_cond_take_kern_impls.py') | ||||
w('#include "../kern.inl"') | w('#include "../kern.inl"') | ||||
w('') | w('') | ||||
if dtype == 'dt_float16': | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | w('#if !MEGDNN_DISABLE_FLOAT16') | ||||
w('namespace megdnn {') | w('namespace megdnn {') | ||||
w('namespace cuda {') | w('namespace cuda {') | ||||
@@ -48,7 +48,7 @@ def main(): | |||||
w('} // cond_take') | w('} // cond_take') | ||||
w('} // cuda') | w('} // cuda') | ||||
w('} // megdnn') | w('} // megdnn') | ||||
if dtype == 'dt_float16': | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#endif') | w('#endif') | ||||
print('generated {}'.format(fname)) | print('generated {}'.format(fname)) | ||||
@@ -34,7 +34,7 @@ def main(): | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_kern_impls.py') | w('// generated by gen_elemwise_kern_impls.py') | ||||
if ctype == 'dt_float16': | |||||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | w('#if !MEGDNN_DISABLE_FLOAT16') | ||||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | ||||
@@ -42,7 +42,7 @@ def main(): | |||||
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | ||||
w('#include "../kern_impl.inl"') | w('#include "../kern_impl.inl"') | ||||
if ctype == 'dt_float16': | |||||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
w('#endif') | w('#endif') | ||||
print('generated {}'.format(fname)) | print('generated {}'.format(fname)) | ||||
@@ -30,14 +30,14 @@ def main(): | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_special_kern_impls.py') | w('// generated by gen_elemwise_special_kern_impls.py') | ||||
if dtype == 'dt_float16': | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | w('#if !MEGDNN_DISABLE_FLOAT16') | ||||
w('#include "../special_kerns.inl"') | w('#include "../special_kerns.inl"') | ||||
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | ||||
w('#undef INST') | w('#undef INST') | ||||
w('}') | w('}') | ||||
w('}') | w('}') | ||||
if dtype == 'dt_float16': | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#endif') | w('#endif') | ||||
print('generated {}'.format(fname)) | print('generated {}'.format(fname)) | ||||
@@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), | |||||
'dt_int8': ('Int8', 'INT'), | 'dt_int8': ('Int8', 'INT'), | ||||
'dt_int16': ('Int16', 'INT'), | 'dt_int16': ('Int16', 'INT'), | ||||
'dt_float32': ('Float32', 'FLOAT'), | 'dt_float32': ('Float32', 'FLOAT'), | ||||
'dt_float16': ('Float16', 'FLOAT') | |||||
'dt_float16': ('Float16', 'FLOAT'), | |||||
'dt_bfloat16': ('BFloat16', 'FLOAT') | |||||
} | } | ||||
MODES = { | MODES = { | ||||
@@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||||
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
|| src.enumv() == DTypeEnum::Float16 | || src.enumv() == DTypeEnum::Float16 | ||||
|| src.enumv() == DTypeEnum::BFloat16 | |||||
#endif | #endif | ||||
, | |||||
"ComputeMode::FLOAT32 is only available for Float16 " | |||||
, | |||||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||||
"input / output."); | "input / output."); | ||||
} | } | ||||
@@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, | |||||
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
|| filter.enumv() == DTypeEnum::Float16 | || filter.enumv() == DTypeEnum::Float16 | ||||
|| filter.enumv() == DTypeEnum::BFloat16 | |||||
#endif | #endif | ||||
, | |||||
"ComputeMode::FLOAT32 is only available for Float16 " | |||||
, | |||||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||||
"input / output."); | "input / output."); | ||||
} | } | ||||
@@ -87,7 +87,8 @@ namespace megdnn { | |||||
//! define kernel for all float types | //! define kernel for all float types | ||||
#define DEF_KERN_FLOAT(_mode, _imp) \ | #define DEF_KERN_FLOAT(_mode, _imp) \ | ||||
DEF_KERN(dt_float32, _mode, _imp); \ | DEF_KERN(dt_float32, _mode, _imp); \ | ||||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) | |||||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ | |||||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) | |||||
//! define kernel for all int types | //! define kernel for all int types | ||||
#define DEF_KERN_INT(_mode, _imp) \ | #define DEF_KERN_INT(_mode, _imp) \ | ||||
@@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||||
C = TensorLayout(TensorShape({A0, B1}), C.dtype); | C = TensorLayout(TensorShape({A0, B1}), C.dtype); | ||||
} else { | } else { | ||||
auto do_deduce = [&](size_t pack_size) { | auto do_deduce = [&](size_t pack_size) { | ||||
megdnn_assert( | |||||
A.ndim == 4 && B.ndim == 3, | |||||
"matmul requires input dimension to be A(4), B(3); get: %s %s", | |||||
A.TensorShape::to_string().c_str(), | |||||
B.TensorShape::to_string().c_str()); | |||||
megdnn_assert(A.ndim == 4 && B.ndim == 3, | |||||
"matmul requires input dimension to be A(4), B(3); " | |||||
"get: %s %s", | |||||
A.TensorShape::to_string().c_str(), | |||||
B.TensorShape::to_string().c_str()); | |||||
A0 = A.shape[0]; | A0 = A.shape[0]; | ||||
A1 = A.shape[1]; | A1 = A.shape[1]; | ||||
B0 = B.shape[0]; | B0 = B.shape[0]; | ||||
@@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||||
std::swap(A0, A1); | std::swap(A0, A1); | ||||
if (m_param.transposeB) | if (m_param.transposeB) | ||||
std::swap(B0, B1); | std::swap(B0, B1); | ||||
megdnn_assert( | |||||
A1 == B0, | |||||
"shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), " | |||||
"(transposed) B is (%zu,%zu,4)", | |||||
A0, A1, B0, B1); | |||||
megdnn_assert(A1 == B0, | |||||
"shape mismatch in matmal: (transposed) A is " | |||||
"(%zu,%zu,4,4), " | |||||
"(transposed) B is (%zu,%zu,4)", | |||||
A0, A1, B0, B1); | |||||
C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); | C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); | ||||
}; | }; | ||||
do_deduce(pack_size(param().format)); | do_deduce(pack_size(param().format)); | ||||
@@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
} | } | ||||
megdnn_assert(param().compute_mode != | megdnn_assert(param().compute_mode != | ||||
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | ||||
|| A.dtype == dtype::Float16()), | |||||
"ComputeMode::FLOAT32 is only available for Float16 " | |||||
|| A.dtype == dtype::Float16() || | |||||
A.dtype == dtype::BFloat16()), | |||||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||||
"input / output."); | "input / output."); | ||||
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); | auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); | ||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
@@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> { | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct RoundingConverter<half_bfloat16::bfloat16> { | |||||
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( | |||||
float x) const { | |||||
return static_cast<half_bfloat16::bfloat16>(x); | |||||
} | |||||
}; | |||||
#endif // #ifdef MEGDNN_DISABLE_FLOAT16 | #endif // #ifdef MEGDNN_DISABLE_FLOAT16 | ||||
template <> | template <> | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "megdnn/handle.h" | #include "megdnn/handle.h" | ||||
#include "megdnn/thin/small_vector.h" | #include "megdnn/thin/small_vector.h" | ||||
#include "megdnn/oprs/general.h" | |||||
#include "src/common/hash_ct.h" | #include "src/common/hash_ct.h" | ||||
#include "src/common/utils.cuh" | #include "src/common/utils.cuh" | ||||
@@ -548,6 +549,59 @@ public: | |||||
std::string to_string() const; | std::string to_string() const; | ||||
}; | }; | ||||
/**! | |||||
* \brief helpers for oprs using typecvt between comp_type and dst_type | |||||
* \tparam SrcType src type | |||||
* \tparam CompType compute type, such as fp32 for conv | |||||
* \tparam DstType dst type | |||||
*/ | |||||
template <typename SrcType, typename CompType, typename DstType = SrcType> | |||||
struct CompTypeCvter { | |||||
std::unique_ptr<TypeCvt> m_cvt_opr; | |||||
WorkspaceBundle* m_workspace_bundle; | |||||
size_t m_workspace_idx; | |||||
CompTypeCvter(Handle* handle, WorkspaceBundle* bundle) | |||||
: m_workspace_bundle(bundle), m_workspace_idx(0) { | |||||
megdnn_assert( | |||||
(DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv && | |||||
DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv), | |||||
"SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is " | |||||
"not " | |||||
"supportted.", | |||||
SrcType().name(), CompType().name(), DstType().name(), | |||||
CompType().name()); | |||||
m_cvt_opr = handle->create_operator<TypeCvt>(); | |||||
} | |||||
//! Convert tensor dtype from SrcType to CompType. | |||||
CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) { | |||||
if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) { | |||||
if (!comp.layout.dtype.valid() || | |||||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | |||||
comp.layout.dtype = CompType(); | |||||
comp.layout.init_contiguous_stride(); | |||||
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); | |||||
if (src.layout.ndim) { | |||||
m_cvt_opr->exec(src, comp); | |||||
} | |||||
} | |||||
} | |||||
return *this; | |||||
} | |||||
//! Convert tensor dtype from CompType to DstType. | |||||
CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) { | |||||
megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv); | |||||
if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) { | |||||
m_cvt_opr->exec(comp, dst); | |||||
} | |||||
return *this; | |||||
} | |||||
Workspace workspace() { | |||||
return m_workspace_bundle->get_workspace(m_workspace_idx); | |||||
} | |||||
}; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | ||||
if (param().format == param::WarpPerspective::Format::NCHW) { | if (param().format == param::WarpPerspective::Format::NCHW) { | ||||
megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
src.dtype.enumv() == DTypeEnum::Float16, | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||||
"WarpPerspective NCHW input dtype should be " | |||||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16", "") "."); | |||||
megdnn_assert( | |||||
src.dtype.enumv() == DTypeEnum::Float32 || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
(src.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::Int8 || | |||||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||||
"WarpPerspective NCHW input dtype should be " | |||||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", "") "."); | |||||
megdnn_assert( | megdnn_assert( | ||||
(src.dtype.category() == DTypeCategory::FLOAT && | (src.dtype.category() == DTypeCategory::FLOAT && | ||||
(src.dtype == mat.dtype || | (src.dtype == mat.dtype || | ||||
@@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||||
param::WarpPerspective::BorderMode::ISOLATED); | param::WarpPerspective::BorderMode::ISOLATED); | ||||
} else { | } else { | ||||
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); | megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); | ||||
megdnn_assert(src.dtype == dtype::Float32() || | |||||
MEGDNN_FLOAT16_SELECT( | |||||
src.dtype == dtype::Float16(), false) || | |||||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||||
"WarpPerspective NHWCD4 input dtype should be " | |||||
"Float32" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16", "") ",QunatizedS8, Quantized8Asymm."); | |||||
megdnn_assert( | |||||
src.dtype == dtype::Float32() || | |||||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || | |||||
src.dtype == dtype::BFloat16()), | |||||
false) || | |||||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||||
"WarpPerspective NHWCD4 input dtype should be " | |||||
"Float32" MEGDNN_FLOAT16_SELECT( | |||||
"/Float16/BFloat16", | |||||
"") ",QunatizedS8, Quantized8Asymm."); | |||||
megdnn_assert( | megdnn_assert( | ||||
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | ||||
"The input to WarpPerspective is in NHWCD4 format, in this " | "The input to WarpPerspective is in NHWCD4 format, in this " | ||||
@@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||||
} | } | ||||
} | } | ||||
void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad, | |||||
size_t workspace_in_bytes) | |||||
{ | |||||
void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(grad, mat, diff); | check_layout_fwd(grad, mat, diff); | ||||
megdnn_assert(grad.dtype == dtype::Float32(), | |||||
"Backward WarpPerspective only supports Float32."); | |||||
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||||
|| grad.dtype == dtype::BFloat16()), | |||||
"Backward WarpPerspective only supports Float32/BFloat16."); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); | auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); | ||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
} | } | ||||
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad, | |||||
size_t workspace_in_bytes) | |||||
{ | |||||
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, | |||||
const TensorLayout& mat, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(src, mat, diff); | check_layout_fwd(src, mat, diff); | ||||
megdnn_assert_eq_layout(mat, grad); | megdnn_assert_eq_layout(mat, grad); | ||||
megdnn_assert(grad.dtype == dtype::Float32(), | |||||
"Backward WarpPerspective only supports Float32."); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, | |||||
mat, diff, grad); | |||||
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||||
|| grad.dtype == dtype::BFloat16()), | |||||
"Backward WarpPerspective only supports Float32/BFloat16."); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(src, mat, diff, grad); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
} | } | ||||
@@ -0,0 +1,29 @@ | |||||
/** | |||||
* \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_cond_take_kern_impls.py | |||||
#include "../kern.inl" | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace cond_take { | |||||
inst_genidx(::megdnn::dtype::BFloat16) | |||||
#undef inst_genidx | |||||
inst_copy(::megdnn::dtype::BFloat16) | |||||
#undef inst_copy | |||||
#undef inst_copy_ | |||||
} // cond_take | |||||
} // cuda | |||||
} // megdnn | |||||
#endif |
@@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul | non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul | ||||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 | non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 | ||||
algo_size = all_algos.size(); | |||||
for (size_t i = 0; i < algo_size; ++i) { | |||||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||||
all_algos.push_back(bfloat16_refhold.back().get()); | |||||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||||
} | |||||
size_t all_algo_size = all_algos.size(); | size_t all_algo_size = all_algos.size(); | ||||
#if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
fill_imma_algos(); | fill_imma_algos(); | ||||
@@ -499,6 +499,28 @@ private: | |||||
}; | }; | ||||
#endif | #endif | ||||
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||||
public: | |||||
AlgoBFloat16(AlgoBase* impl); | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
private: | |||||
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | |||||
TensorLayout& fsrc, TensorLayout& ffilter, | |||||
TensorLayout& fbias, TensorLayout& fz, | |||||
TensorLayout& fdst) const; | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
AlgoBase* m_impl; | |||||
std::string m_name; | |||||
}; | |||||
class ConvBiasForwardImpl::AlgoPack { | class ConvBiasForwardImpl::AlgoPack { | ||||
AlgoPack(const AlgoPack&) = delete; | AlgoPack(const AlgoPack&) = delete; | ||||
AlgoPack& operator=(const AlgoPack&) = delete; | AlgoPack& operator=(const AlgoPack&) = delete; | ||||
@@ -508,7 +530,8 @@ public: | |||||
std::vector<AlgoBase*> all_algos, | std::vector<AlgoBase*> all_algos, | ||||
//! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
non_cudnn_algos; | |||||
non_cudnn_algos, | |||||
bfloat16_algos; | |||||
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | ||||
std::vector<AlgoCUDNNConv> cudnn_convs; | std::vector<AlgoCUDNNConv> cudnn_convs; | ||||
AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
@@ -531,6 +554,7 @@ public: | |||||
int8_chwn4_imma_unroll_width; | int8_chwn4_imma_unroll_width; | ||||
#endif | #endif | ||||
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | ||||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | ||||
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | ||||
@@ -0,0 +1,120 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/bfloat16.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/conv_bias/algo.h" | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/utils.cuh" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace conv_bias; | |||||
ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||||
ConvBiasForwardImpl::AlgoBase* algorithm) | |||||
: m_impl(algorithm) { | |||||
megdnn_assert_internal(algorithm); | |||||
m_name = ssprintf("BFLOAT16:%s", m_impl->name()); | |||||
} | |||||
ConvBiasForwardImpl::AlgoBase::SizeArgs | |||||
ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||||
const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc, | |||||
TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz, | |||||
TensorLayout& fdst) const { | |||||
fsrc = *args.src_layout; | |||||
ffilter = *args.filter_layout; | |||||
fbias = *args.bias_layout; | |||||
fz = *args.z_layout; | |||||
fdst = *args.dst_layout; | |||||
auto change_dtype = [](TensorLayout& layout) { | |||||
if (layout.dtype == dtype::BFloat16()) { | |||||
layout.dtype = dtype::Float32(); | |||||
} | |||||
}; | |||||
change_dtype(fsrc); | |||||
change_dtype(ffilter); | |||||
change_dtype(fbias); | |||||
change_dtype(fz); | |||||
change_dtype(fdst); | |||||
opr->param() = args.opr->param(); | |||||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
opr->execution_policy() = {m_impl}; | |||||
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | |||||
} | |||||
bool ConvBiasForwardImpl::AlgoBFloat16::is_available( | |||||
const SizeArgs& args) const { | |||||
TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||||
SizeArgs fargs = float_args( | |||||
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||||
ffilter, fbias, fz, fdst); | |||||
return args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16() && | |||||
m_impl->is_available(fargs); | |||||
} | |||||
WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||||
void* ptr, const SizeArgs& args) const { | |||||
TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||||
SizeArgs fargs = float_args( | |||||
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||||
ffilter, fbias, fz, fdst); | |||||
SmallVector<size_t> sizes; | |||||
auto get_workspace = [&sizes](const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
if (src.dtype != dst.dtype) { | |||||
sizes.push_back(dst.span().dist_byte()); | |||||
} | |||||
}; | |||||
get_workspace(*args.src_layout, fsrc); | |||||
get_workspace(*args.filter_layout, ffilter); | |||||
get_workspace(*args.bias_layout, fbias); | |||||
get_workspace(*args.z_layout, fz); | |||||
get_workspace(*args.dst_layout, fdst); | |||||
sizes.push_back(m_impl->get_workspace_in_bytes(fargs)); | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||||
TensorND fsrc_tensor = *args.src_tensor; | |||||
TensorND ffilter_tensor = *args.filter_tensor; | |||||
TensorND fbias_tensor = *args.bias_tensor; | |||||
TensorND fz_tensor = *args.z_tensor; | |||||
TensorND fdst_tensor = *args.dst_tensor; | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||||
{ | |||||
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||||
.src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||||
.src_to_comp_type(*args.bias_tensor, fbias_tensor) | |||||
.src_to_comp_type(*args.z_tensor, fz_tensor) | |||||
.src_to_comp_type(*args.dst_tensor, fdst_tensor); | |||||
} | |||||
{ | |||||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||||
convbias_opr->param() = args.opr->param(); | |||||
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
convbias_opr->execution_policy() = {m_impl}; | |||||
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | |||||
fdst_tensor, cvter.workspace()); | |||||
} | |||||
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -20,6 +20,10 @@ using namespace conv_bias; | |||||
bool ConvBiasForwardImpl::AlgoChanwise::is_available( | bool ConvBiasForwardImpl::AlgoChanwise::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
return false; | return false; | ||||
@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||||
bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( | bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
return false; | return false; | ||||
#if CUDA_VERSION < 9000 | #if CUDA_VERSION < 9000 | ||||
@@ -23,6 +23,10 @@ using namespace conv_bias; | |||||
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
if (args.bias_layout->ndim == 0 || | if (args.bias_layout->ndim == 0 || | ||||
args.bias_layout->eq_shape(*args.dst_layout)) | args.bias_layout->eq_shape(*args.dst_layout)) | ||||
return false; | return false; | ||||
@@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl) | |||||
bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( | bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) | if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) | ||||
return false; | return false; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
@@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, | |||||
namespace conv_bias { | namespace conv_bias { | ||||
bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | ||||
// on Tegra K1. | // on Tegra K1. | ||||
if (args.handle->is_tegra_k1()) | if (args.handle->is_tegra_k1()) | ||||
@@ -20,6 +20,10 @@ using namespace cuda; | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { | bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
return false; | return false; | ||||
@@ -9,6 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/cuda/conv_bias/opr_impl.h" | #include "src/cuda/conv_bias/opr_impl.h" | ||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/conv_bias/helper.h" | #include "src/cuda/conv_bias/helper.h" | ||||
#include "src/cuda/conv_bias/algo.h" | #include "src/cuda/conv_bias/algo.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
@@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
conv_args = orig_args; | conv_args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda convbias fwd"); | |||||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda convbias fwd"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda convbias fwd"); | |||||
} | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda convbias fwd"); | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda convbias fwd"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda convbias fwd"); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -57,6 +57,7 @@ public: | |||||
class AlgoInt8NCHW4IMMAImplicitGemm; | class AlgoInt8NCHW4IMMAImplicitGemm; | ||||
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | ||||
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | ||||
class AlgoBFloat16; | |||||
class AlgoPack; | class AlgoPack; | ||||
@@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||||
// add gconv algos by AlgoGroupConvGeneral | // add gconv algos by AlgoGroupConvGeneral | ||||
auto all_algos_data = all_algos.data(); | auto all_algos_data = all_algos.data(); | ||||
for (size_t i = 2; i < all_algos.size(); ++ i) { | |||||
size_t group_algo_start = 2; | |||||
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||||
gconv.push_back({all_algos[i]}); | gconv.push_back({all_algos[i]}); | ||||
} | } | ||||
for (size_t i = 2; i < all_algos.size(); ++ i) { | |||||
algo2gconv[all_algos[i]] = &gconv[i - 2]; | |||||
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||||
algo2gconv[all_algos[i]] = &gconv[i - group_algo_start]; | |||||
} | } | ||||
for (auto &&i: gconv) { | for (auto &&i: gconv) { | ||||
all_algos.push_back(&i); | all_algos.push_back(&i); | ||||
@@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||||
megdnn_assert(all_algos_data == all_algos.data()); | megdnn_assert(all_algos_data == all_algos.data()); | ||||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | ||||
size_t algo_size = all_algos.size(); | |||||
for (size_t i=0; i<algo_size; ++i) { | |||||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||||
all_algos.push_back(bfloat16_refhold.back().get()); | |||||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||||
} | |||||
} | } | ||||
ConvolutionBackwardDataImpl::AlgoCUDNN* | ConvolutionBackwardDataImpl::AlgoCUDNN* | ||||
@@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||||
ConvolutionBackwardDataImpl *o, | ConvolutionBackwardDataImpl *o, | ||||
const TensorLayout &filter, const TensorLayout &diff, | const TensorLayout &filter, const TensorLayout &diff, | ||||
const TensorLayout &grad): | const TensorLayout &grad): | ||||
SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||||
SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||||
{ | { | ||||
} | } | ||||
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
ConvolutionBackwardDataImpl *o, | |||||
const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||||
ConvolutionBackwardDataImpl *o, const TensorLayout& filter, | |||||
const CanonizedFilterMeta &filter_meta, const TensorLayout &diff, | |||||
const TensorLayout &grad): | const TensorLayout &grad): | ||||
handle{concrete_handle(o->handle())}, | handle{concrete_handle(o->handle())}, | ||||
filter_meta{filter}, | |||||
filter_meta{filter_meta}, | |||||
diff_layout{&diff}, | diff_layout{&diff}, | ||||
grad_layout{&grad}, | grad_layout{&grad}, | ||||
filter_layout{&filter}, | |||||
opr{o} | opr{o} | ||||
{ | { | ||||
} | } | ||||
@@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
CanonizedFilterMeta filter_meta; | CanonizedFilterMeta filter_meta; | ||||
const TensorLayout *diff_layout, *grad_layout; | |||||
const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||||
ConvolutionBackwardDataImpl *opr; | ConvolutionBackwardDataImpl *opr; | ||||
std::string to_string() const; | std::string to_string() const; | ||||
void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | ||||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | ||||
} | } | ||||
SizeArgs(ConvolutionBackwardDataImpl *opr, | |||||
const TensorLayout &filter, const TensorLayout &diff, | |||||
const TensorLayout &grad); | |||||
SizeArgs(ConvolutionBackwardDataImpl *opr, | |||||
const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||||
const TensorLayout &grad); | |||||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad); | |||||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||||
const TensorLayout& filter, | |||||
const CanonizedFilterMeta& filter_meta, | |||||
const TensorLayout& diff, const TensorLayout& grad); | |||||
convolution::ForwardSizeArgs as_fwd_args() const { | convolution::ForwardSizeArgs as_fwd_args() const { | ||||
return {handle, grad_layout, filter_meta, diff_layout}; | |||||
return {handle, grad_layout, filter_layout, filter_meta, | |||||
diff_layout}; | |||||
} | } | ||||
}; | }; | ||||
struct ExecArgs: public SizeArgs { | struct ExecArgs: public SizeArgs { | ||||
@@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||||
} | } | ||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||||
public: | |||||
AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*); | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
bool is_reproducible() const override { return true; } | |||||
private: | |||||
std::string m_name; | |||||
ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr; | |||||
SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||||
TensorLayout& fsrc, TensorLayout& ffilter, | |||||
TensorLayout& fdst) const; | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
}; | |||||
//! implement group conv by another algo | //! implement group conv by another algo | ||||
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | ||||
AlgoBase *m_impl; | AlgoBase *m_impl; | ||||
@@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack { | |||||
AlgoChanwiseSmall chanwise_small; | AlgoChanwiseSmall chanwise_small; | ||||
std::vector<AlgoGroupConvGeneral> gconv; | std::vector<AlgoGroupConvGeneral> gconv; | ||||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | ||||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
std::vector<AlgoBase*> | std::vector<AlgoBase*> | ||||
//! all algorithms | //! all algorithms | ||||
all_algos, | all_algos, | ||||
//! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
non_cudnn_algos; | |||||
non_cudnn_algos, | |||||
bfloat16_algos; | |||||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | ||||
}; | }; | ||||
@@ -0,0 +1,115 @@ | |||||
/** | |||||
* \file src/cuda/convolution/backward_data/bfloat16.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./algo.h" | |||||
#include "src/cuda/convolution/chanwise/kern.cuh" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace convolution; | |||||
ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16( | |||||
ConvolutionBackwardDataImpl::AlgoBase* algorithm) | |||||
: m_algorithm(algorithm) { | |||||
megdnn_assert_internal(algorithm); | |||||
m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s", | |||||
m_algorithm->name()); | |||||
} | |||||
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs | |||||
ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||||
const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||||
TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||||
ffilter = *args.filter_layout; | |||||
fdiff = *args.diff_layout; | |||||
fgrad = *args.grad_layout; | |||||
auto change_dtype = [](TensorLayout& layout) { | |||||
if (layout.dtype == dtype::BFloat16()) { | |||||
layout.dtype = dtype::Float32(); | |||||
} | |||||
}; | |||||
change_dtype(ffilter); | |||||
change_dtype(fdiff); | |||||
change_dtype(fgrad); | |||||
opr->param() = args.opr->param(); | |||||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
opr->execution_policy() = {m_algorithm}; | |||||
return SizeArgs(opr, ffilter, fdiff, fgrad); | |||||
} | |||||
bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( | |||||
const SizeArgs& args) const { | |||||
TensorLayout ffilter, fdiff, fgrad; | |||||
auto conv_back_data_opr = | |||||
args.handle->create_operator<ConvolutionBackwardData>(); | |||||
SizeArgs fargs = float_args( | |||||
args, | |||||
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||||
ffilter, fdiff, fgrad); | |||||
return args.diff_layout->dtype == args.filter_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16() && | |||||
m_algorithm->is_available(fargs); | |||||
} | |||||
WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( | |||||
void* ptr, const SizeArgs& args) const { | |||||
TensorLayout ffilter, fdiff, fgrad; | |||||
auto conv_back_data_opr = | |||||
args.handle->create_operator<ConvolutionBackwardData>(); | |||||
SizeArgs fargs = float_args( | |||||
args, | |||||
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||||
ffilter, fdiff, fgrad); | |||||
SmallVector<size_t> sizes; | |||||
auto get_workspace = [&sizes](const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
if (src.dtype != dst.dtype) { | |||||
sizes.push_back(dst.span().dist_byte()); | |||||
} | |||||
}; | |||||
get_workspace(*args.filter_layout, ffilter); | |||||
get_workspace(*args.diff_layout, fdiff); | |||||
get_workspace(*args.grad_layout, fgrad); | |||||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||||
const ExecArgs& args) const { | |||||
TensorND ffilter_tensor = *args.filter_tensor; | |||||
TensorND fdiff_tensor = *args.diff_tensor; | |||||
TensorND fgrad_tensor = *args.grad_tensor; | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||||
{ | |||||
cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||||
.src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||||
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||||
} | |||||
{ | |||||
auto conv_back_data_opr = | |||||
args.handle->create_operator<ConvolutionBackwardData>(); | |||||
conv_back_data_opr->param() = args.opr->param(); | |||||
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
conv_back_data_opr->execution_policy() = {m_algorithm}; | |||||
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | |||||
cvter.workspace()); | |||||
} | |||||
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -19,6 +19,10 @@ using namespace convolution; | |||||
bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( | bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
return args.filter_meta.format == Param::Format::NCHW && | return args.filter_meta.format == Param::Format::NCHW && | ||||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | ||||
@@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||||
bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( | bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
#if CUDA_VERSION < 9000 | #if CUDA_VERSION < 9000 | ||||
if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) | if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) | ||||
return false; | return false; | ||||
@@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||||
bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( | bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto sub_args = args; | auto sub_args = args; | ||||
TensorLayout diff_pg, grad_pg; | TensorLayout diff_pg, grad_pg; | ||||
modify_size_args(sub_args, diff_pg, grad_pg); | modify_size_args(sub_args, diff_pg, grad_pg); | ||||
@@ -20,6 +20,10 @@ using namespace cuda; | |||||
bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( | bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto &&fm = args.filter_meta; | auto &&fm = args.filter_meta; | ||||
return args.filter_meta.format == Param::Format::NCHW && | return args.filter_meta.format == Param::Format::NCHW && | ||||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | ||||
@@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||||
megdnn_assert(all_algos_data == all_algos.data()); | megdnn_assert(all_algos_data == all_algos.data()); | ||||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | ||||
size_t algo_size = all_algos.size(); | |||||
for (size_t i=0; i<algo_size; ++i) { | |||||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||||
all_algos.push_back(bfloat16_refhold.back().get()); | |||||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||||
} | |||||
} | } | ||||
ConvolutionBackwardFilterImpl::AlgoCUDNN* | ConvolutionBackwardFilterImpl::AlgoCUDNN* | ||||
@@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||||
ConvolutionBackwardFilterImpl *o, | ConvolutionBackwardFilterImpl *o, | ||||
const TensorLayout &src, const TensorLayout &diff, | const TensorLayout &src, const TensorLayout &diff, | ||||
const TensorLayout &grad): | const TensorLayout &grad): | ||||
SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff)) | |||||
SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff)) | |||||
{ | { | ||||
} | } | ||||
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
ConvolutionBackwardFilterImpl *o, | |||||
const TensorLayout &src, const TensorLayout &diff, | |||||
const CanonizedFilterMeta &grad): | |||||
handle{concrete_handle(o->handle())}, | |||||
src_layout{&src}, | |||||
diff_layout{&diff}, | |||||
grad_filter_meta{grad}, | |||||
opr{o} | |||||
{ | |||||
} | |||||
ConvolutionBackwardFilterImpl* o, const TensorLayout& src, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
const CanonizedFilterMeta& grad_meta) | |||||
: handle{concrete_handle(o->handle())}, | |||||
src_layout{&src}, | |||||
diff_layout{&diff}, | |||||
grad_layout{&grad}, | |||||
grad_filter_meta{grad_meta}, | |||||
opr{o} {} | |||||
ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( | ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( | ||||
ConvolutionBackwardFilterImpl *opr, | ConvolutionBackwardFilterImpl *opr, | ||||
@@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
public: | public: | ||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
const TensorLayout *src_layout, *diff_layout; | |||||
const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||||
CanonizedFilterMeta grad_filter_meta; | CanonizedFilterMeta grad_filter_meta; | ||||
ConvolutionBackwardFilterImpl *opr; | ConvolutionBackwardFilterImpl *opr; | ||||
@@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
SizeArgs(ConvolutionBackwardFilterImpl *opr, | SizeArgs(ConvolutionBackwardFilterImpl *opr, | ||||
const TensorLayout &src, const TensorLayout &diff, | const TensorLayout &src, const TensorLayout &diff, | ||||
const TensorLayout &grad); | const TensorLayout &grad); | ||||
SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||||
const TensorLayout &src, const TensorLayout &diff, | |||||
const CanonizedFilterMeta &grad); | |||||
SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
const CanonizedFilterMeta& grad_meta); | |||||
convolution::ForwardSizeArgs as_fwd_args() const { | convolution::ForwardSizeArgs as_fwd_args() const { | ||||
return {handle, src_layout, grad_filter_meta, diff_layout}; | |||||
return {handle, src_layout, grad_layout, grad_filter_meta, | |||||
diff_layout}; | |||||
} | } | ||||
}; | }; | ||||
struct ExecArgs: public SizeArgs { | struct ExecArgs: public SizeArgs { | ||||
@@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||||
} | } | ||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | |||||
public: | |||||
AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*); | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
bool is_reproducible() const override { return true; } | |||||
private: | |||||
std::string m_name; | |||||
ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr; | |||||
SizeArgs float_args(const SizeArgs& args, | |||||
ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc, | |||||
TensorLayout& ffilter, TensorLayout& fdst) const; | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
}; | |||||
//! implement group conv by another algo | //! implement group conv by another algo | ||||
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | ||||
AlgoBase *m_impl; | AlgoBase *m_impl; | ||||
@@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack { | |||||
AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
std::vector<AlgoGroupConvGeneral> gconv; | std::vector<AlgoGroupConvGeneral> gconv; | ||||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | ||||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
std::vector<AlgoBase*> | std::vector<AlgoBase*> | ||||
//! all algorithms | //! all algorithms | ||||
all_algos, | all_algos, | ||||
//! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
non_cudnn_algos; | |||||
non_cudnn_algos, | |||||
bfloat16_algos; | |||||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | ||||
}; | }; | ||||
@@ -0,0 +1,117 @@ | |||||
/** | |||||
* \file src/cuda/convolution/backward_filter/bfloat16.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./algo.h" | |||||
#include "src/cuda/convolution/chanwise/kern.cuh" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace convolution; | |||||
ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16( | |||||
ConvolutionBackwardFilterImpl::AlgoBase* algorithm) | |||||
: m_algorithm(algorithm) { | |||||
megdnn_assert_internal(algorithm); | |||||
m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s", | |||||
m_algorithm->name()); | |||||
} | |||||
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs | |||||
ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||||
const SizeArgs& args, ConvolutionBackwardFilterImpl* opr, | |||||
TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||||
fsrc = *args.src_layout; | |||||
fdiff = *args.diff_layout; | |||||
fgrad = *args.grad_layout; | |||||
auto change_dtype = [](TensorLayout& layout) { | |||||
if (layout.dtype == dtype::BFloat16()) { | |||||
layout.dtype = dtype::Float32(); | |||||
} | |||||
}; | |||||
change_dtype(fsrc); | |||||
change_dtype(fdiff); | |||||
change_dtype(fgrad); | |||||
opr->param() = args.opr->param(); | |||||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
opr->execution_policy() = {m_algorithm}; | |||||
return SizeArgs(opr, fsrc, fdiff, fgrad); | |||||
} | |||||
bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( | |||||
const SizeArgs& args) const { | |||||
TensorLayout fsrc, fdiff, fgrad; | |||||
auto conv_back_filter_opr = | |||||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||||
SizeArgs fargs = float_args(args, | |||||
static_cast<ConvolutionBackwardFilterImpl*>( | |||||
conv_back_filter_opr.get()), | |||||
fsrc, fdiff, fgrad); | |||||
return args.src_layout->dtype == args.diff_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16() && | |||||
m_algorithm->is_available(fargs); | |||||
} | |||||
WorkspaceBundle | |||||
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( | |||||
void* ptr, const SizeArgs& args) const { | |||||
TensorLayout fsrc, fdiff, fgrad; | |||||
auto conv_back_filter_opr = | |||||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||||
SizeArgs fargs = float_args(args, | |||||
static_cast<ConvolutionBackwardFilterImpl*>( | |||||
conv_back_filter_opr.get()), | |||||
fsrc, fdiff, fgrad); | |||||
SmallVector<size_t> sizes; | |||||
auto get_workspace = [&sizes](const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
if (src.dtype != dst.dtype) { | |||||
sizes.push_back(dst.span().dist_byte()); | |||||
} | |||||
}; | |||||
get_workspace(*args.src_layout, fsrc); | |||||
get_workspace(*args.diff_layout, fdiff); | |||||
get_workspace(*args.grad_layout, fgrad); | |||||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||||
const ExecArgs& args) const { | |||||
TensorND fsrc_tensor = *args.src_tensor; | |||||
TensorND fdiff_tensor = *args.diff_tensor; | |||||
TensorND fgrad_tensor = *args.grad_tensor; | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||||
{ | |||||
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||||
.src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||||
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||||
} | |||||
{ | |||||
auto conv_back_filter_opr = | |||||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||||
conv_back_filter_opr->param() = args.opr->param(); | |||||
conv_back_filter_opr->param().compute_mode = | |||||
Param::ComputeMode::DEFAULT; | |||||
conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||||
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | |||||
cvter.workspace()); | |||||
} | |||||
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -19,6 +19,10 @@ using namespace convolution; | |||||
bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( | bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.src_layout->dtype == args.src_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto &&fm = args.grad_filter_meta; | auto &&fm = args.grad_filter_meta; | ||||
return fm.format == Param::Format::NCHW && | return fm.format == Param::Format::NCHW && | ||||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | ||||
@@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||||
bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.src_layout->dtype == args.src_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto sub_args = args; | auto sub_args = args; | ||||
TensorLayout src_pg, diff_pg; | TensorLayout src_pg, diff_pg; | ||||
modify_size_args(sub_args, src_pg, diff_pg); | modify_size_args(sub_args, src_pg, diff_pg); | ||||
@@ -19,6 +19,10 @@ using namespace cuda; | |||||
bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( | bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
if (args.src_layout->dtype == args.src_layout->dtype && | |||||
args.diff_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
auto &&fm = args.grad_filter_meta; | auto &&fm = args.grad_filter_meta; | ||||
return fm.format == Param::Format::NCHW && | return fm.format == Param::Format::NCHW && | ||||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | ||||
@@ -16,6 +16,10 @@ using namespace cuda; | |||||
using namespace convolution; | using namespace convolution; | ||||
bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { | bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { | ||||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
args.src_layout->dtype == dtype::BFloat16()) { | |||||
return false; | |||||
} | |||||
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | ||||
// on Tegra K1. | // on Tegra K1. | ||||
@@ -25,6 +25,7 @@ namespace convolution { | |||||
struct ForwardSizeArgs { | struct ForwardSizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
const TensorLayout *src_layout; | const TensorLayout *src_layout; | ||||
const TensorLayout *filter_layout; | |||||
CanonizedFilterMeta filter_meta; | CanonizedFilterMeta filter_meta; | ||||
const TensorLayout *dst_layout; | const TensorLayout *dst_layout; | ||||
}; | }; | ||||
@@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, | |||||
_megdnn_tensor_out grad, | _megdnn_tensor_out grad, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); | AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); | ||||
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout); | |||||
auto algo = get_algorithm(this, filter.layout, args.filter_meta, | |||||
diff.layout, grad.layout); | |||||
algo->check_workspace(args, workspace).exec(args); | algo->check_workspace(args, workspace).exec(args); | ||||
} | } | ||||
@@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | |||||
reproducible); | |||||
return get_algorithm_heuristic(filter, fm, diff, grad, | |||||
workspace_limit_in_bytes, reproducible); | |||||
} | } | ||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | |||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | |||||
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | |||||
if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | sm_algo_pack.chanwise.is_available_reproducible( | ||||
@@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
args = orig_args; | args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_data"); | |||||
if (args.filter_layout->dtype.enumv() != | |||||
DTypeTrait<dtype::BFloat16>::enumv) { | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda conv bwd_data"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda conv bwd_data"); | |||||
} | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_data"); | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_data"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_data"); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||||
const TensorLayout &diff, | const TensorLayout &diff, | ||||
const TensorLayout &grad) { | const TensorLayout &grad) { | ||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
return get_algorithm(this, args.filter_meta, diff, grad)-> | |||||
return get_algorithm(this, filter, args.filter_meta, diff, grad)-> | |||||
get_workspace_in_bytes(args); | get_workspace_in_bytes(args); | ||||
} | } | ||||
@@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
AlgoBase::ExecArgs args(this, src, diff, grad, workspace); | AlgoBase::ExecArgs args(this, src, diff, grad, workspace); | ||||
auto algo = get_algorithm(this, src.layout, diff.layout, | auto algo = get_algorithm(this, src.layout, diff.layout, | ||||
args.grad_filter_meta); | |||||
grad.layout, args.grad_filter_meta); | |||||
algo->check_workspace(args, workspace).exec(args); | algo->check_workspace(args, workspace).exec(args); | ||||
} | } | ||||
@@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | |||||
reproducible); | |||||
return get_algorithm_heuristic(src, diff, grad, fm, | |||||
workspace_limit_in_bytes, reproducible); | |||||
} | } | ||||
ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | |||||
bool reproducible) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad); | |||||
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | |||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | |||||
if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | sm_algo_pack.chanwise.is_available_reproducible( | ||||
@@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
args = orig_args; | args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_filter"); | |||||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | |||||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||||
} | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_filter"); | |||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_filter"); | |||||
} else { | |||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
"cuda conv bwd_filter"); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||||
const TensorLayout &diff, | const TensorLayout &diff, | ||||
const TensorLayout &grad) { | const TensorLayout &grad) { | ||||
AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
return get_algorithm(this, src, diff, args.grad_filter_meta)-> | |||||
return get_algorithm(this, src, diff, grad, args.grad_filter_meta)-> | |||||
get_workspace_in_bytes(args); | get_workspace_in_bytes(args); | ||||
} | } | ||||
@@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | bool reproducible) override; | ||||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible); | |||||
Algorithm* get_algorithm_heuristic( | |||||
const TensorLayout& filter, | |||||
const CanonizedFilterMeta& filter_meta, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, bool reproducible); | |||||
size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
@@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
class AlgoChanwise; | class AlgoChanwise; | ||||
class AlgoChanwiseSmall; | class AlgoChanwiseSmall; | ||||
class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
class AlgoBFloat16; | |||||
class AlgoPack; | class AlgoPack; | ||||
@@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
bool reproducible) override; | bool reproducible) override; | ||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, | |||||
const TensorLayout& gradk, | |||||
const CanonizedFilterMeta& grad_meta, | |||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | bool reproducible); | ||||
size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
@@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
class AlgoMatmul; | class AlgoMatmul; | ||||
class AlgoChanwise; | class AlgoChanwise; | ||||
class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
class AlgoBFloat16; | |||||
class AlgoPack; | class AlgoPack; | ||||
@@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||||
const CanonizedFilterMeta &filter, | const CanonizedFilterMeta &filter, | ||||
const TensorLayout &dst); | const TensorLayout &dst); | ||||
}; | }; | ||||
struct ExecArgs: public SizeArgs { | |||||
struct ExecArgs : public SizeArgs { | |||||
const TensorND *src_tensor, *filter_tensor, *dst_tensor; | const TensorND *src_tensor, *filter_tensor, *dst_tensor; | ||||
Workspace workspace; | Workspace workspace; | ||||
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||||
#include "../kern_impl.inl" | |||||
#endif |
@@ -0,0 +1,18 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
// generated by gen_elemwise_special_kern_impls.py | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
#include "../special_kerns.inl" | |||||
INST(::megdnn::dtype::BFloat16) | |||||
#undef INST | |||||
} | |||||
} | |||||
#endif |
@@ -141,6 +141,9 @@ INST_FOR_CTYPE | |||||
#define ct dt_float16 | #define ct dt_float16 | ||||
INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
#undef ct | #undef ct | ||||
#define ct dt_bfloat16 | |||||
INST_FOR_CTYPE | |||||
#undef ct | |||||
#define ct dt_int8 | #define ct dt_int8 | ||||
INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
#undef ct | #undef ct | ||||
@@ -68,6 +68,17 @@ namespace elemwise_intl { | |||||
return t; | return t; | ||||
} | } | ||||
struct __attribute__((aligned(8))) bhalf4 { | |||||
dt_bfloat16 x, y, z, w; | |||||
}; | |||||
__device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, | |||||
dt_bfloat16 z, dt_bfloat16 w) { | |||||
bhalf4 t; | |||||
t.x = x, t.y = y, t.z = z, t.w = w; | |||||
return t; | |||||
} | |||||
#define INST(_ctype, _vect_type) \ | #define INST(_ctype, _vect_type) \ | ||||
template <> \ | template <> \ | ||||
class VectTypeTrait<_ctype> { \ | class VectTypeTrait<_ctype> { \ | ||||
@@ -87,6 +98,7 @@ namespace elemwise_intl { | |||||
INST(dt_uint8, uchar4); | INST(dt_uint8, uchar4); | ||||
INST(dt_float32, float4); | INST(dt_float32, float4); | ||||
INST(dt_float16, half4); | INST(dt_float16, half4); | ||||
INST(dt_bfloat16, bhalf4); | |||||
INST(dt_int32, int4); | INST(dt_int32, int4); | ||||
INST(dt_int16, short4); | INST(dt_int16, short4); | ||||
#undef as_raw | #undef as_raw | ||||
@@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | |||||
__trap(); | __trap(); | ||||
((int*)0)[0] = 1; | ((int*)0)[0] = 1; | ||||
} | } | ||||
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | |||||
__trap(); | |||||
((int*)0)[0] = 1; | |||||
} | |||||
#endif | #endif | ||||
__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { | __device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { | ||||
@@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
all_algos.push_back(&cublas_lt); | all_algos.push_back(&cublas_lt); | ||||
#endif | #endif | ||||
all_algos.push_back(&naive); | all_algos.push_back(&naive); | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | |||||
all_algos.push_back(cublas_bfloat16.get()); | |||||
#endif | |||||
} | } | ||||
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/matrix_mul/opr_impl.h" | #include "src/cuda/matrix_mul/opr_impl.h" | ||||
#include <cuda.h> | #include <cuda.h> | ||||
#include <memory> | |||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
#include <cublasLt.h> | #include <cublasLt.h> | ||||
#endif | #endif | ||||
@@ -140,6 +141,24 @@ public: | |||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
}; | }; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||||
public: | |||||
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*); | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
void exec(const ExecArgs& args) const override; | |||||
bool is_reproducible() const override { return true; } | |||||
private: | |||||
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||||
std::string m_name; | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
SizeArgs float_args(const SizeArgs& args) const; | |||||
}; | |||||
#endif | |||||
class MatrixMulForwardImpl::AlgoPack { | class MatrixMulForwardImpl::AlgoPack { | ||||
AlgoPack(const AlgoPack&) = delete; | AlgoPack(const AlgoPack&) = delete; | ||||
AlgoPack& operator=(const AlgoPack&) = delete; | AlgoPack& operator=(const AlgoPack&) = delete; | ||||
@@ -154,7 +173,9 @@ public: | |||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
AlgoCuBlasLt cublas_lt; | AlgoCuBlasLt cublas_lt; | ||||
#endif | #endif | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | |||||
#endif | |||||
std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -0,0 +1,91 @@ | |||||
/** | |||||
* \file dnn/src/cuda/matrix_mul/bfloat16.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/handle.h" | |||||
#include "src/cuda/matrix_mul/algos.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||||
MatrixMulForwardImpl::AlgoBase* algorithm) | |||||
: m_algorithm(algorithm) { | |||||
megdnn_assert_internal(algorithm); | |||||
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); | |||||
} | |||||
MatrixMulForwardImpl::AlgoBase::SizeArgs | |||||
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { | |||||
auto new_args = args; | |||||
auto change_dtype = [](TensorLayout& layout) { | |||||
if (layout.dtype == dtype::BFloat16()) { | |||||
layout.dtype = dtype::Float32(); | |||||
} | |||||
}; | |||||
change_dtype(new_args.layout_a); | |||||
change_dtype(new_args.layout_b); | |||||
change_dtype(new_args.layout_c); | |||||
return new_args; | |||||
} | |||||
bool MatrixMulForwardImpl::AlgoBFloat16::is_available( | |||||
const SizeArgs& args) const { | |||||
auto fargs = float_args(args); | |||||
return args.layout_a.dtype == dtype::BFloat16() && | |||||
m_algorithm->is_available(fargs); | |||||
} | |||||
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||||
void* ptr, const SizeArgs& args) const { | |||||
auto fargs = float_args(args); | |||||
SmallVector<size_t> sizes; | |||||
auto get_workspace = [&sizes](const TensorLayout& src) { | |||||
TensorLayout dst = src; | |||||
if (dst.dtype == dtype::BFloat16()) { | |||||
dst.dtype = dtype::Float32(); | |||||
sizes.push_back(dst.span().dist_byte()); | |||||
} | |||||
}; | |||||
get_workspace(args.layout_a); | |||||
get_workspace(args.layout_b); | |||||
get_workspace(args.layout_c); | |||||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||||
const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||||
TensorND a = args.tensor_a; | |||||
TensorND b = args.tensor_b; | |||||
TensorND c = args.tensor_c; | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>( | |||||
args.opr->handle(), &bundle); | |||||
ctypecvt.src_to_comp_type(args.tensor_a, a) | |||||
.src_to_comp_type(args.tensor_b, b) | |||||
.src_to_comp_type(args.tensor_c, c); | |||||
{ | |||||
auto matmul_opr = | |||||
args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
matmul_opr->param() = args.opr->param(); | |||||
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||||
matmul_opr->execution_policy() = {m_algorithm}; | |||||
matmul_opr->exec(a, b, c, ctypecvt.workspace()); | |||||
} | |||||
ctypecvt.comp_to_dst_type(c, args.tensor_c); | |||||
} | |||||
// vim: syntax=cpp.doxygen |