GitOrigin-RevId: ac757ca307
release-1.7
@@ -53,6 +53,18 @@ std::vector<TestArg> get_args() { | |||
TensorShape{1, 3, 1, 1}, dtype::Float16()); | |||
} | |||
// case 3: 1 x 1 x 1 x C | |||
for (size_t i = 4; i < 257; i *= 4) { | |||
param::BN param; | |||
param.fwd_mode = param::BN::FwdMode::TRAINING; | |||
param.param_dim = param::BN::ParamDim::DIM_111C; | |||
args.emplace_back(param, TensorShape{3, i, i, 3}, | |||
TensorShape{1, 1, 1, 3}, dtype::Float32()); | |||
args.emplace_back(param, TensorShape{3, i, i, 3}, | |||
TensorShape{1, 1, 1, 3}, dtype::Float16()); | |||
} | |||
return args; | |||
} | |||
@@ -60,4 +72,4 @@ std::vector<TestArg> get_args() { | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
// vim: syntax=cpp.doxygen |
@@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest, | |||
void CheckerHelper::check_tensors(const TensorValueArray& expected, | |||
const TensorValueArray& computed) { | |||
for (size_t i = 0; i < expected.size(); ++i) { | |||
if (expected[i].layout.ndim == 0) | |||
if (expected[i].layout.ndim == 0 || m_bypass.find(i) != m_bypass.end()) | |||
continue; | |||
if (m_allow_invalid_check) { | |||
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | |||
@@ -69,6 +69,7 @@ protected: | |||
std::unordered_map<size_t, RNG*> m_rng; | |||
std::unordered_map<size_t, DType> m_dtype; | |||
std::unordered_map<size_t, TensorFormat> m_fmt; | |||
std::set<size_t> m_bypass; | |||
float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, | |||
m_max_avg_biased_error = 1e-3; | |||
float_t m_perf_check_threshold = -1; | |||
@@ -184,6 +185,10 @@ public: | |||
m_rng[idx] = rng; | |||
return *this; | |||
} | |||
Checker& set_bypass(size_t idx) { | |||
m_bypass.insert(idx); | |||
return *this; | |||
} | |||
//! max error of a single element | |||
Checker& set_epsilon(dt_float32 epsilon) { | |||
m_epsilon = epsilon; | |||
@@ -82,6 +82,15 @@ struct DeduceLayoutProxy<Opr, 8, true> { | |||
} | |||
}; | |||
template <typename Opr> | |||
struct DeduceLayoutProxy<Opr, 9, true> { | |||
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||
megdnn_assert(layouts.size() == 9); | |||
opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], | |||
layouts[4], layouts[5], layouts[6], layouts[7], | |||
layouts[8]); | |||
} | |||
}; | |||
} // namespace test | |||
} // namespace megdnn | |||
@@ -23,6 +23,23 @@ template <typename Opr, size_t Arity, bool has_workspace> | |||
struct ExecProxy; | |||
template <typename Opr> | |||
struct ExecProxy<Opr, 9, true> { | |||
WorkspaceWrapper W; | |||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||
if (!W.valid()) { | |||
W = WorkspaceWrapper(opr->handle(), 0); | |||
} | |||
W.update(opr->get_workspace_in_bytes( | |||
tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||
tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||
tensors[6].layout, tensors[7].layout, tensors[8].layout)); | |||
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||
tensors[5], tensors[6], tensors[7], tensors[8], | |||
W.workspace()); | |||
} | |||
}; | |||
template <typename Opr> | |||
struct ExecProxy<Opr, 8, true> { | |||
WorkspaceWrapper W; | |||
void exec(Opr* opr, const TensorNDArray& tensors) { | |||
@@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
} | |||
return; | |||
} | |||
if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) { | |||
memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); | |||
return; | |||
} | |||
megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | |||
tensor.layout.dtype.name()); | |||
} | |||
@@ -6,10 +6,13 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/cuda/fixture.h" | |||
#include "src/cuda/batch_normalization/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/bn.h" | |||
@@ -21,15 +24,26 @@ | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, BN_FORWARD) { | |||
TEST_F(CUDA, BN_FORWARD_BACKWARD) { | |||
using namespace batch_normalization; | |||
using cuda::cudnn_handle; | |||
using cuda::batch_normalization::BNTensorDescHolder; | |||
using cuda::batch_normalization::get_reserve_size; | |||
std::vector<TestArg> args = get_args(); | |||
Checker<BNForward> checker(handle_cuda()); | |||
Checker<BNBackward> checker_bwd(handle_cuda()); | |||
for (auto&& arg : args) { | |||
for (int i = 0; i < 8; ++i) { | |||
auto tensor_desc = BNTensorDescHolder({arg.src, arg.dtype}, arg.param.param_dim, | |||
arg.param.fwd_mode); | |||
auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc); | |||
// Forward | |||
for (int i = 0; i < 9; ++i) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype); | |||
checker.set_dtype(7, dtype::Byte()); | |||
checker.set_dtype(8, arg.dtype); | |||
checker.set_bypass(7); | |||
checker.set_epsilon(1e-3).set_param(arg.param); | |||
for (bool need_statistic : {false, true}) | |||
checker.exec({ | |||
@@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) { | |||
: TensorShape({0}), // mean | |||
need_statistic ? arg.param_shape | |||
: TensorShape({0}), // variance | |||
arg.param_shape, // batch_mean | |||
arg.param_shape, // batch_inv_variance | |||
{} // dst | |||
arg.param_shape, // batch_mean | |||
arg.param_shape, // batch_inv_variance | |||
{reserve}, // reserve | |||
arg.src // dst | |||
}); | |||
} | |||
} | |||
TEST_F(CUDA, BN_BACKWARD) { | |||
using namespace batch_normalization; | |||
std::vector<TestArg> args = get_args(); | |||
Checker<BNBackward> checker(handle_cuda()); | |||
for (auto&& arg : args) { | |||
for (int i = 0; i < 8; ++i) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
// Backward | |||
for (int i = 0; i < 9; ++i) { | |||
checker_bwd.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype) // x | |||
.set_dtype(1, arg.dtype) // dy | |||
.set_dtype(7, arg.dtype); // dx | |||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
checker_bwd | |||
.set_dtype(0, arg.dtype) // x | |||
.set_dtype(1, arg.dtype) // dy | |||
.set_dtype(5, dtype::Byte()) // reserve | |||
.set_dtype(8, arg.dtype) // dx | |||
.set_bypass(5); | |||
checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec( | |||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | |||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||
arg.param_shape, {reserve}, arg.param_shape, arg.param_shape, | |||
arg.src}); | |||
} | |||
} | |||
@@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype); | |||
checker.set_dtype(8, arg.dtype); | |||
checker.set_epsilon(1e-3).set_param(arg.param); | |||
for (bool need_statistic : {false, true}) | |||
checker.exec({ | |||
@@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) { | |||
: TensorShape({0}), // variance | |||
arg.param_shape, // batch_mean | |||
arg.param_shape, // batch_inv_variance | |||
{} // dst | |||
{0}, // reserve | |||
arg.src // dst | |||
}); | |||
} | |||
} | |||
@@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) { | |||
std::vector<TestArg> args = get_args(); | |||
Checker<BNBackward> checker(handle_rocm()); | |||
for (auto&& arg : args) { | |||
for (int i = 0; i < 8; ++i) { | |||
for (int i = 0; i < 9; ++i) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype) // x | |||
.set_dtype(1, arg.dtype) // dy | |||
.set_dtype(7, arg.dtype); // dx | |||
.set_dtype(8, arg.dtype); // dx | |||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | |||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||
arg.param_shape, {0}, arg.param_shape, arg.param_shape, | |||
arg.src}); | |||
} | |||
} | |||