@@ -29,6 +29,7 @@ | |||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y | |||
#else | |||
#include "megdnn/dtype/half.hpp" | |||
#include "megdnn/dtype/bfloat16.hpp" | |||
#define MEGDNN_INC_FLOAT16(_x) _x | |||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x | |||
#endif | |||
@@ -49,6 +50,7 @@ namespace megdnn { | |||
cb(IntB4) \ | |||
cb(Byte) \ | |||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||
cb(UintB4) \ | |||
/*! | |||
@@ -62,6 +64,7 @@ namespace megdnn { | |||
cb(Int32) \ | |||
cb(Byte) \ | |||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||
/*! | |||
* \brief iterate through each fractional byte dtype | |||
@@ -101,6 +104,7 @@ namespace megdnn { | |||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | |||
cb(::megdnn::dtype::Float32) \ | |||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | |||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ | |||
/*! | |||
* \brief iterate through each dtype object that can be involved in integer | |||
@@ -345,6 +349,7 @@ typedef int16_t dt_int16; | |||
typedef int8_t dt_int8; | |||
typedef uint8_t dt_uint8; | |||
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 | |||
#if MEGDNN_CC_HOST | |||
@@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||
Float16, | |||
#endif | |||
UintB4 = 10, | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
BFloat16 = 11, | |||
#endif | |||
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | |||
#define D(_name) _name, | |||
@@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | |||
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | |||
std::numeric_limits<dt_float16>::lowest(), | |||
std::numeric_limits<dt_float16>::max())); | |||
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED, | |||
std::numeric_limits<dt_bfloat16>::lowest(), | |||
std::numeric_limits<dt_bfloat16>::max())); | |||
template <> | |||
struct DTypeTrait<dtype::Byte> { | |||
@@ -50,167 +50,7 @@ | |||
#include <hip/hip_fp16.h> | |||
#endif | |||
/// Combined gcc version number. | |||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
//check C++11 language features | |||
#if defined(__clang__) //clang | |||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif*/ | |||
#elif defined(__GNUC__) //gcc | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#endif | |||
#elif defined(_MSC_VER) //Visual C++ | |||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#define HALF_POP_WARNINGS 1 | |||
#pragma warning(push) | |||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
#endif | |||
//check C++11 library features | |||
#include <utility> | |||
#if defined(_LIBCPP_VERSION) //libc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#elif defined(__GLIBCXX__) //libstdc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifdef __clang__ | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#else | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
#if _CPPLIB_VER >= 520 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#if _CPPLIB_VER >= 610 | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#undef HALF_GNUC_VERSION | |||
//support constexpr | |||
#if HALF_ENABLE_CPP11_CONSTEXPR | |||
#define HALF_CONSTEXPR constexpr | |||
#define HALF_CONSTEXPR_CONST constexpr | |||
#else | |||
#define HALF_CONSTEXPR | |||
#define HALF_CONSTEXPR_CONST const | |||
#endif | |||
//support noexcept | |||
#if HALF_ENABLE_CPP11_NOEXCEPT | |||
#define HALF_NOEXCEPT noexcept | |||
#define HALF_NOTHROW noexcept | |||
#else | |||
#define HALF_NOEXCEPT | |||
#define HALF_NOTHROW throw() | |||
#endif | |||
#include <algorithm> | |||
#include <limits> | |||
#include <climits> | |||
#include <cmath> | |||
#include <cstring> | |||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#include <type_traits> | |||
#endif | |||
#if HALF_ENABLE_CPP11_CSTDINT | |||
#include <cstdint> | |||
#endif | |||
#if HALF_ENABLE_CPP11_HASH | |||
#include <functional> | |||
#endif | |||
#include "megdnn/dtype/half_common_prologue.h" | |||
/// Default rounding mode. | |||
/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as | |||
@@ -3141,16 +2981,7 @@ namespace std | |||
#endif | |||
} | |||
#undef HALF_CONSTEXPR | |||
#undef HALF_CONSTEXPR_CONST | |||
#undef HALF_NOEXCEPT | |||
#undef HALF_NOTHROW | |||
#ifdef HALF_POP_WARNINGS | |||
#pragma warning(pop) | |||
#undef HALF_POP_WARNINGS | |||
#endif | |||
#include "megdnn/dtype/half_common_epilogue.h" | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,48 @@ | |||
/** | |||
* half - IEEE 754-based half-precision floating point library. | |||
* | |||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
* | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
* Software is furnished to do so, subject to the following conditions: | |||
* | |||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* | |||
* Version 1.11.0 | |||
* \file | |||
* Main header file for half precision functionality. | |||
* | |||
* -------------------------------------------------------------------------- | |||
* \file include/megdnn/dtype/half_common_epilogue.h | |||
* | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* | |||
* This file has been modified by Megvii ("Megvii Modifications"). | |||
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||
* | |||
* -------------------------------------------------------------------------- | |||
*/ | |||
#undef HALF_CONSTEXPR | |||
#undef HALF_CONSTEXPR_CONST | |||
#undef HALF_NOEXCEPT | |||
#undef HALF_NOTHROW | |||
#ifdef HALF_POP_WARNINGS | |||
#pragma warning(pop) | |||
#undef HALF_POP_WARNINGS | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,202 @@ | |||
/** | |||
* half - IEEE 754-based half-precision floating point library. | |||
* | |||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
* | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
* Software is furnished to do so, subject to the following conditions: | |||
* | |||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* | |||
* Version 1.11.0 | |||
* \file | |||
* Main header file for half precision functionality. | |||
* | |||
* -------------------------------------------------------------------------- | |||
* \file dnn/include/megdnn/dtype/half_common_prologue.h | |||
* | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* | |||
* This file has been modified by Megvii ("Megvii Modifications"). | |||
* All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||
* | |||
* -------------------------------------------------------------------------- | |||
*/ | |||
#include "megdnn/arch.h" | |||
/// Combined gcc version number. | |||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
//check C++11 language features | |||
#if defined(__clang__) //clang | |||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif*/ | |||
#elif defined(__GNUC__) //gcc | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#endif | |||
#elif defined(_MSC_VER) //Visual C++ | |||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#define HALF_POP_WARNINGS 1 | |||
#pragma warning(push) | |||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
#endif | |||
//check C++11 library features | |||
#include <utility> | |||
#if defined(_LIBCPP_VERSION) //libc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#elif defined(__GLIBCXX__) //libstdc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifdef __clang__ | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#else | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
#if _CPPLIB_VER >= 520 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#if _CPPLIB_VER >= 610 | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#undef HALF_GNUC_VERSION | |||
//support constexpr | |||
#if HALF_ENABLE_CPP11_CONSTEXPR | |||
#define HALF_CONSTEXPR constexpr | |||
#define HALF_CONSTEXPR_CONST constexpr | |||
#else | |||
#define HALF_CONSTEXPR | |||
#define HALF_CONSTEXPR_CONST const | |||
#endif | |||
//support noexcept | |||
#if HALF_ENABLE_CPP11_NOEXCEPT | |||
#define HALF_NOEXCEPT noexcept | |||
#define HALF_NOTHROW noexcept | |||
#else | |||
#define HALF_NOEXCEPT | |||
#define HALF_NOTHROW throw() | |||
#endif | |||
#include <algorithm> | |||
#include <limits> | |||
#include <climits> | |||
#include <cmath> | |||
#include <cstring> | |||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#include <type_traits> | |||
#endif | |||
#if HALF_ENABLE_CPP11_CSTDINT | |||
#include <cstdint> | |||
#endif | |||
#if HALF_ENABLE_CPP11_HASH | |||
#include <functional> | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -30,7 +30,7 @@ def main(): | |||
w('// generated by gen_cond_take_kern_impls.py') | |||
w('#include "../kern.inl"') | |||
w('') | |||
if dtype == 'dt_float16': | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('namespace megdnn {') | |||
w('namespace cuda {') | |||
@@ -48,7 +48,7 @@ def main(): | |||
w('} // cond_take') | |||
w('} // cuda') | |||
w('} // megdnn') | |||
if dtype == 'dt_float16': | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
@@ -34,7 +34,7 @@ def main(): | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_kern_impls.py') | |||
if ctype == 'dt_float16': | |||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
@@ -42,7 +42,7 @@ def main(): | |||
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||
w('#include "../kern_impl.inl"') | |||
if ctype == 'dt_float16': | |||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
@@ -30,14 +30,14 @@ def main(): | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_special_kern_impls.py') | |||
if dtype == 'dt_float16': | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('#include "../special_kerns.inl"') | |||
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef INST') | |||
w('}') | |||
w('}') | |||
if dtype == 'dt_float16': | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
@@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), | |||
'dt_int8': ('Int8', 'INT'), | |||
'dt_int16': ('Int16', 'INT'), | |||
'dt_float32': ('Float32', 'FLOAT'), | |||
'dt_float16': ('Float16', 'FLOAT') | |||
'dt_float16': ('Float16', 'FLOAT'), | |||
'dt_bfloat16': ('BFloat16', 'FLOAT') | |||
} | |||
MODES = { | |||
@@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
|| src.enumv() == DTypeEnum::Float16 | |||
|| src.enumv() == DTypeEnum::BFloat16 | |||
#endif | |||
, | |||
"ComputeMode::FLOAT32 is only available for Float16 " | |||
, | |||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
"input / output."); | |||
} | |||
@@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, | |||
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
|| filter.enumv() == DTypeEnum::Float16 | |||
|| filter.enumv() == DTypeEnum::BFloat16 | |||
#endif | |||
, | |||
"ComputeMode::FLOAT32 is only available for Float16 " | |||
, | |||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
"input / output."); | |||
} | |||
@@ -87,7 +87,8 @@ namespace megdnn { | |||
//! define kernel for all float types | |||
#define DEF_KERN_FLOAT(_mode, _imp) \ | |||
DEF_KERN(dt_float32, _mode, _imp); \ | |||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) | |||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ | |||
MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) | |||
//! define kernel for all int types | |||
#define DEF_KERN_INT(_mode, _imp) \ | |||
@@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||
C = TensorLayout(TensorShape({A0, B1}), C.dtype); | |||
} else { | |||
auto do_deduce = [&](size_t pack_size) { | |||
megdnn_assert( | |||
A.ndim == 4 && B.ndim == 3, | |||
"matmul requires input dimension to be A(4), B(3); get: %s %s", | |||
A.TensorShape::to_string().c_str(), | |||
B.TensorShape::to_string().c_str()); | |||
megdnn_assert(A.ndim == 4 && B.ndim == 3, | |||
"matmul requires input dimension to be A(4), B(3); " | |||
"get: %s %s", | |||
A.TensorShape::to_string().c_str(), | |||
B.TensorShape::to_string().c_str()); | |||
A0 = A.shape[0]; | |||
A1 = A.shape[1]; | |||
B0 = B.shape[0]; | |||
@@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||
std::swap(A0, A1); | |||
if (m_param.transposeB) | |||
std::swap(B0, B1); | |||
megdnn_assert( | |||
A1 == B0, | |||
"shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), " | |||
"(transposed) B is (%zu,%zu,4)", | |||
A0, A1, B0, B1); | |||
megdnn_assert(A1 == B0, | |||
"shape mismatch in matmal: (transposed) A is " | |||
"(%zu,%zu,4,4), " | |||
"(transposed) B is (%zu,%zu,4)", | |||
A0, A1, B0, B1); | |||
C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); | |||
}; | |||
do_deduce(pack_size(param().format)); | |||
@@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, | |||
} | |||
megdnn_assert(param().compute_mode != | |||
Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | |||
|| A.dtype == dtype::Float16()), | |||
"ComputeMode::FLOAT32 is only available for Float16 " | |||
|| A.dtype == dtype::Float16() || | |||
A.dtype == dtype::BFloat16()), | |||
"ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
"input / output."); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
@@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> { | |||
} | |||
}; | |||
template <> | |||
struct RoundingConverter<half_bfloat16::bfloat16> { | |||
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( | |||
float x) const { | |||
return static_cast<half_bfloat16::bfloat16>(x); | |||
} | |||
}; | |||
#endif // #ifdef MEGDNN_DISABLE_FLOAT16 | |||
template <> | |||
@@ -16,6 +16,7 @@ | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/handle.h" | |||
#include "megdnn/thin/small_vector.h" | |||
#include "megdnn/oprs/general.h" | |||
#include "src/common/hash_ct.h" | |||
#include "src/common/utils.cuh" | |||
@@ -548,6 +549,59 @@ public: | |||
std::string to_string() const; | |||
}; | |||
/**! | |||
* \brief helpers for oprs using typecvt between comp_type and dst_type | |||
* \tparam SrcType src type | |||
* \tparam CompType compute type, such as fp32 for conv | |||
* \tparam DstType dst type | |||
*/ | |||
template <typename SrcType, typename CompType, typename DstType = SrcType> | |||
struct CompTypeCvter { | |||
std::unique_ptr<TypeCvt> m_cvt_opr; | |||
WorkspaceBundle* m_workspace_bundle; | |||
size_t m_workspace_idx; | |||
CompTypeCvter(Handle* handle, WorkspaceBundle* bundle) | |||
: m_workspace_bundle(bundle), m_workspace_idx(0) { | |||
megdnn_assert( | |||
(DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv && | |||
DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv), | |||
"SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is " | |||
"not " | |||
"supportted.", | |||
SrcType().name(), CompType().name(), DstType().name(), | |||
CompType().name()); | |||
m_cvt_opr = handle->create_operator<TypeCvt>(); | |||
} | |||
//! Convert tensor dtype from SrcType to CompType. | |||
CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) { | |||
if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) { | |||
if (!comp.layout.dtype.valid() || | |||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | |||
comp.layout.dtype = CompType(); | |||
comp.layout.init_contiguous_stride(); | |||
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); | |||
if (src.layout.ndim) { | |||
m_cvt_opr->exec(src, comp); | |||
} | |||
} | |||
} | |||
return *this; | |||
} | |||
//! Convert tensor dtype from CompType to DstType. | |||
CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) { | |||
megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv); | |||
if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) { | |||
m_cvt_opr->exec(comp, dst); | |||
} | |||
return *this; | |||
} | |||
Workspace workspace() { | |||
return m_workspace_bundle->get_workspace(m_workspace_idx); | |||
} | |||
}; | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | |||
if (param().format == param::WarpPerspective::Format::NCHW) { | |||
megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 || | |||
MEGDNN_FLOAT16_SELECT( | |||
src.dtype.enumv() == DTypeEnum::Float16, | |||
false) || | |||
src.dtype.enumv() == DTypeEnum::Int8 || | |||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||
"WarpPerspective NCHW input dtype should be " | |||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||
"/Float16", "") "."); | |||
megdnn_assert( | |||
src.dtype.enumv() == DTypeEnum::Float32 || | |||
MEGDNN_FLOAT16_SELECT( | |||
(src.dtype.enumv() == DTypeEnum::Float16 || | |||
src.dtype.enumv() == DTypeEnum::BFloat16), | |||
false) || | |||
src.dtype.enumv() == DTypeEnum::Int8 || | |||
src.dtype.enumv() == DTypeEnum::Uint8 || | |||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||
"WarpPerspective NCHW input dtype should be " | |||
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||
"/Float16/BFloat16", "") "."); | |||
megdnn_assert( | |||
(src.dtype.category() == DTypeCategory::FLOAT && | |||
(src.dtype == mat.dtype || | |||
@@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||
param::WarpPerspective::BorderMode::ISOLATED); | |||
} else { | |||
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); | |||
megdnn_assert(src.dtype == dtype::Float32() || | |||
MEGDNN_FLOAT16_SELECT( | |||
src.dtype == dtype::Float16(), false) || | |||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||
"WarpPerspective NHWCD4 input dtype should be " | |||
"Float32" MEGDNN_FLOAT16_SELECT( | |||
"/Float16", "") ",QunatizedS8, Quantized8Asymm."); | |||
megdnn_assert( | |||
src.dtype == dtype::Float32() || | |||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || | |||
src.dtype == dtype::BFloat16()), | |||
false) || | |||
src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||
"WarpPerspective NHWCD4 input dtype should be " | |||
"Float32" MEGDNN_FLOAT16_SELECT( | |||
"/Float16/BFloat16", | |||
"") ",QunatizedS8, Quantized8Asymm."); | |||
megdnn_assert( | |||
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | |||
"The input to WarpPerspective is in NHWCD4 format, in this " | |||
@@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||
} | |||
} | |||
void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes) | |||
{ | |||
void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(grad, mat, diff); | |||
megdnn_assert(grad.dtype == dtype::Float32(), | |||
"Backward WarpPerspective only supports Float32."); | |||
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
|| grad.dtype == dtype::BFloat16()), | |||
"Backward WarpPerspective only supports Float32/BFloat16."); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes) | |||
{ | |||
void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, | |||
const TensorLayout& mat, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(src, mat, diff); | |||
megdnn_assert_eq_layout(mat, grad); | |||
megdnn_assert(grad.dtype == dtype::Float32(), | |||
"Backward WarpPerspective only supports Float32."); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, | |||
mat, diff, grad); | |||
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
|| grad.dtype == dtype::BFloat16()), | |||
"Backward WarpPerspective only supports Float32/BFloat16."); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(src, mat, diff, grad); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_cond_take_kern_impls.py | |||
#include "../kern.inl" | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cond_take { | |||
inst_genidx(::megdnn::dtype::BFloat16) | |||
#undef inst_genidx | |||
inst_copy(::megdnn::dtype::BFloat16) | |||
#undef inst_copy | |||
#undef inst_copy_ | |||
} // cond_take | |||
} // cuda | |||
} // megdnn | |||
#endif |
@@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul | |||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 | |||
algo_size = all_algos.size(); | |||
for (size_t i = 0; i < algo_size; ++i) { | |||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
all_algos.push_back(bfloat16_refhold.back().get()); | |||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
} | |||
size_t all_algo_size = all_algos.size(); | |||
#if CUDA_VERSION >= 10000 | |||
fill_imma_algos(); | |||
@@ -499,6 +499,28 @@ private: | |||
}; | |||
#endif | |||
class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||
public: | |||
AlgoBFloat16(AlgoBase* impl); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
private: | |||
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | |||
TensorLayout& fsrc, TensorLayout& ffilter, | |||
TensorLayout& fbias, TensorLayout& fz, | |||
TensorLayout& fdst) const; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
}; | |||
class ConvBiasForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
@@ -508,7 +530,8 @@ public: | |||
std::vector<AlgoBase*> all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | |||
std::vector<AlgoCUDNNConv> cudnn_convs; | |||
AlgoChanwise chanwise; | |||
@@ -531,6 +554,7 @@ public: | |||
int8_chwn4_imma_unroll_width; | |||
#endif | |||
std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
@@ -0,0 +1,120 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/bfloat16.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.cuh" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace conv_bias; | |||
ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||
ConvBiasForwardImpl::AlgoBase* algorithm) | |||
: m_impl(algorithm) { | |||
megdnn_assert_internal(algorithm); | |||
m_name = ssprintf("BFLOAT16:%s", m_impl->name()); | |||
} | |||
ConvBiasForwardImpl::AlgoBase::SizeArgs | |||
ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||
const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc, | |||
TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz, | |||
TensorLayout& fdst) const { | |||
fsrc = *args.src_layout; | |||
ffilter = *args.filter_layout; | |||
fbias = *args.bias_layout; | |||
fz = *args.z_layout; | |||
fdst = *args.dst_layout; | |||
auto change_dtype = [](TensorLayout& layout) { | |||
if (layout.dtype == dtype::BFloat16()) { | |||
layout.dtype = dtype::Float32(); | |||
} | |||
}; | |||
change_dtype(fsrc); | |||
change_dtype(ffilter); | |||
change_dtype(fbias); | |||
change_dtype(fz); | |||
change_dtype(fdst); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_impl}; | |||
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | |||
} | |||
bool ConvBiasForwardImpl::AlgoBFloat16::is_available( | |||
const SizeArgs& args) const { | |||
TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
SizeArgs fargs = float_args( | |||
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||
ffilter, fbias, fz, fdst); | |||
return args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16() && | |||
m_impl->is_available(fargs); | |||
} | |||
WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
SizeArgs fargs = float_args( | |||
args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||
ffilter, fbias, fz, fdst); | |||
SmallVector<size_t> sizes; | |||
auto get_workspace = [&sizes](const TensorLayout& src, | |||
const TensorLayout& dst) { | |||
if (src.dtype != dst.dtype) { | |||
sizes.push_back(dst.span().dist_byte()); | |||
} | |||
}; | |||
get_workspace(*args.src_layout, fsrc); | |||
get_workspace(*args.filter_layout, ffilter); | |||
get_workspace(*args.bias_layout, fbias); | |||
get_workspace(*args.z_layout, fz); | |||
get_workspace(*args.dst_layout, fdst); | |||
sizes.push_back(m_impl->get_workspace_in_bytes(fargs)); | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
TensorND fsrc_tensor = *args.src_tensor; | |||
TensorND ffilter_tensor = *args.filter_tensor; | |||
TensorND fbias_tensor = *args.bias_tensor; | |||
TensorND fz_tensor = *args.z_tensor; | |||
TensorND fdst_tensor = *args.dst_tensor; | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
{ | |||
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||
.src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||
.src_to_comp_type(*args.bias_tensor, fbias_tensor) | |||
.src_to_comp_type(*args.z_tensor, fz_tensor) | |||
.src_to_comp_type(*args.dst_tensor, fdst_tensor); | |||
} | |||
{ | |||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
convbias_opr->param() = args.opr->param(); | |||
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
convbias_opr->execution_policy() = {m_impl}; | |||
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | |||
fdst_tensor, cvter.workspace()); | |||
} | |||
{ cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -20,6 +20,10 @@ using namespace conv_bias; | |||
bool ConvBiasForwardImpl::AlgoChanwise::is_available( | |||
const SizeArgs& args) const { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
if (args.z_layout->ndim > 0) | |||
return false; | |||
@@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||
bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( | |||
const SizeArgs& args) const { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
if (args.z_layout->ndim > 0) | |||
return false; | |||
#if CUDA_VERSION < 9000 | |||
@@ -23,6 +23,10 @@ using namespace conv_bias; | |||
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
const SizeArgs& args) const { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
if (args.bias_layout->ndim == 0 || | |||
args.bias_layout->eq_shape(*args.dst_layout)) | |||
return false; | |||
@@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl) | |||
bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( | |||
const SizeArgs& args) const { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) | |||
return false; | |||
auto&& param = args.opr->param(); | |||
@@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, | |||
namespace conv_bias { | |||
bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | |||
// on Tegra K1. | |||
if (args.handle->is_tegra_k1()) | |||
@@ -20,6 +20,10 @@ using namespace cuda; | |||
using namespace conv_bias; | |||
bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
if (args.z_layout->ndim > 0) | |||
return false; | |||
@@ -9,6 +9,7 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/conv_bias/helper.h" | |||
#include "src/cuda/conv_bias/algo.h" | |||
#include "src/cuda/handle.h" | |||
@@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
conv_args = orig_args; | |||
} | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda convbias fwd"); | |||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda convbias fwd"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda convbias fwd"); | |||
} | |||
} else { | |||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda convbias fwd"); | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda convbias fwd"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda convbias fwd"); | |||
} | |||
} | |||
} | |||
@@ -57,6 +57,7 @@ public: | |||
class AlgoInt8NCHW4IMMAImplicitGemm; | |||
class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | |||
class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
@@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
// add gconv algos by AlgoGroupConvGeneral | |||
auto all_algos_data = all_algos.data(); | |||
for (size_t i = 2; i < all_algos.size(); ++ i) { | |||
size_t group_algo_start = 2; | |||
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||
gconv.push_back({all_algos[i]}); | |||
} | |||
for (size_t i = 2; i < all_algos.size(); ++ i) { | |||
algo2gconv[all_algos[i]] = &gconv[i - 2]; | |||
for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||
algo2gconv[all_algos[i]] = &gconv[i - group_algo_start]; | |||
} | |||
for (auto &&i: gconv) { | |||
all_algos.push_back(&i); | |||
@@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
megdnn_assert(all_algos_data == all_algos.data()); | |||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | |||
size_t algo_size = all_algos.size(); | |||
for (size_t i=0; i<algo_size; ++i) { | |||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
all_algos.push_back(bfloat16_refhold.back().get()); | |||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
} | |||
} | |||
ConvolutionBackwardDataImpl::AlgoCUDNN* | |||
@@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
ConvolutionBackwardDataImpl *o, | |||
const TensorLayout &filter, const TensorLayout &diff, | |||
const TensorLayout &grad): | |||
SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||
SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||
{ | |||
} | |||
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
ConvolutionBackwardDataImpl *o, | |||
const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
ConvolutionBackwardDataImpl *o, const TensorLayout& filter, | |||
const CanonizedFilterMeta &filter_meta, const TensorLayout &diff, | |||
const TensorLayout &grad): | |||
handle{concrete_handle(o->handle())}, | |||
filter_meta{filter}, | |||
filter_meta{filter_meta}, | |||
diff_layout{&diff}, | |||
grad_layout{&grad}, | |||
filter_layout{&filter}, | |||
opr{o} | |||
{ | |||
} | |||
@@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *diff_layout, *grad_layout; | |||
const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
ConvolutionBackwardDataImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | |||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
} | |||
SizeArgs(ConvolutionBackwardDataImpl *opr, | |||
const TensorLayout &filter, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(ConvolutionBackwardDataImpl *opr, | |||
const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad); | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, grad_layout, filter_meta, diff_layout}; | |||
return {handle, grad_layout, filter_layout, filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
@@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||
} | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||
public: | |||
AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
private: | |||
std::string m_name; | |||
ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr; | |||
SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||
TensorLayout& fsrc, TensorLayout& ffilter, | |||
TensorLayout& fdst) const; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
//! implement group conv by another algo | |||
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
@@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack { | |||
AlgoChanwiseSmall chanwise_small; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
}; | |||
@@ -0,0 +1,115 @@ | |||
/** | |||
* \file src/cuda/convolution/backward_data/bfloat16.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./algo.h" | |||
#include "src/cuda/convolution/chanwise/kern.cuh" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace convolution; | |||
ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16( | |||
ConvolutionBackwardDataImpl::AlgoBase* algorithm) | |||
: m_algorithm(algorithm) { | |||
megdnn_assert_internal(algorithm); | |||
m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s", | |||
m_algorithm->name()); | |||
} | |||
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs | |||
ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||
const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||
TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||
ffilter = *args.filter_layout; | |||
fdiff = *args.diff_layout; | |||
fgrad = *args.grad_layout; | |||
auto change_dtype = [](TensorLayout& layout) { | |||
if (layout.dtype == dtype::BFloat16()) { | |||
layout.dtype = dtype::Float32(); | |||
} | |||
}; | |||
change_dtype(ffilter); | |||
change_dtype(fdiff); | |||
change_dtype(fgrad); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_algorithm}; | |||
return SizeArgs(opr, ffilter, fdiff, fgrad); | |||
} | |||
bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( | |||
const SizeArgs& args) const { | |||
TensorLayout ffilter, fdiff, fgrad; | |||
auto conv_back_data_opr = | |||
args.handle->create_operator<ConvolutionBackwardData>(); | |||
SizeArgs fargs = float_args( | |||
args, | |||
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||
ffilter, fdiff, fgrad); | |||
return args.diff_layout->dtype == args.filter_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16() && | |||
m_algorithm->is_available(fargs); | |||
} | |||
WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
TensorLayout ffilter, fdiff, fgrad; | |||
auto conv_back_data_opr = | |||
args.handle->create_operator<ConvolutionBackwardData>(); | |||
SizeArgs fargs = float_args( | |||
args, | |||
static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||
ffilter, fdiff, fgrad); | |||
SmallVector<size_t> sizes; | |||
auto get_workspace = [&sizes](const TensorLayout& src, | |||
const TensorLayout& dst) { | |||
if (src.dtype != dst.dtype) { | |||
sizes.push_back(dst.span().dist_byte()); | |||
} | |||
}; | |||
get_workspace(*args.filter_layout, ffilter); | |||
get_workspace(*args.diff_layout, fdiff); | |||
get_workspace(*args.grad_layout, fgrad); | |||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||
const ExecArgs& args) const { | |||
TensorND ffilter_tensor = *args.filter_tensor; | |||
TensorND fdiff_tensor = *args.diff_tensor; | |||
TensorND fgrad_tensor = *args.grad_tensor; | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
{ | |||
cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||
.src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||
} | |||
{ | |||
auto conv_back_data_opr = | |||
args.handle->create_operator<ConvolutionBackwardData>(); | |||
conv_back_data_opr->param() = args.opr->param(); | |||
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
conv_back_data_opr->execution_policy() = {m_algorithm}; | |||
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | |||
cvter.workspace()); | |||
} | |||
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -19,6 +19,10 @@ using namespace convolution; | |||
bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( | |||
const SizeArgs& args) const { | |||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto&& fm = args.filter_meta; | |||
return args.filter_meta.format == Param::Format::NCHW && | |||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
@@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||
bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( | |||
const SizeArgs &args) const { | |||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
#if CUDA_VERSION < 9000 | |||
if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) | |||
return false; | |||
@@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||
bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( | |||
const SizeArgs &args) const { | |||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto sub_args = args; | |||
TensorLayout diff_pg, grad_pg; | |||
modify_size_args(sub_args, diff_pg, grad_pg); | |||
@@ -20,6 +20,10 @@ using namespace cuda; | |||
bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( | |||
const SizeArgs &args) const { | |||
if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto &&fm = args.filter_meta; | |||
return args.filter_meta.format == Param::Format::NCHW && | |||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
@@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
megdnn_assert(all_algos_data == all_algos.data()); | |||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | |||
size_t algo_size = all_algos.size(); | |||
for (size_t i=0; i<algo_size; ++i) { | |||
bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
all_algos.push_back(bfloat16_refhold.back().get()); | |||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
} | |||
} | |||
ConvolutionBackwardFilterImpl::AlgoCUDNN* | |||
@@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
ConvolutionBackwardFilterImpl *o, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const TensorLayout &grad): | |||
SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff)) | |||
SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff)) | |||
{ | |||
} | |||
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
ConvolutionBackwardFilterImpl *o, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const CanonizedFilterMeta &grad): | |||
handle{concrete_handle(o->handle())}, | |||
src_layout{&src}, | |||
diff_layout{&diff}, | |||
grad_filter_meta{grad}, | |||
opr{o} | |||
{ | |||
} | |||
ConvolutionBackwardFilterImpl* o, const TensorLayout& src, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta) | |||
: handle{concrete_handle(o->handle())}, | |||
src_layout{&src}, | |||
diff_layout{&diff}, | |||
grad_layout{&grad}, | |||
grad_filter_meta{grad_meta}, | |||
opr{o} {} | |||
ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( | |||
ConvolutionBackwardFilterImpl *opr, | |||
@@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
public: | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
const TensorLayout *src_layout, *diff_layout; | |||
const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
CanonizedFilterMeta grad_filter_meta; | |||
ConvolutionBackwardFilterImpl *opr; | |||
@@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const CanonizedFilterMeta &grad); | |||
SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, src_layout, grad_filter_meta, diff_layout}; | |||
return {handle, src_layout, grad_layout, grad_filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
@@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
} | |||
}; | |||
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | |||
public: | |||
AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
private: | |||
std::string m_name; | |||
ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr; | |||
SizeArgs float_args(const SizeArgs& args, | |||
ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc, | |||
TensorLayout& ffilter, TensorLayout& fdst) const; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
//! implement group conv by another algo | |||
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
@@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack { | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
}; | |||
@@ -0,0 +1,117 @@ | |||
/** | |||
* \file src/cuda/convolution/backward_filter/bfloat16.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./algo.h" | |||
#include "src/cuda/convolution/chanwise/kern.cuh" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace convolution; | |||
ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16( | |||
ConvolutionBackwardFilterImpl::AlgoBase* algorithm) | |||
: m_algorithm(algorithm) { | |||
megdnn_assert_internal(algorithm); | |||
m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s", | |||
m_algorithm->name()); | |||
} | |||
ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs | |||
ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||
const SizeArgs& args, ConvolutionBackwardFilterImpl* opr, | |||
TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||
fsrc = *args.src_layout; | |||
fdiff = *args.diff_layout; | |||
fgrad = *args.grad_layout; | |||
auto change_dtype = [](TensorLayout& layout) { | |||
if (layout.dtype == dtype::BFloat16()) { | |||
layout.dtype = dtype::Float32(); | |||
} | |||
}; | |||
change_dtype(fsrc); | |||
change_dtype(fdiff); | |||
change_dtype(fgrad); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_algorithm}; | |||
return SizeArgs(opr, fsrc, fdiff, fgrad); | |||
} | |||
bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( | |||
const SizeArgs& args) const { | |||
TensorLayout fsrc, fdiff, fgrad; | |||
auto conv_back_filter_opr = | |||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
SizeArgs fargs = float_args(args, | |||
static_cast<ConvolutionBackwardFilterImpl*>( | |||
conv_back_filter_opr.get()), | |||
fsrc, fdiff, fgrad); | |||
return args.src_layout->dtype == args.diff_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16() && | |||
m_algorithm->is_available(fargs); | |||
} | |||
WorkspaceBundle | |||
ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
TensorLayout fsrc, fdiff, fgrad; | |||
auto conv_back_filter_opr = | |||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
SizeArgs fargs = float_args(args, | |||
static_cast<ConvolutionBackwardFilterImpl*>( | |||
conv_back_filter_opr.get()), | |||
fsrc, fdiff, fgrad); | |||
SmallVector<size_t> sizes; | |||
auto get_workspace = [&sizes](const TensorLayout& src, | |||
const TensorLayout& dst) { | |||
if (src.dtype != dst.dtype) { | |||
sizes.push_back(dst.span().dist_byte()); | |||
} | |||
}; | |||
get_workspace(*args.src_layout, fsrc); | |||
get_workspace(*args.diff_layout, fdiff); | |||
get_workspace(*args.grad_layout, fgrad); | |||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||
const ExecArgs& args) const { | |||
TensorND fsrc_tensor = *args.src_tensor; | |||
TensorND fdiff_tensor = *args.diff_tensor; | |||
TensorND fgrad_tensor = *args.grad_tensor; | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
{ | |||
cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||
.src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||
.src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||
} | |||
{ | |||
auto conv_back_filter_opr = | |||
args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
conv_back_filter_opr->param() = args.opr->param(); | |||
conv_back_filter_opr->param().compute_mode = | |||
Param::ComputeMode::DEFAULT; | |||
conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | |||
cvter.workspace()); | |||
} | |||
{ cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -19,6 +19,10 @@ using namespace convolution; | |||
bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( | |||
const SizeArgs &args) const { | |||
if (args.src_layout->dtype == args.src_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto &&fm = args.grad_filter_meta; | |||
return fm.format == Param::Format::NCHW && | |||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
@@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||
bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | |||
const SizeArgs &args) const { | |||
if (args.src_layout->dtype == args.src_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto sub_args = args; | |||
TensorLayout src_pg, diff_pg; | |||
modify_size_args(sub_args, src_pg, diff_pg); | |||
@@ -19,6 +19,10 @@ using namespace cuda; | |||
bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( | |||
const SizeArgs &args) const { | |||
if (args.src_layout->dtype == args.src_layout->dtype && | |||
args.diff_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
auto &&fm = args.grad_filter_meta; | |||
return fm.format == Param::Format::NCHW && | |||
args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
@@ -16,6 +16,10 @@ using namespace cuda; | |||
using namespace convolution; | |||
bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { | |||
if (args.src_layout->dtype == args.filter_layout->dtype && | |||
args.src_layout->dtype == dtype::BFloat16()) { | |||
return false; | |||
} | |||
// CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | |||
// on Tegra K1. | |||
@@ -25,6 +25,7 @@ namespace convolution { | |||
struct ForwardSizeArgs { | |||
HandleImpl *handle; | |||
const TensorLayout *src_layout; | |||
const TensorLayout *filter_layout; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *dst_layout; | |||
}; | |||
@@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); | |||
auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout); | |||
auto algo = get_algorithm(this, filter.layout, args.filter_meta, | |||
diff.layout, grad.layout); | |||
algo->check_workspace(args, workspace).exec(args); | |||
} | |||
@@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
auto fm = check_layout_fwd(grad, filter, diff); | |||
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | |||
reproducible); | |||
return get_algorithm_heuristic(filter, fm, diff, grad, | |||
workspace_limit_in_bytes, reproducible); | |||
} | |||
ConvolutionBackwardDataImpl::Algorithm* | |||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | |||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | |||
if (args.filter_meta.group > 1 && | |||
sm_algo_pack.chanwise.is_available_reproducible( | |||
@@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
args = orig_args; | |||
} | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_data"); | |||
if (args.filter_layout->dtype.enumv() != | |||
DTypeTrait<dtype::BFloat16>::enumv) { | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda conv bwd_data"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda conv bwd_data"); | |||
} | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_data"); | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_data"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_data"); | |||
} | |||
} | |||
} | |||
@@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) { | |||
AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
return get_algorithm(this, args.filter_meta, diff, grad)-> | |||
return get_algorithm(this, filter, args.filter_meta, diff, grad)-> | |||
get_workspace_in_bytes(args); | |||
} | |||
@@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_workspace workspace) { | |||
AlgoBase::ExecArgs args(this, src, diff, grad, workspace); | |||
auto algo = get_algorithm(this, src.layout, diff.layout, | |||
args.grad_filter_meta); | |||
grad.layout, args.grad_filter_meta); | |||
algo->check_workspace(args, workspace).exec(args); | |||
} | |||
@@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
auto fm = check_layout_fwd(src, grad, diff); | |||
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | |||
reproducible); | |||
return get_algorithm_heuristic(src, diff, grad, fm, | |||
workspace_limit_in_bytes, reproducible); | |||
} | |||
ConvolutionBackwardFilterImpl::Algorithm* | |||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
AlgoBase::SizeArgs args(this, src, diff, grad); | |||
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | |||
if (args.grad_filter_meta.group > 1 && | |||
sm_algo_pack.chanwise.is_available_reproducible( | |||
@@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
args = orig_args; | |||
} | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_filter"); | |||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, | |||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||
} | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_filter"); | |||
if (reproducible) { | |||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_filter"); | |||
} else { | |||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
"cuda conv bwd_filter"); | |||
} | |||
} | |||
} | |||
@@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) { | |||
AlgoBase::SizeArgs args(this, src, diff, grad); | |||
return get_algorithm(this, src, diff, args.grad_filter_meta)-> | |||
return get_algorithm(this, src, diff, grad, args.grad_filter_meta)-> | |||
get_workspace_in_bytes(args); | |||
} | |||
@@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
@@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
class AlgoChanwise; | |||
class AlgoChanwiseSmall; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
@@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const CanonizedFilterMeta& grad, | |||
const TensorLayout& gradk, | |||
const CanonizedFilterMeta& grad_meta, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
@@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
class AlgoMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
@@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||
const CanonizedFilterMeta &filter, | |||
const TensorLayout &dst); | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
Workspace workspace; | |||
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||
#define KERN_IMPL_ARITY 3 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||
#define KERN_IMPL_ARITY 3 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||
#define KERN_IMPL_ARITY 1 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,17 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -0,0 +1,18 @@ | |||
/** | |||
* \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// generated by gen_elemwise_special_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#include "../special_kerns.inl" | |||
INST(::megdnn::dtype::BFloat16) | |||
#undef INST | |||
} | |||
} | |||
#endif |
@@ -141,6 +141,9 @@ INST_FOR_CTYPE | |||
#define ct dt_float16 | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#define ct dt_bfloat16 | |||
INST_FOR_CTYPE | |||
#undef ct | |||
#define ct dt_int8 | |||
INST_FOR_CTYPE | |||
#undef ct | |||
@@ -68,6 +68,17 @@ namespace elemwise_intl { | |||
return t; | |||
} | |||
struct __attribute__((aligned(8))) bhalf4 { | |||
dt_bfloat16 x, y, z, w; | |||
}; | |||
__device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, | |||
dt_bfloat16 z, dt_bfloat16 w) { | |||
bhalf4 t; | |||
t.x = x, t.y = y, t.z = z, t.w = w; | |||
return t; | |||
} | |||
#define INST(_ctype, _vect_type) \ | |||
template <> \ | |||
class VectTypeTrait<_ctype> { \ | |||
@@ -87,6 +98,7 @@ namespace elemwise_intl { | |||
INST(dt_uint8, uchar4); | |||
INST(dt_float32, float4); | |||
INST(dt_float16, half4); | |||
INST(dt_bfloat16, bhalf4); | |||
INST(dt_int32, int4); | |||
INST(dt_int16, short4); | |||
#undef as_raw | |||
@@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | |||
__trap(); | |||
((int*)0)[0] = 1; | |||
} | |||
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | |||
__trap(); | |||
((int*)0)[0] = 1; | |||
} | |||
#endif | |||
__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { | |||
@@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&cublas_lt); | |||
#endif | |||
all_algos.push_back(&naive); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | |||
all_algos.push_back(cublas_bfloat16.get()); | |||
#endif | |||
} | |||
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
@@ -15,6 +15,7 @@ | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
#include <cuda.h> | |||
#include <memory> | |||
#if CUDA_VERSION >= 10010 | |||
#include <cublasLt.h> | |||
#endif | |||
@@ -140,6 +141,24 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||
public: | |||
AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return true; } | |||
private: | |||
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||
std::string m_name; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
SizeArgs float_args(const SizeArgs& args) const; | |||
}; | |||
#endif | |||
class MatrixMulForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
@@ -154,7 +173,9 @@ public: | |||
#if CUDA_VERSION >= 10010 | |||
AlgoCuBlasLt cublas_lt; | |||
#endif | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | |||
#endif | |||
std::vector<AlgoBase*> all_algos; | |||
}; | |||
@@ -0,0 +1,91 @@ | |||
/** | |||
* \file dnn/src/cuda/matrix_mul/bfloat16.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/matrix_mul/algos.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||
MatrixMulForwardImpl::AlgoBase* algorithm) | |||
: m_algorithm(algorithm) { | |||
megdnn_assert_internal(algorithm); | |||
m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); | |||
} | |||
MatrixMulForwardImpl::AlgoBase::SizeArgs | |||
MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { | |||
auto new_args = args; | |||
auto change_dtype = [](TensorLayout& layout) { | |||
if (layout.dtype == dtype::BFloat16()) { | |||
layout.dtype = dtype::Float32(); | |||
} | |||
}; | |||
change_dtype(new_args.layout_a); | |||
change_dtype(new_args.layout_b); | |||
change_dtype(new_args.layout_c); | |||
return new_args; | |||
} | |||
bool MatrixMulForwardImpl::AlgoBFloat16::is_available( | |||
const SizeArgs& args) const { | |||
auto fargs = float_args(args); | |||
return args.layout_a.dtype == dtype::BFloat16() && | |||
m_algorithm->is_available(fargs); | |||
} | |||
WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||
void* ptr, const SizeArgs& args) const { | |||
auto fargs = float_args(args); | |||
SmallVector<size_t> sizes; | |||
auto get_workspace = [&sizes](const TensorLayout& src) { | |||
TensorLayout dst = src; | |||
if (dst.dtype == dtype::BFloat16()) { | |||
dst.dtype = dtype::Float32(); | |||
sizes.push_back(dst.span().dist_byte()); | |||
} | |||
}; | |||
get_workspace(args.layout_a); | |||
get_workspace(args.layout_b); | |||
get_workspace(args.layout_c); | |||
sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
return {ptr, std::move(sizes)}; | |||
} | |||
size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
} | |||
void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
TensorND a = args.tensor_a; | |||
TensorND b = args.tensor_b; | |||
TensorND c = args.tensor_c; | |||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>( | |||
args.opr->handle(), &bundle); | |||
ctypecvt.src_to_comp_type(args.tensor_a, a) | |||
.src_to_comp_type(args.tensor_b, b) | |||
.src_to_comp_type(args.tensor_c, c); | |||
{ | |||
auto matmul_opr = | |||
args.opr->handle()->create_operator<MatrixMulForward>(); | |||
matmul_opr->param() = args.opr->param(); | |||
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
matmul_opr->execution_policy() = {m_algorithm}; | |||
matmul_opr->exec(a, b, c, ctypecvt.workspace()); | |||
} | |||
ctypecvt.comp_to_dst_type(c, args.tensor_c); | |||
} | |||
// vim: syntax=cpp.doxygen |