|
- #include "megdnn/oprs.h"
-
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- using Param = GroupNormBase::Param;
-
- void GroupNormBase::deduce_layout_fwd(
- const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
- TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
- MEGDNN_MARK_USED_VAR(weight);
- MEGDNN_MARK_USED_VAR(bias);
- size_t N = data.shape[0];
- size_t group = param().group;
- TensorLayout unnormalized_layout({N, group}, dtype::Float32());
- dst = data;
- mean = unnormalized_layout;
- rstd = unnormalized_layout;
- }
-
- void GroupNormBase::check_layout_fwd(
- const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
- const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
- megdnn_assert_contiguous(data);
- megdnn_assert_contiguous(weight);
- megdnn_assert_contiguous(bias);
- megdnn_assert_contiguous(dst);
- megdnn_assert_contiguous(mean);
- megdnn_assert_contiguous(rstd);
- auto errmsg = [&]() {
- return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " +
- megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " +
- megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd);
- };
- MEGDNN_MARK_USED_VAR(errmsg);
-
- megdnn_assert(data.eq_layout(dst), "%s", errmsg().c_str());
- megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str());
- megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
-
- auto p = param();
- size_t C = data.shape[1];
- size_t group = p.group;
- megdnn_assert(
- group > 0, "Expected num groups to be greater than 0, got %zu", group);
- megdnn_assert(
- C % group == 0,
- "Expected number of channels in input to be divisible by num_groups, but "
- "got Channel of shape %zu and num_groups= %zu",
- C, group);
- }
-
- void GroupNormForward::deduce_layout(
- const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
- TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
- deduce_layout_fwd(data, weight, bias, dst, mean, rstd);
- }
-
- void GroupNormForward::check_exec(
- const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
- const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd,
- size_t workspace_in_bytes) {
- check_layout_fwd(data, weight, bias, dst, mean, rstd);
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(data, weight, bias, dst, mean, rstd);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- void GroupNormBackward::deduce_layout(
- const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata,
- TensorLayout& dweight, TensorLayout& dbias) {
- MEGDNN_MARK_USED_VAR(diff);
- MEGDNN_MARK_USED_VAR(mean);
- MEGDNN_MARK_USED_VAR(rstd);
- ddata = data;
- dweight = weight;
- dbias = weight;
- }
-
- void GroupNormBackward::check_exec(
- const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
- const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata,
- const TensorLayout& dweight, const TensorLayout& dbias,
- size_t workspace_in_bytes) {
- auto p = param();
- auto required_workspace_in_bytes = get_workspace_in_bytes(
- diff, data, weight, mean, rstd, ddata, dweight, dbias);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
-
- megdnn_assert_contiguous(diff);
- megdnn_assert_contiguous(data);
- megdnn_assert_contiguous(mean);
- megdnn_assert_contiguous(rstd);
- megdnn_assert_contiguous(ddata);
- if (p.affine) {
- megdnn_assert_contiguous(weight);
- megdnn_assert_contiguous(dweight);
- megdnn_assert_contiguous(dbias);
- }
-
- auto errmsg = [&]() {
- return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " +
- megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " +
- megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " +
- megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias);
- };
- MEGDNN_MARK_USED_VAR(errmsg);
-
- megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str());
- megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
- if (p.affine) {
- megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str());
- megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str());
- }
- }
-
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|