#include "megdnn/oprs.h" #include "src/common/utils.h" #include namespace megdnn { void TileRepeatBase::check_layout_fwd( const TensorLayout& src, const TensorLayout& dst) { auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + "times=" + param().times.to_string(); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); auto expected_ndim = param().times.ndim; megdnn_assert(expected_ndim == src.ndim, "%s", errmsg_c); megdnn_assert(expected_ndim == dst.ndim, "%s", errmsg_c); rep(i, expected_ndim) { megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i], "%s", errmsg_c); } megdnn_assert(src.dtype == dst.dtype); } void TileRepeatBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { dst.ndim = src.ndim; rep(i, src.ndim) { dst.shape[i] = src.shape[i] * param().times[i]; } dst.dtype = src.dtype; dst.init_contiguous_stride(); check_layout_fwd(src, dst); } size_t TileRepeatBase::get_workspace_in_bytes_fwd( const TensorShape& /* src */, const TensorShape& dst, const TensorShape& times, DType dtype) { size_t nr_workspace = 0; auto nr_reduces = count_not_ones_in_shape(times); if (nr_reduces == 0) { // case 1: no tile/repeat is needed, let alone workspace. nr_workspace = 0; } else if (nr_reduces == 1) { // case 2: only one tile/repeat is needed, so we don't need workspace. nr_workspace = 0; } else if (nr_reduces == 2) { // case 3: two tile/repeats are needed, so we need a single workspace. nr_workspace = 1; } else { // case 4: multiple tile/repeats are needed, so we need two workspace in // an alternate fashion. nr_workspace = 2; } if (nr_workspace == 0) { return 0; } else { WorkspaceBundle workspaces{ nullptr, {nr_workspace, dst.total_nr_elems() * dtype.size()}}; return workspaces.total_size_in_bytes(); } } void TileBase::simplify_shape( const TensorShape& src, const TensorShape& dst, const TensorShape& times, TensorShape& src2, TensorShape& dst2, TensorShape& times2) { size_t n = 0; for (size_t i = 0; i < src.ndim; ++i) { if (times.shape[i] == 1 && n > 0) { src2.shape[n - 1] *= src.shape[i]; dst2.shape[n - 1] *= dst.shape[i]; } else { src2.shape[n] = src.shape[i]; dst2.shape[n] = dst.shape[i]; times2.shape[n] = times.shape[i]; ++n; } } src2.ndim = dst2.ndim = times2.ndim = n; } size_t TileBase::get_workspace_in_bytes_fwd( const TensorLayout& src_, const TensorLayout& dst_) { TensorShape src, dst, times; simplify_shape(src_, dst_, param().times, src, dst, times); return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype); } void TileForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } void TileForward::check_exec( const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void TileBackward::check_exec( const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void RepeatBase::simplify_shape( const TensorShape& src, const TensorShape& /* dst */, const TensorShape& times, TensorShape& src2, TensorShape& dst2, TensorShape& times2) { auto n = 0u; size_t i = 0; while (i < times.ndim) { size_t j = i; while (j < times.ndim && times.shape[j] == 1) ++j; // Here: j is times.ndim, or times.shape[j] != 1 if (j < times.ndim) ++j; src2.shape[n] = std::accumulate( src.shape + i, src.shape + j, 1_z, SafeMultiplies()); times2.shape[n] = times.shape[j - 1]; dst2.shape[n] = src2.shape[n] * times2.shape[n]; ++n; i = j; } src2.ndim = dst2.ndim = times2.ndim = n; } size_t RepeatBase::get_workspace_in_bytes_fwd( const TensorLayout& src_, const TensorLayout& dst_) { TensorShape src, dst, times; simplify_shape(src_, dst_, param().times, src, dst, times); return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype); } void RepeatForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } void RepeatForward::check_exec( const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void RepeatBackward::check_exec( const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } } // namespace megdnn // vim: syntax=cpp.doxygen