@@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) | |||||
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) | list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) | ||||
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) | list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) | ||||
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") | if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") | ||||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||||
endif() | |||||
if (NOT ${HIP_VERSION_MINOR} STREQUAL "7") | |||||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||||
message(FATAL_ERROR "ROCM version needed 3.x, Please update ROCM.") | |||||
else() | |||||
if (${HIP_VERSION_MINOR} LESS "7") | |||||
message(WARNING "ROCM version 3.x which x(got ${HIP_VERSION_MINOR}) greater equal 7 is prefered.") | |||||
endif() | |||||
endif() | endif() | ||||
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) | set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) | ||||
@@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR | |||||
DOC "Path to MIOPEN library directory." ) | DOC "Path to MIOPEN library directory." ) | ||||
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") | if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||||
endif() | endif() | ||||
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) | get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) | ||||
@@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR | |||||
DOC "Path to MIOPEN include directory." ) | DOC "Path to MIOPEN include directory." ) | ||||
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") | if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||||
endif() | endif() | ||||
#rocblas | #rocblas | ||||
@@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR | |||||
DOC "Path to ROCBLAS library directory." ) | DOC "Path to ROCBLAS library directory." ) | ||||
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") | if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||||
endif() | endif() | ||||
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) | get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) | ||||
@@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR | |||||
DOC "Path to ROCBLAS include directory." ) | DOC "Path to ROCBLAS include directory." ) | ||||
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") | if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||||
endif() | endif() | ||||
#rocrand | #rocrand | ||||
@@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR | |||||
DOC "Path to ROCRAND library directory." ) | DOC "Path to ROCRAND library directory." ) | ||||
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") | if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||||
endif() | endif() | ||||
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) | get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) | ||||
@@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR | |||||
DOC "Path to ROCRAND include directory." ) | DOC "Path to ROCRAND include directory." ) | ||||
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") | if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||||
endif() | endif() | ||||
@@ -0,0 +1,116 @@ | |||||
/** | |||||
* \file dnn/src/rocm/batch_normalization/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./opr_impl.h" | |||||
#include "src/rocm/utils.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
namespace batch_normalization { | |||||
void BNTensorDescHolder::setup(const TensorLayout& x, | |||||
const ParamDim& param_dim) { | |||||
TensorShape xy_shape(x); | |||||
switch (param_dim) { | |||||
case ParamDim::DIM_11HW: | |||||
// xy: N, C, H, W --> (N*C), 1, H, W | |||||
xy_shape.shape[0] = xy_shape.shape[0] * xy_shape.shape[1]; | |||||
xy_shape.shape[1] = 1; | |||||
bn_mode = miopenBNPerActivation; | |||||
break; | |||||
case ParamDim::DIM_1CHW: | |||||
bn_mode = miopenBNPerActivation; | |||||
break; | |||||
case ParamDim::DIM_1C11: | |||||
bn_mode = miopenBNSpatial; | |||||
break; | |||||
default: | |||||
megdnn_throw(megdnn_mangle( | |||||
"Unknown param dim type of batch normalization.")); | |||||
} | |||||
xy_desc.set(TensorLayout(xy_shape, x.dtype)); | |||||
param_desc.set(xy_desc.desc, bn_mode); | |||||
} | |||||
} // namespace batch_normalization | |||||
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||||
_megdnn_tensor_out variance, | |||||
_megdnn_tensor_out batch_mean, | |||||
_megdnn_tensor_out batch_inv_variance, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | |||||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | |||||
dst.layout, workspace.size); | |||||
auto handle = concrete_handle(this->handle())->miopen_handle(); | |||||
m_tensor_desc.setup(src.layout, m_param.param_dim); | |||||
float alpha = 1.0f, beta = 0.0f; | |||||
switch (m_param.fwd_mode) { | |||||
case param::BN::FwdMode::TRAINING: | |||||
miopen_check(miopenBatchNormalizationForwardTraining( | |||||
handle, m_tensor_desc.bn_mode, &alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, // xDesc | |||||
src.raw_ptr, // x | |||||
m_tensor_desc.xy_desc.desc, // yDesc | |||||
dst.raw_ptr, // y | |||||
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | |||||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | |||||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); | |||||
break; | |||||
case param::BN::FwdMode::INFERENCE: | |||||
miopen_check(miopenBatchNormalizationForwardInference( | |||||
handle, m_tensor_desc.bn_mode, &alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, src.raw_ptr, | |||||
m_tensor_desc.xy_desc.desc, dst.raw_ptr, | |||||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, | |||||
m_param.epsilon)); | |||||
break; | |||||
default: | |||||
megdnn_throw(megdnn_mangle( | |||||
"Unknown forward mode type of batch normalization.")); | |||||
} | |||||
} | |||||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||||
_megdnn_tensor_in saved_batch_mean, | |||||
_megdnn_tensor_in saved_batch_inv_variance, | |||||
_megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(x.layout, dy.layout, saved_batch_mean.layout, | |||||
saved_batch_inv_variance.layout, bn_scale.layout, | |||||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size); | |||||
auto handle = concrete_handle(this->handle())->miopen_handle(); | |||||
m_tensor_desc.setup(x.layout, m_param.param_dim); | |||||
float alpha = 1.0, beta = 0.0; | |||||
miopen_check(miopenBatchNormalizationBackward( | |||||
handle, m_tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, x.raw_ptr, m_tensor_desc.xy_desc.desc, | |||||
dy.raw_ptr, m_tensor_desc.xy_desc.desc, dx.raw_ptr, | |||||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, d_bn_scale.raw_ptr, | |||||
d_bn_bias.raw_ptr, m_param.epsilon, saved_batch_mean.raw_ptr, | |||||
saved_batch_inv_variance.raw_ptr)); | |||||
} | |||||
} // namespace rocm | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* \file dnn/src/rocm/batch_normalization/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/rocm/miopen_wrapper.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
namespace batch_normalization { | |||||
struct BNTensorDescHolder { | |||||
using ParamDim = param::BN::ParamDim; | |||||
TensorDesc xy_desc; | |||||
BNParamDesc param_desc; | |||||
miopenBatchNormMode_t bn_mode; | |||||
void setup(const TensorLayout& x, const ParamDim& param_dim); | |||||
}; | |||||
} // namespace batch_normalization | |||||
class BNForwardImpl final : public BNForward { | |||||
public: | |||||
using BNForward::BNForward; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | |||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||||
}; | |||||
class BNBackwardImpl final : public BNBackward { | |||||
public: | |||||
using BNBackward::BNBackward; | |||||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||||
_megdnn_tensor_in saved_batch_mean, | |||||
_megdnn_tensor_in saved_batch_inv_variance, | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||||
}; | |||||
} // namespace rocm | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -35,6 +35,7 @@ | |||||
#include "src/rocm/linspace/opr_impl.h" | #include "src/rocm/linspace/opr_impl.h" | ||||
#include "src/rocm/argmxx/opr_impl.h" | #include "src/rocm/argmxx/opr_impl.h" | ||||
#include "src/rocm/sleep/opr_impl.h" | #include "src/rocm/sleep/opr_impl.h" | ||||
#include "src/rocm/batch_normalization/opr_impl.h" | |||||
#include <miopen/version.h> | #include <miopen/version.h> | ||||
#include <hip/hip_version.h> | #include <hip/hip_version.h> | ||||
@@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wpragmas" | #pragma GCC diagnostic ignored "-Wpragmas" | ||||
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file dnn/test/rocm/bn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/rocm/fixture.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/bn.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/rng.h" | |||||
#include "test/common/tensor.h" | |||||
#include "test/common/workspace_wrapper.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(ROCM, BN_FORWARD) { | |||||
using namespace batch_normalization; | |||||
std::vector<TestArg> args = get_args(); | |||||
Checker<BNForward> checker(handle_rocm()); | |||||
for (auto&& arg : args) { | |||||
for (int i = 0; i < 8; ++i) { | |||||
checker.set_dtype(i, dtype::Float32()); | |||||
} | |||||
checker.set_dtype(0, arg.dtype); | |||||
checker.set_epsilon(1e-3).set_param(arg.param); | |||||
for (bool need_statistic : {false, true}) | |||||
checker.exec({ | |||||
arg.src, | |||||
arg.param_shape, // bn_scale | |||||
arg.param_shape, // bn_bias | |||||
need_statistic ? arg.param_shape | |||||
: TensorShape({0}), // mean | |||||
need_statistic ? arg.param_shape | |||||
: TensorShape({0}), // variance | |||||
arg.param_shape, // batch_mean | |||||
arg.param_shape, // batch_inv_variance | |||||
{} // dst | |||||
}); | |||||
} | |||||
} | |||||
TEST_F(ROCM, BN_BACKWARD) { | |||||
using namespace batch_normalization; | |||||
std::vector<TestArg> args = get_args(); | |||||
Checker<BNBackward> checker(handle_rocm()); | |||||
for (auto&& arg : args) { | |||||
for (int i = 0; i < 8; ++i) { | |||||
checker.set_dtype(i, dtype::Float32()); | |||||
} | |||||
checker.set_dtype(0, arg.dtype) // x | |||||
.set_dtype(1, arg.dtype) // dy | |||||
.set_dtype(7, arg.dtype); // dx | |||||
checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||||
{arg.src, arg.src, arg.param_shape, arg.param_shape, | |||||
arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |