#include "megdnn/oprs.h" #include "megdnn/oprs/nn_int.h" #include "src/common/utils.h" namespace megdnn { void BatchConvBiasForward::deduce_dtype( DType src, DType filter, DType /* bias */, DType /* z */, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } void BatchConvBiasForward::deduce_layout( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& /* bias */, const TensorLayout& /* z */, TensorLayout& dst) { TensorLayout non_batch_filter; non_batch_filter.ndim = filter.ndim - 1; non_batch_filter.dtype = filter.dtype; for (size_t i = 0; i < non_batch_filter.ndim; i++) { non_batch_filter[i] = filter[i + 1]; non_batch_filter.stride[i] = filter.stride[i + 1]; } non_batch_filter.format = filter.format; deduce_layout_fwd(src, non_batch_filter, dst); } BatchConvBiasForward::CanonizedFilterMeta BatchConvBiasForward::check_exec( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert( src.dtype.enumv() == filter.dtype.enumv() && src.dtype.enumv() == DTypeEnum::QuantizedS8, "batch conv only support qint8"); float scale_src = src.dtype.param().scale; float scale_filter = filter.dtype.param().scale; float scale_bias = bias.dtype.param().scale; megdnn_assert( std::abs(scale_src * scale_filter - scale_bias) < 1e-6, "scale_bias is not equal to the product of scale_src and " "scale_filter (scale_src: %f scale_filter: %f scale_bias: %f).", scale_src, scale_filter, scale_bias); TensorLayout non_batch_filter; non_batch_filter.ndim = filter.ndim - 1; non_batch_filter.dtype = filter.dtype; for (size_t i = 0; i < non_batch_filter.ndim; i++) { non_batch_filter[i] = filter[i + 1]; non_batch_filter.stride[i] = filter.stride[i + 1]; } non_batch_filter.format = filter.format; auto ret = check_layout_fwd(src, non_batch_filter, dst); megdnn_assert_contiguous(bias); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, bias, z, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); if (bias.ndim != 0) { //! bias.layout == dst.layout failed, no assert information auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) { if (dst.dtype.category() == DTypeCategory::QUANTIZED) { return bias.eq_shape(dst); } else { return bias.eq_layout(dst); } }; if (check_eq(bias, dst)) return ret; if (param().format == param::BatchConvBias::Format::NCHW4) { megdnn_assert(bias.shape[0] == 1); megdnn_assert( bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); } } if (z.ndim != 0) { megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); megdnn_assert(z.eq_shape(dst)); } return ret; } } // namespace megdnn // vim: syntax=cpp.doxygen