Browse Source

feat(dnn/naive/norm,-dnn/cuda/norm,-dnn/test/norm): add norm dnn opr,

fwd only

GitOrigin-RevId: 989474168d
HuaHua404-patch-1
Megvii Engine Team 3 years ago
parent
commit
b55942a94d
17 changed files with 1464 additions and 1 deletions
  1. +29
    -0
      dnn/include/megdnn/oprs/general.h
  2. +8
    -0
      dnn/scripts/opr_param_defs.py
  3. +2
    -1
      dnn/src/common/handle_impl.h
  4. +43
    -0
      dnn/src/common/norm.cpp
  5. +1
    -0
      dnn/src/common/opr_trait.h
  6. +2
    -0
      dnn/src/cuda/handle_create.cpp
  7. +28
    -0
      dnn/src/cuda/norm/helper.cu
  8. +226
    -0
      dnn/src/cuda/norm/helper.h
  9. +180
    -0
      dnn/src/cuda/norm/opr_impl.cpp
  10. +25
    -0
      dnn/src/cuda/norm/opr_impl.h
  11. +1
    -0
      dnn/src/naive/handle.cpp
  12. +152
    -0
      dnn/src/naive/norm/helper.h
  13. +197
    -0
      dnn/src/naive/norm/opr_impl.cpp
  14. +23
    -0
      dnn/src/naive/norm/opr_impl.h
  15. +19
    -0
      dnn/test/common/norm.h
  16. +291
    -0
      dnn/test/cuda/norm.cpp
  17. +237
    -0
      dnn/test/naive/norm.cpp

+ 29
- 0
dnn/include/megdnn/oprs/general.h View File

@@ -1475,6 +1475,35 @@ protected:

using LAMB = LAMBUpdate;

class NormBase : public OperatorBase {
DEF_OPR_PARAM(Norm); // package norm params in Norm keyword from py declaration
DEF_OPR_IMPL(NormBase, OperatorBase, 1, 1); // constructor and static members

public:
virtual void deduce_layout(const TensorLayout& src, TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};

class NormForward : public NormBase {
DEF_OPR_IMPL(NormForward, NormBase, 1, 1);
using Mode = Param::Mode;

public:
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;
};
using Norm = NormForward;

} // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"


+ 8
- 0
dnn/scripts/opr_param_defs.py View File

@@ -1277,3 +1277,11 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true').
add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false')
)
(pdef("Norm").
add_enum('Mode',
Doc('P_NORM=0', 'calculate p-norm, parameter p would be ignored in other mode'),
Doc('INF_NORM=1', 'infinite norm'),
Doc('NEG_INF_NORM=2', 'negative infinite norm'), name_field="mode").
add_fields('float32', Doc('p', 'the order of norm'), '2').
add_fields('int32', Doc('dim', 'which dim the norm performed along'), '-1'),
)

+ 2
- 1
dnn/src/common/handle_impl.h View File

@@ -212,7 +212,8 @@ private:
cb(LAMBUpdate) \
cb(LSTMBackward) \
cb(SoftmaxForward) \
cb(SoftmaxBackward)
cb(SoftmaxBackward) \
cb(NormForward)
// clang-format on

/*!


+ 43
- 0
dnn/src/common/norm.cpp View File

@@ -0,0 +1,43 @@
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {
void NormForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
megdnn_assert(
param().dim > -1 && param().dim < static_cast<dt_int32>(src.ndim),
"dim params must be passed and cannot be -1.");

SmallVector<size_t> shapeList;
for (size_t i = 0; i < src.ndim; ++i) {
if (static_cast<dt_int32>(i) != param().dim) {
shapeList.append(1, static_cast<size_t>(src.shape[i]));
} else {
shapeList.append(1, static_cast<size_t>(1));
}
}
dst = TensorLayout{TensorShape(shapeList), src.dtype};
return;
}

void NormBase::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
megdnn_assert_eq_dtype(src, dst);

#if !MEGDNN_DISABLE_FLOAT16
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::Float32,
"Float16 or Float32 is only supported.");
#else
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32, "Float32 is only supported.");
#endif

TensorLayout dst_expected;
deduce_layout(src, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);

auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn

+ 1
- 0
dnn/src/common/opr_trait.h View File

@@ -16,6 +16,7 @@ struct OprTrait {};
static const bool can_deduce_layout = CanDeduceLayout; \
}

DEF(Norm, 2, true, true);
DEF(Padding, 2, false, true);
DEF(PaddingBackward, 2, false, false);
DEF(ConvolutionForward, 3, true, true);


+ 2
- 0
dnn/src/cuda/handle_create.cpp View File

@@ -47,6 +47,7 @@
#include "src/cuda/matrix_mul/opr_impl.h"
#include "src/cuda/max_tensor_diff/opr_impl.h"
#include "src/cuda/mesh_indexing/opr_impl.h"
#include "src/cuda/norm/opr_impl.h"
#include "src/cuda/padding/opr_impl.h"
#include "src/cuda/param_pack/opr_impl.h"
#include "src/cuda/pooling/opr_impl.h"
@@ -216,6 +217,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(NormForward);

template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {


+ 28
- 0
dnn/src/cuda/norm/helper.cu View File

@@ -0,0 +1,28 @@


#include "helper.h"
#include "megdnn/dtype.h"
#include "src/cuda/reduce_helper.cuh"

namespace megdnn {
namespace cuda {

using namespace device_reduce;
#define COMMA ,

INST_REDUCE(NormOp<dt_float32 COMMA dt_float32 COMMA dt_float32>, false);
INST_REDUCE(NormOp<dt_float16 COMMA dt_float16 COMMA dt_float16>, false);

INST_REDUCE(NormZeroOp<dt_float32 COMMA dt_float32 COMMA dt_float32>, false);
INST_REDUCE(NormZeroOp<dt_float16 COMMA dt_float16 COMMA dt_float16>, false);

INST_REDUCE(NormOneOp<dt_float32 COMMA dt_float32 COMMA dt_float32>, false);
INST_REDUCE(NormOneOp<dt_float16 COMMA dt_float16 COMMA dt_float16>, false);

INST_REDUCE(NormTwoOp<dt_float32 COMMA dt_float32 COMMA dt_float32>, false);
INST_REDUCE(NormTwoOp<dt_float16 COMMA dt_float16 COMMA dt_float16>, false);

#undef COMMA

} // namespace cuda
} // namespace megdnn

+ 226
- 0
dnn/src/cuda/norm/helper.h View File

@@ -0,0 +1,226 @@
#pragma once
#include "megdnn/dtype.h"

#if MEGDNN_CC_HOST
#include "megdnn/basic_types.h"
#endif

namespace megdnn {
namespace device_reduce {

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct NormOp;

template <>
struct NormOp<dt_float32, dt_float32, dt_float32> {
typedef dt_float32 wtype;
typedef dt_float32 src_ctype;
typedef dt_float32 dst_ctype;
typedef wtype p_type;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;
const p_type p;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
return powf(fabsf(src[idx]), p);
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = powf(val, 1.f / p);
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormOp(src_ctype* src, dst_ctype* dst, size_t B, p_type p)
: INIT(wtype(0)), src(src), dst(dst), B(B), p(static_cast<wtype>(p)) {}
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormOp<dt_float16, dt_float16, dt_float16> {
typedef dt_float16 wtype;
typedef dt_float16 src_ctype;
typedef dt_float16 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;
const wtype p;

// HALF_FLOAT API has dispatch host and device.
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
return half_float::detail::pow(half_float::detail::abs(src[idx]), p);
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = half_float::detail::pow(val, static_cast<wtype>(1.f) / p);
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE
NormOp(src_ctype* src, dst_ctype* dst, size_t B, dt_float32 p)
: INIT(wtype(0)), src(src), dst(dst), B(B), p(static_cast<wtype>(p)) {}
};
#endif

// TODO: 0Norm impl need understand reduceop
template <typename src_ctype, typename dst_ctype, typename wtype_>
struct NormZeroOp;

template <>
struct NormZeroOp<dt_float32, dt_float32, dt_float32> {
typedef dt_float32 wtype;
typedef dt_float32 src_ctype;
typedef dt_float32 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;
const wtype epsilon = 0.00001f;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
return fabsf(src[idx] - 0.0f) <= epsilon ? 0.0f : 1.0f;
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormZeroOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormZeroOp<dt_float16, dt_float16, dt_float16> {
typedef dt_float16 wtype;
typedef dt_float16 src_ctype;
typedef dt_float16 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;
const wtype epsilon = half_float::half(0.00001f);

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
return half_float::detail::fabs(src[idx] - half_float::half()) <= epsilon
? half_float::half(0.0f)
: half_float::half(1.0f);
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormZeroOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};
#endif

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct NormOneOp;

template <>
struct NormOneOp<dt_float32, dt_float32, dt_float32> {
typedef dt_float32 wtype;
typedef dt_float32 src_ctype;
typedef dt_float32 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return fabsf(src[idx]); }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormOneOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormOneOp<dt_float16, dt_float16, dt_float16> {
typedef dt_float16 wtype;
typedef dt_float16 src_ctype;
typedef dt_float16 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
return half_float::detail::abs(src[idx]);
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormOneOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};
#endif

template <typename src_ctype, typename dst_ctype, typename wtype_>
struct NormTwoOp;

template <>
struct NormTwoOp<dt_float32, dt_float32, dt_float32> {
typedef dt_float32 wtype;
typedef dt_float32 src_ctype;
typedef dt_float32 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx] * src[idx]; }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = sqrtf(val);
}

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormTwoOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormTwoOp<dt_float16, dt_float16, dt_float16> {
typedef dt_float16 wtype;
typedef dt_float16 src_ctype;
typedef dt_float16 dst_ctype;
const wtype INIT;

src_ctype* src;
dst_ctype* dst;
const size_t B;

MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx] * src[idx]; }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = half_float::detail::sqrt(val);
}

static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs + rhs;
}
MEGDNN_HOST MEGDNN_DEVICE NormTwoOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};
#endif

} // namespace device_reduce
} // namespace megdnn

+ 180
- 0
dnn/src/cuda/norm/opr_impl.cpp View File

@@ -0,0 +1,180 @@
#include "src/cuda/norm/opr_impl.h"
#include "helper.h"
#include "src/common/reduce_helper_device.h"
#include "src/common/utils.h"
#include "src/cuda/handle.h"
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {

using namespace device_reduce;
using Mode = Norm::Mode;

template <>
void NormForwardImpl::dispatch_mode<Mode::NEG_INF_NORM>(
_megdnn_tensor_inout src, _megdnn_tensor_inout dst, _megdnn_workspace workspace,
size_t A, size_t B, size_t C, cudaStream_t stream) {
#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
auto reduceOp = \
MinOp<ctype, ctype, ctype>(src.ptr<ctype>(), dst.ptr<ctype>(), B); \
run_reduce<MinOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, reduceOp); \
break; \
};
switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

template <>
void NormForwardImpl::dispatch_mode<Mode::INF_NORM>(
_megdnn_tensor_inout src, _megdnn_tensor_inout dst, _megdnn_workspace workspace,
size_t A, size_t B, size_t C, cudaStream_t stream) {
#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
auto reduceOp = \
MaxOp<ctype, ctype, ctype>(src.ptr<ctype>(), dst.ptr<ctype>(), B); \
run_reduce<MaxOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, reduceOp); \
break; \
};
switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

template <>
void NormForwardImpl::dispatch_mode<Mode::P_NORM>(
_megdnn_tensor_inout src, _megdnn_tensor_inout dst, _megdnn_workspace workspace,
size_t A, size_t B, size_t C, cudaStream_t stream) {
typedef dt_float32 p_type;

#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
p_type epsilon = 0.000001f; \
if (fabs(param().p - 0.0f) < epsilon) { \
run_reduce<NormZeroOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, \
NormZeroOp<ctype, ctype, ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), B)); \
} else if (fabs(param().p - 1.0f) < epsilon) { \
run_reduce<NormOneOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, \
NormOneOp<ctype, ctype, ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), B)); \
} else if (fabs(param().p - 2.0f) < epsilon) { \
run_reduce<NormTwoOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, \
NormTwoOp<ctype, ctype, ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), B)); \
} else { \
run_reduce<NormOp<ctype, ctype, ctype>, false>( \
workspace.ptr<ctype>(), A, B, C, stream, \
NormOp<ctype, ctype, ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), B, param().p)); \
} \
break; \
};

switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

} // namespace cuda

namespace cuda {
void NormForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, param().dim);
auto stream = cuda_stream(this->handle());

#define CASE(mode) \
case mode: { \
dispatch_mode<mode>(src, dst, workspace, A, B, C, stream); \
break; \
};

switch (param().mode) {
CASE(Mode::P_NORM)
CASE(Mode::INF_NORM)
CASE(Mode::NEG_INF_NORM)
default:
megdnn_assert_internal(false);
}
#undef CASE

return;
}

size_t NormForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
using namespace device_reduce;
size_t A, B, C;
reduce::get_ABC(src, A, B, C, param().dim);

#define cb(dt, op) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
return get_reduce_workspace_in_bytes<op<ctype, ctype, ctype>>(A, B, C); \
break; \
};

#if !MEGDNN_DISABLE_FLOAT16
#define CASE(mode, op) \
case mode: { \
switch (src.dtype.enumv()) { \
cb(::megdnn::dtype::Float32, op) cb(::megdnn::dtype::Float16, op) default \
: megdnn_assert_internal(false); \
} \
};
#else
#define CASE(mode, op) \
case mode: { \
switch (src.dtype.enumv()) { \
cb(::megdnn::dtype::Float32, op) default : megdnn_assert_internal(false); \
} \
};
#endif

// XXX: 0/1 norm dispathed to different Op, but workspace size same as
// NormOp
switch (param().mode) {
CASE(Mode::INF_NORM, MaxOp)
CASE(Mode::NEG_INF_NORM, MinOp)
CASE(Mode::P_NORM, NormOp)
default:
megdnn_assert_internal(false);
}
#undef CASE
#undef cb
}

} // namespace cuda
} // namespace megdnn

+ 25
- 0
dnn/src/cuda/norm/opr_impl.h View File

@@ -0,0 +1,25 @@
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/utils.h"

namespace megdnn {
namespace cuda {
class NormForwardImpl : public NormForward {
using Norm::Norm;

public:
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;

protected:
template <Mode mode>
void dispatch_mode(
_megdnn_tensor_inout src, _megdnn_tensor_inout dst,
_megdnn_workspace workspace, size_t A, size_t B, size_t C,
cudaStream_t stream);
};
} // namespace cuda
} // namespace megdnn

+ 1
- 0
dnn/src/naive/handle.cpp View File

@@ -51,6 +51,7 @@
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/naive/max_tensor_diff/opr_impl.h"
#include "src/naive/mesh_indexing/opr_impl.h"
#include "src/naive/norm/opr_impl.h"
#include "src/naive/padding/opr_impl.h"
#include "src/naive/param_pack/opr_impl.h"
#include "src/naive/pooling/opr_impl.h"


+ 152
- 0
dnn/src/naive/norm/helper.h View File

@@ -0,0 +1,152 @@
#pragma once
#include <algorithm>
#include <numeric>
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"
#include "src/common/utils.h"

using namespace megdnn;

/* anonymous namespace */
namespace {
using Mode = Reduce::Mode;

/* Reduce Trait */
template <Mode mode, typename ctype>
struct Trait;

template <typename ctype>
struct Trait<Mode::SUM, ctype> {
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::SUM, ctype>::INIT = ctype(0);

template <typename ctype>
struct Trait<Mode::MEAN, ctype> {
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t B) { return x / (ctype)B; }
};
template <typename ctype>
const ctype Trait<Mode::MEAN, ctype>::INIT = ctype(0);

template <typename ctype>
struct Trait<Mode::SUM_SQR, ctype> {
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x * x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::SUM_SQR, ctype>::INIT = ctype(0);

template <typename ctype>
struct Trait<Mode::PRODUCT, ctype> {
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x * y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);

template <typename ctype>
struct Trait<Mode::MIN, ctype> {
static ctype apply(ctype x, ctype y) { return x < y ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};

template <>
struct Trait<Mode::MIN, dt_float32> {
using ctype = dt_float32;

static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};

template <typename ctype>
struct Trait<Mode::MAX, ctype> {
static ctype apply(ctype x, ctype y) { return x > y ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};

template <>
struct Trait<Mode::MAX, dt_float32> {
using ctype = dt_float32;

static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};

/* NormOp */
template <typename ctype>
struct NormOp;

template <>
struct NormOp<dt_float32> {
typedef dt_float32 ctype;
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x, dt_float32 p) { return powf(fabs(x), p); }
static ctype write(ctype x, size_t, dt_float32 p) { return powf(x, 1.f / p); }
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormOp<dt_float16> {
typedef dt_float16 ctype;
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x, dt_float32 p) {
return half_float::pow(half_float::abs(x), half_float::half(p));
}
static ctype write(ctype x, size_t, dt_float32 p) {
return half_float::pow(x, half_float::half(1.f / p));
}
};
#endif

template <typename ctype>
struct NormZeroOp;

template <>
struct NormZeroOp<dt_float32> {
typedef dt_float32 ctype;
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x - 0.f < 0.00001f ? 0.f : 1.f; }
static ctype write(ctype x, size_t) { return x; }
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NormZeroOp<dt_float16> {
typedef dt_float16 ctype;
static const ctype INIT;

static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) {
return x - half_float::half(0.f) < half_float::half(0.00001f)
? half_float::half(0.f)
: half_float::half(1.f);
}
static ctype write(ctype x, size_t) { return x; }
};
#endif
} // namespace

+ 197
- 0
dnn/src/naive/norm/opr_impl.cpp View File

@@ -0,0 +1,197 @@
#include "src/naive/norm/opr_impl.h"

#include "helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"

namespace megdnn {
namespace naive {
using Mode = Norm::Mode;

template <>
void NormForwardImpl::dispatch_mode<Mode::NEG_INF_NORM>(
_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t A, size_t B, size_t C) {
#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
const ctype* __restrict sptr = src.ptr<ctype>(); \
ctype* __restrict dptr = dst.ptr<ctype>(); \
std::function<ctype(size_t, size_t, size_t, size_t)> func; \
func = [&](size_t a, size_t c, size_t bl, size_t br) -> ctype { \
if (bl + 1 < br) { \
size_t mid = bl + (br - bl) / 2; \
return Trait<ReduceForward::Mode::MIN, ctype>::apply( \
func(a, c, bl, mid), func(a, c, mid, br)); \
} else { \
return Trait<ReduceForward::Mode::MIN, ctype>::visit( \
sptr[a * B * C + bl * C + c]); \
} \
}; \
for (size_t a = 0; a < A; ++a) \
for (size_t c = 0; c < C; ++c) { \
dptr[a * C + c] = Trait<ReduceForward::Mode::MIN, ctype>::write( \
func(a, c, 0, B), B); \
} \
break; \
};

switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

template <>
void NormForwardImpl::dispatch_mode<Mode::INF_NORM>(
_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t A, size_t B, size_t C) {
#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
const ctype* __restrict sptr = src.ptr<ctype>(); \
ctype* __restrict dptr = dst.ptr<ctype>(); \
std::function<ctype(size_t, size_t, size_t, size_t)> func; \
func = [&](size_t a, size_t c, size_t bl, size_t br) -> ctype { \
if (bl + 1 < br) { \
size_t mid = bl + (br - bl) / 2; \
return Trait<ReduceForward::Mode::MAX, ctype>::apply( \
func(a, c, bl, mid), func(a, c, mid, br)); \
} else { \
return Trait<ReduceForward::Mode::MAX, ctype>::visit( \
sptr[a * B * C + bl * C + c]); \
} \
}; \
for (size_t a = 0; a < A; ++a) \
for (size_t c = 0; c < C; ++c) { \
dptr[a * C + c] = Trait<ReduceForward::Mode::MAX, ctype>::write( \
func(a, c, 0, B), B); \
} \
break; \
};

switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

template <>
void NormForwardImpl::dispatch_mode<Mode::P_NORM>(
_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t A, size_t B, size_t C) {
#define CASE(dt) \
case DTypeTrait<dt>::enumv: { \
using ctype = DTypeTrait<dt>::ctype; \
const ctype* __restrict sptr = src.ptr<ctype>(); \
ctype* __restrict dptr = dst.ptr<ctype>(); \
std::function<ctype(size_t, size_t, size_t, size_t)> func; \
if (param().p - 0.f < 0.00001f) { \
func = [&](size_t a, size_t c, size_t bl, size_t br) -> ctype { \
if (bl + 1 < br) { \
size_t mid = bl + (br - bl) / 2; \
return NormZeroOp<ctype>::apply( \
func(a, c, bl, mid), func(a, c, mid, br)); \
} else { \
return NormZeroOp<ctype>::visit(sptr[a * B * C + bl * C + c]); \
} \
}; \
for (size_t a = 0; a < A; ++a) { \
for (size_t c = 0; c < C; ++c) { \
dptr[a * C + c] = NormZeroOp<ctype>::write(func(a, c, 0, B), B); \
} \
} \
} else { \
func = [&](size_t a, size_t c, size_t bl, size_t br) -> ctype { \
if (bl + 1 < br) { \
size_t mid = bl + (br - bl) / 2; \
return NormOp<ctype>::apply( \
func(a, c, bl, mid), func(a, c, mid, br)); \
} else { \
return NormOp<ctype>::visit( \
sptr[a * B * C + bl * C + c], param().p); \
} \
}; \
for (size_t a = 0; a < A; ++a) { \
for (size_t c = 0; c < C; ++c) { \
dptr[a * C + c] = \
NormOp<ctype>::write(func(a, c, 0, B), B, param().p); \
} \
} \
} \
break; \
};

switch (src.layout.dtype.enumv()) {
CASE(::megdnn::dtype::Float32)
#if !MEGDNN_DISABLE_FLOAT16
CASE(::megdnn::dtype::Float16)
#endif
default:
megdnn_assert_internal(false);
}
#undef CASE
}

void NormForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
using namespace reduce;
size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, param().dim);
auto make_tensor = [&](DType comp_dtype, _megdnn_tensor_inout tensor,
dt_byte*& workspace_ptr) {
if (comp_dtype == tensor.layout.dtype)
return tensor;
auto layout = TensorLayout(tensor.layout, comp_dtype);
TensorND new_tensor{workspace_ptr, layout};
workspace_ptr += layout.span().dist_byte();
return new_tensor;
};
auto typecvt = handle()->create_operator<TypeCvt>();

auto copy_to = [&typecvt](const TensorND& from, const TensorND& to) {
if (from.raw_ptr() != to.raw_ptr())
typecvt->exec(from, to);
};

auto workspace_ptr = workspace.ptr<dt_byte>();

auto new_src = make_tensor(src.layout.dtype, src, workspace_ptr);
auto new_dst = make_tensor(dst.layout.dtype, dst, workspace_ptr);

#define CASE(mode) \
case mode: { \
copy_to(src, new_src); \
::megdnn::naive::HandleImpl* handlePtr = static_cast<HandleImpl*>(handle()); \
MEGDNN_DISPATCH_CPU_KERN( \
handlePtr, dispatch_mode<mode>(new_src, new_dst, A, B, C)); \
copy_to(new_dst, dst); \
break; \
};
switch (param().mode) {
CASE(Mode::P_NORM)
CASE(Mode::INF_NORM)
CASE(Mode::NEG_INF_NORM)
default:
megdnn_assert_internal(false);
}
#undef CASE
}

size_t NormForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(dst);
return 0;
}

} // namespace naive
} // namespace megdnn

+ 23
- 0
dnn/src/naive/norm/opr_impl.h View File

@@ -0,0 +1,23 @@
#pragma once
#include "megdnn/oprs.h"
#include "src/common/reduce_helper.h"
#include "src/naive/reduce/opr_impl.h"

namespace megdnn {
namespace naive {
class NormForwardImpl : public Norm {
public:
using Norm::Norm;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;

protected:
template <Mode mode>
void dispatch_mode(
_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t, size_t, size_t);
};
} // namespace naive
} // namespace megdnn

+ 19
- 0
dnn/test/common/norm.h View File

@@ -0,0 +1,19 @@

#pragma once
#include <iostream>
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"

namespace megdnn {
namespace test {
namespace norm {

struct TestArg {
param::Norm param;
TensorShape src;
TestArg(param::Norm param, TensorShape src) : param(param), src(src) {}
};

} // namespace norm
} // namespace test
} // namespace megdnn

+ 291
- 0
dnn/test/cuda/norm.cpp View File

@@ -0,0 +1,291 @@
#include "test/common/norm.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
// #include "test/naive/fixture.h"
// #include "test/common/benchmarker.h"
#include <iostream>
#include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h"
#include "test/cuda/utils.h"

namespace megdnn {
namespace test {
// CORRECT
// L2, fp32, dim
TEST_F(CUDA, L2NORM_FP32_DIM0) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 2;
param.dim = 0;
checker.set_param(param);
checker.exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
});
}
TEST_F(CUDA, L2NORM_FP32_DIM1) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 2;
param.dim = 1;
checker.set_param(param);
checker.exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 3, 4}, dtype::Float32(),
{12.000, 13.0384, 14.1421, 15.2971, 16.4924, 17.7200,
18.9737, 20.2485, 21.5407, 22.8473, 24.1661, 25.4951}),
});
}
TEST_F(CUDA, L2NORM_FP32_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 2;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float32(),
{3.7417, 11.2250, 19.1311, 27.0924, 35.0714, 43.0581})});
}
// TODO: support -1 dim param, or test for assert
// l2, fp16
TEST_F(CUDA, L2NORM_FP16_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 2;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float16(),
{3.7422, 11.2266, 19.1250, 27.0938, 35.0625, 43.0625})});
}
// l1, fp32,fp16
TEST_F(CUDA, L1NORM_FP32_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 1;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float32(), {6, 22, 38, 54, 70, 86}),
});
}
TEST_F(CUDA, L1NORM_FP16_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 1;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float16(), {6, 22, 38, 54, 70, 86}),
});
}
// l0, fp32,fp16
TEST_F(CUDA, L0NORM_FP32_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 0;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {3, 4, 4, 4, 4, 4}),
});
}
TEST_F(CUDA, L0NORM_FP16_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.p = 0;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {3, 4, 4, 4, 4, 4}),
});
}
// inf
TEST_F(CUDA, INF_NORM_FP32_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
using Mode = Norm::Param::Mode;

param.dim = 3;
param.mode = Mode::INF_NORM;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {3, 7, 11, 15, 19, 23}),
});
}
TEST_F(CUDA, INF_NORM_FP16_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
using Mode = Norm::Param::Mode;

param.dim = 3;
param.mode = Mode::INF_NORM;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {3, 7, 11, 15, 19, 23}),
});
}
// -inf
TEST_F(CUDA, NEG_INF_NORM_FP32_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.mode = Norm::Param::Mode::NEG_INF_NORM;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {0, 4, 8, 12, 16, 20}),
});
}
TEST_F(CUDA, NEG_INF_NORM_FP16_DIM3) {
Checker<Norm> checker(handle_cuda());
Norm::Param param;
param.mode = Norm::Param::Mode::NEG_INF_NORM;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {0, 4, 8, 12, 16, 20}),
});
}

// PERF
TEST_F(CUDA, L2NORM_SPEED_FP32) {
auto benchmarker = Benchmarker<Norm>(handle_cuda());
benchmarker.set_dtype(0, dtype::Float32());
benchmarker.set_dtype(1, dtype::Float32());
Norm::Param param;
param.mode = Norm::Param::Mode::P_NORM;
param.dim = 0;
param.p = 2;
SmallVector<TensorShape> shapes{{4194304}, {}};
NormalRNG rng(0, 1);
float eachTime;
float totalTime = 0.f;
#define ITER 10
for (auto i = 0; i < ITER; i++) {
eachTime = benchmarker.set_param(param).set_rng(0, &rng).exec(shapes);
// printf("PNORM_SPEED_FP32 cuda time: %.6fms\n", eachTime);
totalTime += eachTime;
}
totalTime /= ITER;
printf("PNORM_SPEED_FP32 AVG TIME: %.6fms\n", totalTime);
#undef ITER
}
TEST_F(CUDA, INFNORM_SPEED_FP32) {
auto benchmarker = Benchmarker<Norm>(handle_cuda());
benchmarker.set_dtype(0, dtype::Float32());
benchmarker.set_dtype(1, dtype::Float32());
Norm::Param param;
param.mode = Norm::Param::Mode::INF_NORM;
param.dim = 0;
SmallVector<TensorShape> shapes{{4194304}, {}};
NormalRNG rng(0, 1);
float time_fp32 = benchmarker.set_param(param).set_rng(0, &rng).exec(shapes);
printf("INF_SPEED_FP32 cuda time: float=%.6fms\n", time_fp32);
}
TEST_F(CUDA, NEG_INFNORM_SPEED_FP32) {
auto benchmarker = Benchmarker<Norm>(handle_cuda());
benchmarker.set_dtype(0, dtype::Float32());
benchmarker.set_dtype(1, dtype::Float32());
Norm::Param param;
param.mode = Norm::Param::Mode::NEG_INF_NORM;
param.dim = 0;
SmallVector<TensorShape> shapes{{4194304}, {}};
NormalRNG rng(0, 1);
float time_fp32 = benchmarker.set_param(param).set_rng(0, &rng).exec(shapes);
printf("NEG_INF_SPEED_FP32 cuda time: float=%.6fms\n", time_fp32);
}
} // namespace test
} // namespace megdnn

+ 237
- 0
dnn/test/naive/norm.cpp View File

@@ -0,0 +1,237 @@
#include "test/common/norm.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"

namespace megdnn {
namespace test {
TEST_F(NAIVE, L2NORM_FP32_DIM0) {
Checker<Norm> checker(handle(), false);
Norm::Param param;
param.p = 2;
param.dim = 0;
checker.set_param(param);
checker.exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
});
}
TEST_F(NAIVE, L2NORM_FP32_DIM1) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 2;
param.dim = 1;
checker.set_param(param);
checker.exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 3, 4}, dtype::Float32(),
{12.000, 13.0384, 14.1421, 15.2971, 16.4924, 17.7200,
18.9737, 20.2485, 21.5407, 22.8473, 24.1661, 25.4951}),
});
}
TEST_F(NAIVE, L2NORM_FP32_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 2;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float32(),
{3.7417, 11.2250, 19.1311, 27.0924, 35.0714, 43.0581})});
}
// l2, fp16
TEST_F(NAIVE, L2NORM_FP16_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 2;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float16(),
{3.7422, 11.2266, 19.1250, 27.0938, 35.0625, 43.0625})});
}
// l1, fp32,fp16
TEST_F(NAIVE, L1NORM_FP32_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 1;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float32(), {6, 22, 38, 54, 70, 86}),
});
}
TEST_F(NAIVE, L1NORM_FP16_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 1;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue(
{1, 2, 3, 1}, dtype::Float16(), {6, 22, 38, 54, 70, 86}),
});
}
// l0, fp32,fp16
TEST_F(NAIVE, L0NORM_FP32_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 0;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {3, 4, 4, 4, 4, 4}),
});
}
TEST_F(NAIVE, L0NORM_FP16_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.p = 0;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {3, 4, 4, 4, 4, 4}),
});
}
// inf
TEST_F(NAIVE, INF_NORM_FP32_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
using Mode = Norm::Param::Mode;

param.dim = 3;
param.mode = Mode::INF_NORM;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {3, 7, 11, 15, 19, 23}),
});
}
TEST_F(NAIVE, INF_NORM_FP16_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
using Mode = Norm::Param::Mode;

param.dim = 3;
param.mode = Mode::INF_NORM;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {3, 7, 11, 15, 19, 23}),
});
}
// -inf
TEST_F(NAIVE, NEG_INF_NORM_FP32_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.mode = Norm::Param::Mode::NEG_INF_NORM;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float32(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float32(), {0, 4, 8, 12, 16, 20}),
});
}
TEST_F(NAIVE, NEG_INF_NORM_FP16_DIM3) {
Checker<Norm> checker(handle());
Norm::Param param;
param.mode = Norm::Param::Mode::NEG_INF_NORM;
param.dim = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 2, 3, 4}, dtype::Float16(),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
{}},
Testcase{
{},
TensorValue({1, 2, 3, 1}, dtype::Float16(), {0, 4, 8, 12, 16, 20}),
});
}

} // namespace test
} // namespace megdnn

Loading…
Cancel
Save