GitOrigin-RevId: ac757ca307
release-1.7
@@ -53,6 +53,18 @@ std::vector<TestArg> get_args() { | |||||
TensorShape{1, 3, 1, 1}, dtype::Float16()); | TensorShape{1, 3, 1, 1}, dtype::Float16()); | ||||
} | } | ||||
// case 3: 1 x 1 x 1 x C | |||||
for (size_t i = 4; i < 257; i *= 4) { | |||||
param::BN param; | |||||
param.fwd_mode = param::BN::FwdMode::TRAINING; | |||||
param.param_dim = param::BN::ParamDim::DIM_111C; | |||||
args.emplace_back(param, TensorShape{3, i, i, 3}, | |||||
TensorShape{1, 1, 1, 3}, dtype::Float32()); | |||||
args.emplace_back(param, TensorShape{3, i, i, 3}, | |||||
TensorShape{1, 1, 1, 3}, dtype::Float16()); | |||||
} | |||||
return args; | return args; | ||||
} | } | ||||
@@ -60,4 +72,4 @@ std::vector<TestArg> get_args() { | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest, | |||||
void CheckerHelper::check_tensors(const TensorValueArray& expected, | void CheckerHelper::check_tensors(const TensorValueArray& expected, | ||||
const TensorValueArray& computed) { | const TensorValueArray& computed) { | ||||
for (size_t i = 0; i < expected.size(); ++i) { | for (size_t i = 0; i < expected.size(); ++i) { | ||||
if (expected[i].layout.ndim == 0) | |||||
if (expected[i].layout.ndim == 0 || m_bypass.find(i) != m_bypass.end()) | |||||
continue; | continue; | ||||
if (m_allow_invalid_check) { | if (m_allow_invalid_check) { | ||||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | ||||
@@ -69,6 +69,7 @@ protected: | |||||
std::unordered_map<size_t, RNG*> m_rng; | std::unordered_map<size_t, RNG*> m_rng; | ||||
std::unordered_map<size_t, DType> m_dtype; | std::unordered_map<size_t, DType> m_dtype; | ||||
std::unordered_map<size_t, TensorFormat> m_fmt; | std::unordered_map<size_t, TensorFormat> m_fmt; | ||||
std::set<size_t> m_bypass; | |||||
float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, | float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, | ||||
m_max_avg_biased_error = 1e-3; | m_max_avg_biased_error = 1e-3; | ||||
float_t m_perf_check_threshold = -1; | float_t m_perf_check_threshold = -1; | ||||
@@ -184,6 +185,10 @@ public: | |||||
m_rng[idx] = rng; | m_rng[idx] = rng; | ||||
return *this; | return *this; | ||||
} | } | ||||
Checker& set_bypass(size_t idx) { | |||||
m_bypass.insert(idx); | |||||
return *this; | |||||
} | |||||
//! max error of a single element | //! max error of a single element | ||||
Checker& set_epsilon(dt_float32 epsilon) { | Checker& set_epsilon(dt_float32 epsilon) { | ||||
m_epsilon = epsilon; | m_epsilon = epsilon; | ||||
@@ -82,6 +82,15 @@ struct DeduceLayoutProxy<Opr, 8, true> { | |||||
} | } | ||||
}; | }; | ||||
template <typename Opr> | |||||
struct DeduceLayoutProxy<Opr, 9, true> { | |||||
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||||
megdnn_assert(layouts.size() == 9); | |||||
opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], | |||||
layouts[4], layouts[5], layouts[6], layouts[7], | |||||
layouts[8]); | |||||
} | |||||
}; | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -23,6 +23,23 @@ template <typename Opr, size_t Arity, bool has_workspace> | |||||
struct ExecProxy; | struct ExecProxy; | ||||
template <typename Opr> | template <typename Opr> | ||||
struct ExecProxy<Opr, 9, true> { | |||||
WorkspaceWrapper W; | |||||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||||
if (!W.valid()) { | |||||
W = WorkspaceWrapper(opr->handle(), 0); | |||||
} | |||||
W.update(opr->get_workspace_in_bytes( | |||||
tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||||
tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||||
tensors[6].layout, tensors[7].layout, tensors[8].layout)); | |||||
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||||
tensors[5], tensors[6], tensors[7], tensors[8], | |||||
W.workspace()); | |||||
} | |||||
}; | |||||
template <typename Opr> | |||||
struct ExecProxy<Opr, 8, true> { | struct ExecProxy<Opr, 8, true> { | ||||
WorkspaceWrapper W; | WorkspaceWrapper W; | ||||
void exec(Opr* opr, const TensorNDArray& tensors) { | void exec(Opr* opr, const TensorNDArray& tensors) { | ||||
@@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) { | |||||
} | } | ||||
return; | return; | ||||
} | } | ||||
if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) { | |||||
memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); | |||||
return; | |||||
} | |||||
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | ||||
tensor.layout.dtype.name()); | tensor.layout.dtype.name()); | ||||
} | } | ||||
@@ -6,10 +6,13 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
#include "src/cuda/batch_normalization/opr_impl.h" | |||||
#include "src/cuda/utils.h" | |||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/common/bn.h" | #include "test/common/bn.h" | ||||
@@ -21,15 +24,26 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
TEST_F(CUDA, BN_FORWARD) { | |||||
TEST_F(CUDA, BN_FORWARD_BACKWARD) { | |||||
using namespace batch_normalization; | using namespace batch_normalization; | ||||
using cuda::cudnn_handle; | |||||
using cuda::batch_normalization::BNTensorDescHolder; | |||||
using cuda::batch_normalization::get_reserve_size; | |||||
std::vector<TestArg> args = get_args(); | std::vector<TestArg> args = get_args(); | ||||
Checker<BNForward> checker(handle_cuda()); | Checker<BNForward> checker(handle_cuda()); | ||||
Checker<BNBackward> checker_bwd(handle_cuda()); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
for (int i = 0; i < 8; ++i) { | |||||
auto tensor_desc = BNTensorDescHolder({arg.src, arg.dtype}, arg.param.param_dim, | |||||
arg.param.fwd_mode); | |||||
auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc); | |||||
// Forward | |||||
for (int i = 0; i < 9; ++i) { | |||||
checker.set_dtype(i, dtype::Float32()); | checker.set_dtype(i, dtype::Float32()); | ||||
} | } | ||||
checker.set_dtype(0, arg.dtype); | checker.set_dtype(0, arg.dtype); | ||||
checker.set_dtype(7, dtype::Byte()); | |||||
checker.set_dtype(8, arg.dtype); | |||||
checker.set_bypass(7); | |||||
checker.set_epsilon(1e-3).set_param(arg.param); | checker.set_epsilon(1e-3).set_param(arg.param); | ||||
for (bool need_statistic : {false, true}) | for (bool need_statistic : {false, true}) | ||||
checker.exec({ | checker.exec({ | ||||
@@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) { | |||||
: TensorShape({0}), // mean | : TensorShape({0}), // mean | ||||
need_statistic ? arg.param_shape | need_statistic ? arg.param_shape | ||||
: TensorShape({0}), // variance | : TensorShape({0}), // variance | ||||
arg.param_shape, // batch_mean | |||||
arg.param_shape, // batch_inv_variance | |||||
{} // dst | |||||
arg.param_shape, // batch_mean | |||||
arg.param_shape, // batch_inv_variance | |||||
{reserve}, // reserve | |||||
arg.src // dst | |||||
}); | }); | ||||
} | |||||
} | |||||
TEST_F(CUDA, BN_BACKWARD) { | |||||
using namespace batch_normalization; | |||||
std::vector<TestArg> args = get_args(); | |||||
Checker<BNBackward> checker(handle_cuda()); | |||||
for (auto&& arg : args) { | |||||
for (int i = 0; i < 8; ++i) { | |||||
checker.set_dtype(i, dtype::Float32()); | |||||
// Backward | |||||
for (int i = 0; i < 9; ++i) { | |||||
checker_bwd.set_dtype(i, dtype::Float32()); | |||||
} | } | ||||
checker.set_dtype(0, arg.dtype) // x | |||||
.set_dtype(1, arg.dtype) // dy | |||||
.set_dtype(7, arg.dtype); // dx | |||||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||||
checker_bwd | |||||
.set_dtype(0, arg.dtype) // x | |||||
.set_dtype(1, arg.dtype) // dy | |||||
.set_dtype(5, dtype::Byte()) // reserve | |||||
.set_dtype(8, arg.dtype) // dx | |||||
.set_bypass(5); | |||||
checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec( | |||||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | {arg.src, arg.src, arg.param_shape, arg.param_shape, | ||||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||||
arg.param_shape, {reserve}, arg.param_shape, arg.param_shape, | |||||
arg.src}); | |||||
} | } | ||||
} | } | ||||
@@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) { | |||||
checker.set_dtype(i, dtype::Float32()); | checker.set_dtype(i, dtype::Float32()); | ||||
} | } | ||||
checker.set_dtype(0, arg.dtype); | checker.set_dtype(0, arg.dtype); | ||||
checker.set_dtype(8, arg.dtype); | |||||
checker.set_epsilon(1e-3).set_param(arg.param); | checker.set_epsilon(1e-3).set_param(arg.param); | ||||
for (bool need_statistic : {false, true}) | for (bool need_statistic : {false, true}) | ||||
checker.exec({ | checker.exec({ | ||||
@@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) { | |||||
: TensorShape({0}), // variance | : TensorShape({0}), // variance | ||||
arg.param_shape, // batch_mean | arg.param_shape, // batch_mean | ||||
arg.param_shape, // batch_inv_variance | arg.param_shape, // batch_inv_variance | ||||
{} // dst | |||||
{0}, // reserve | |||||
arg.src // dst | |||||
}); | }); | ||||
} | } | ||||
} | } | ||||
@@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) { | |||||
std::vector<TestArg> args = get_args(); | std::vector<TestArg> args = get_args(); | ||||
Checker<BNBackward> checker(handle_rocm()); | Checker<BNBackward> checker(handle_rocm()); | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
for (int i = 0; i < 8; ++i) { | |||||
for (int i = 0; i < 9; ++i) { | |||||
checker.set_dtype(i, dtype::Float32()); | checker.set_dtype(i, dtype::Float32()); | ||||
} | } | ||||
checker.set_dtype(0, arg.dtype) // x | checker.set_dtype(0, arg.dtype) // x | ||||
.set_dtype(1, arg.dtype) // dy | .set_dtype(1, arg.dtype) // dy | ||||
.set_dtype(7, arg.dtype); // dx | |||||
.set_dtype(8, arg.dtype); // dx | |||||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | checker.set_epsilon(1e-3).set_param(arg.param).exec( | ||||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | {arg.src, arg.src, arg.param_shape, arg.param_shape, | ||||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||||
arg.param_shape, {0}, arg.param_shape, arg.param_shape, | |||||
arg.src}); | |||||
} | } | ||||
} | } | ||||