Browse Source

fix(dnn/sfotmax): call cpu dispatch for softmax opr

GitOrigin-RevId: a606e66101
master
Megvii Engine Team 2 years ago
parent
commit
582dd4ceb8
2 changed files with 47 additions and 26 deletions
  1. +32
    -26
      dnn/src/fallback/softmax/opr_impl.cpp
  2. +15
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 32
- 26
dnn/src/fallback/softmax/opr_impl.cpp View File

@@ -6,35 +6,19 @@

namespace megdnn {
namespace fallback {
void SoftmaxForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto axis = param().axis;
if (axis < 0)
axis += src.layout.ndim;
megdnn_assert(axis >= 0);
check_exec(src.layout, dst.layout, workspace.size);

if (!usable(src.layout)) {
naive::SoftmaxForwardImpl::exec(src, dst, workspace);
return;
}

typedef DTypeTrait<dtype::Float32>::ctype Float32;
auto sptr = src.ptr<Float32>();
auto dptr = dst.ptr<Float32>();

constexpr auto float_min = std::numeric_limits<Float32>::min();
constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(Float32);

size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, axis);

static void do_softmax(
const float* sptr, float* dptr, size_t A, size_t B, size_t C,
_megdnn_workspace workspace) {
constexpr auto float_min = std::numeric_limits<float>::min();
constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(float);
// TODO: When C=2,3,4..., src_ptr span is relatively large, the performance may
// be poor

if (C != 1) {
WorkspaceBundle workspace_bundle{
workspace.raw_ptr, {A * C * sizeof(Float32), A * C * sizeof(Float32)}};
Float32* max = workspace_bundle.get_workspace(0).raw_ptr->as<Float32>();
workspace.raw_ptr, {A * C * sizeof(float), A * C * sizeof(float)}};
float* max = workspace_bundle.get_workspace(0).raw_ptr->as<float>();
GI_FLOAT32_t v_max = GiBroadcastFloat32(float_min);
size_t i = 0;
for (; i + step <= A * C; i += step)
@@ -60,8 +44,8 @@ void SoftmaxForwardImpl::exec(
}
}

Float32* sum = workspace_bundle.get_workspace(1).raw_ptr->as<Float32>();
memset(sum, 0, A * C * sizeof(Float32));
float* sum = workspace_bundle.get_workspace(1).raw_ptr->as<float>();
memset(sum, 0, A * C * sizeof(float));
for (size_t a = 0; a < A; a++) {
for (size_t b = 0; b < B; b++) {
auto max_ptr = max + a * C;
@@ -157,6 +141,28 @@ void SoftmaxForwardImpl::exec(
}
}

void SoftmaxForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto axis = param().axis;
if (axis < 0)
axis += src.layout.ndim;
megdnn_assert(axis >= 0);
check_exec(src.layout, dst.layout, workspace.size);

if (!usable(src.layout)) {
naive::SoftmaxForwardImpl::exec(src, dst, workspace);
return;
}

typedef DTypeTrait<dtype::Float32>::ctype Float32;
auto sptr = src.ptr<Float32>();
auto dptr = dst.ptr<Float32>();

size_t A, B, C;
reduce::get_ABC(src.layout, A, B, C, axis);
MEGDNN_DISPATCH_CPU_KERN_OPR(do_softmax(sptr, dptr, A, B, C, workspace));
}

} // namespace fallback
} // namespace megdnn



+ 15
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -1653,3 +1653,18 @@ def test_conv_transpose3d():
np.testing.assert_equal(
output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32)
)


@pytest.mark.skip(reason="pytest aborted")
def test_softmax():
def np_softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)

data = (np.random.random(size=(1, 16, 224, 224)).astype(np.float32) - 0.5) * 100
desired = np_softmax(data[:, :3, 0, 0])

data = Tensor(data)
data = data[:, :3, 0, 0]
actual = F.softmax(data)

np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5)

Loading…
Cancel
Save