@@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) | |||
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) | |||
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) | |||
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") | |||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||
endif() | |||
if (NOT ${HIP_VERSION_MINOR} STREQUAL "7") | |||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||
message(FATAL_ERROR "ROCM version needed 3.x, Please update ROCM.") | |||
else() | |||
if (${HIP_VERSION_MINOR} LESS "7") | |||
message(WARNING "ROCM version 3.x which x(got ${HIP_VERSION_MINOR}) greater equal 7 is prefered.") | |||
endif() | |||
endif() | |||
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) | |||
@@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR | |||
DOC "Path to MIOPEN library directory." ) | |||
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||
endif() | |||
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) | |||
@@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR | |||
DOC "Path to MIOPEN include directory." ) | |||
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||
endif() | |||
#rocblas | |||
@@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR | |||
DOC "Path to ROCBLAS library directory." ) | |||
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||
endif() | |||
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) | |||
@@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR | |||
DOC "Path to ROCBLAS include directory." ) | |||
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||
endif() | |||
#rocrand | |||
@@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR | |||
DOC "Path to ROCRAND library directory." ) | |||
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||
endif() | |||
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) | |||
@@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR | |||
DOC "Path to ROCRAND include directory." ) | |||
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||
endif() | |||
@@ -0,0 +1,116 @@ | |||
/** | |||
* \file dnn/src/rocm/batch_normalization/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "src/rocm/utils.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace batch_normalization { | |||
void BNTensorDescHolder::setup(const TensorLayout& x, | |||
const ParamDim& param_dim) { | |||
TensorShape xy_shape(x); | |||
switch (param_dim) { | |||
case ParamDim::DIM_11HW: | |||
// xy: N, C, H, W --> (N*C), 1, H, W | |||
xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1]; | |||
xy_shape.shape[1] = 1; | |||
bn_mode = miopenBNPerActivation; | |||
break; | |||
case ParamDim::DIM_1CHW: | |||
bn_mode = miopenBNPerActivation; | |||
break; | |||
case ParamDim::DIM_1C11: | |||
bn_mode = miopenBNSpatial; | |||
break; | |||
default: | |||
megdnn_throw(megdnn_mangle( | |||
"Unknown param dim type of batch normalization.")); | |||
} | |||
xy_desc.set(TensorLayout(xy_shape, x.dtype)); | |||
param_desc.set(xy_desc.desc, bn_mode); | |||
} | |||
} // namespace batch_normalization | |||
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, | |||
_megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | |||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | |||
dst.layout, workspace.size); | |||
auto handle = concrete_handle(this->handle())->miopen_handle(); | |||
m_tensor_desc.setup(src.layout, m_param.param_dim); | |||
float alpha = 1.0f, beta = 0.0f; | |||
switch (m_param.fwd_mode) { | |||
case param::BN::FwdMode::TRAINING: | |||
miopen_check(miopenBatchNormalizationForwardTraining( | |||
handle, m_tensor_desc.bn_mode, &alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, // xDesc | |||
src.raw_ptr, // x | |||
m_tensor_desc.xy_desc.desc, // yDesc | |||
dst.raw_ptr, // y | |||
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | |||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | |||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); | |||
break; | |||
case param::BN::FwdMode::INFERENCE: | |||
miopen_check(miopenBatchNormalizationForwardInference( | |||
handle, m_tensor_desc.bn_mode, &alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, src.raw_ptr, | |||
m_tensor_desc.xy_desc.desc, dst.raw_ptr, | |||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, | |||
m_param.epsilon)); | |||
break; | |||
default: | |||
megdnn_throw(megdnn_mangle( | |||
"Unknown forward mode type of batch normalization.")); | |||
} | |||
} | |||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) { | |||
check_exec(x.layout, dy.layout, saved_batch_mean.layout, | |||
saved_batch_inv_variance.layout, bn_scale.layout, | |||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size); | |||
auto handle = concrete_handle(this->handle())->miopen_handle(); | |||
m_tensor_desc.setup(x.layout, m_param.param_dim); | |||
float alpha = 1.0, beta = 0.0; | |||
miopen_check(miopenBatchNormalizationBackward( | |||
handle, m_tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, x.raw_ptr, m_tensor_desc.xy_desc.desc, | |||
dy.raw_ptr, m_tensor_desc.xy_desc.desc, dx.raw_ptr, | |||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, d_bn_scale.raw_ptr, | |||
d_bn_bias.raw_ptr, m_param.epsilon, saved_batch_mean.raw_ptr, | |||
saved_batch_inv_variance.raw_ptr)); | |||
} | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,80 @@ | |||
/** | |||
* \file dnn/src/rocm/batch_normalization/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/rocm/miopen_wrapper.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace batch_normalization { | |||
struct BNTensorDescHolder { | |||
using ParamDim = param::BN::ParamDim; | |||
TensorDesc xy_desc; | |||
BNParamDesc param_desc; | |||
miopenBatchNormMode_t bn_mode; | |||
void setup(const TensorLayout& x, const ParamDim& param_dim); | |||
}; | |||
} // namespace batch_normalization | |||
class BNForwardImpl final : public BNForward { | |||
public: | |||
using BNForward::BNForward; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||
}; | |||
class BNBackwardImpl final : public BNBackward { | |||
public: | |||
using BNBackward::BNBackward; | |||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||
}; | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -35,6 +35,7 @@ | |||
#include "src/rocm/linspace/opr_impl.h" | |||
#include "src/rocm/argmxx/opr_impl.h" | |||
#include "src/rocm/sleep/opr_impl.h" | |||
#include "src/rocm/batch_normalization/opr_impl.h" | |||
#include <miopen/version.h> | |||
#include <hip/hip_version.h> | |||
@@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wpragmas" | |||
@@ -0,0 +1,71 @@ | |||
/** | |||
* \file dnn/test/rocm/bn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/rocm/fixture.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/bn.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/rng.h" | |||
#include "test/common/tensor.h" | |||
#include "test/common/workspace_wrapper.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(ROCM, BN_FORWARD) { | |||
using namespace batch_normalization; | |||
std::vector<TestArg> args = get_args(); | |||
Checker<BNForward> checker(handle_rocm()); | |||
for (auto&& arg : args) { | |||
for (int i = 0; i < 8; ++i) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype); | |||
checker.set_epsilon(1e-3).set_param(arg.param); | |||
for (bool need_statistic : {false, true}) | |||
checker.exec({ | |||
arg.src, | |||
arg.param_shape, // bn_scale | |||
arg.param_shape, // bn_bias | |||
need_statistic ? arg.param_shape | |||
: TensorShape({0}), // mean | |||
need_statistic ? arg.param_shape | |||
: TensorShape({0}), // variance | |||
arg.param_shape, // batch_mean | |||
arg.param_shape, // batch_inv_variance | |||
{} // dst | |||
}); | |||
} | |||
} | |||
TEST_F(ROCM, BN_BACKWARD) { | |||
using namespace batch_normalization; | |||
std::vector<TestArg> args = get_args(); | |||
Checker<BNBackward> checker(handle_rocm()); | |||
for (auto&& arg : args) { | |||
for (int i = 0; i < 8; ++i) { | |||
checker.set_dtype(i, dtype::Float32()); | |||
} | |||
checker.set_dtype(0, arg.dtype) // x | |||
.set_dtype(1, arg.dtype) // dy | |||
.set_dtype(7, arg.dtype); // dx | |||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | |||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |