You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_conv_bias.cpp 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #include "megdnn/oprs.h"
  2. #include "megdnn/oprs/nn_int.h"
  3. #include "src/common/utils.h"
  4. namespace megdnn {
  5. void BatchConvBiasForward::deduce_dtype(
  6. DType src, DType filter, DType /* bias */, DType /* z */, DType& dst) {
  7. check_or_deduce_dtype_fwd(src, filter, dst);
  8. }
  9. void BatchConvBiasForward::deduce_layout(
  10. const TensorLayout& src, const TensorLayout& filter,
  11. const TensorLayout& /* bias */, const TensorLayout& /* z */,
  12. TensorLayout& dst) {
  13. TensorLayout non_batch_filter;
  14. non_batch_filter.ndim = filter.ndim - 1;
  15. non_batch_filter.dtype = filter.dtype;
  16. for (size_t i = 0; i < non_batch_filter.ndim; i++) {
  17. non_batch_filter[i] = filter[i + 1];
  18. non_batch_filter.stride[i] = filter.stride[i + 1];
  19. }
  20. non_batch_filter.format = filter.format;
  21. deduce_layout_fwd(src, non_batch_filter, dst);
  22. }
  23. BatchConvBiasForward::CanonizedFilterMeta BatchConvBiasForward::check_exec(
  24. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias,
  25. const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes) {
  26. megdnn_assert(
  27. src.dtype.enumv() == filter.dtype.enumv() &&
  28. src.dtype.enumv() == DTypeEnum::QuantizedS8,
  29. "batch conv only support qint8");
  30. float scale_src = src.dtype.param<dtype::QuantizedS8>().scale;
  31. float scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale;
  32. float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
  33. megdnn_assert(
  34. std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
  35. "scale_bias is not equal to the product of scale_src and "
  36. "scale_filter (scale_src: %f scale_filter: %f scale_bias: %f).",
  37. scale_src, scale_filter, scale_bias);
  38. TensorLayout non_batch_filter;
  39. non_batch_filter.ndim = filter.ndim - 1;
  40. non_batch_filter.dtype = filter.dtype;
  41. for (size_t i = 0; i < non_batch_filter.ndim; i++) {
  42. non_batch_filter[i] = filter[i + 1];
  43. non_batch_filter.stride[i] = filter.stride[i + 1];
  44. }
  45. non_batch_filter.format = filter.format;
  46. auto ret = check_layout_fwd(src, non_batch_filter, dst);
  47. megdnn_assert_contiguous(bias);
  48. auto required_workspace_in_bytes =
  49. get_workspace_in_bytes(src, filter, bias, z, dst);
  50. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  51. if (bias.ndim != 0) {
  52. //! bias.layout == dst.layout failed, no assert information
  53. auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) {
  54. if (dst.dtype.category() == DTypeCategory::QUANTIZED) {
  55. return bias.eq_shape(dst);
  56. } else {
  57. return bias.eq_layout(dst);
  58. }
  59. };
  60. if (check_eq(bias, dst))
  61. return ret;
  62. if (param().format == param::BatchConvBias::Format::NCHW4) {
  63. megdnn_assert(bias.shape[0] == 1);
  64. megdnn_assert(
  65. bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
  66. bias.to_string().c_str(), dst.to_string().c_str());
  67. megdnn_assert(bias.shape[2] == 1);
  68. megdnn_assert(bias.shape[3] == 1);
  69. megdnn_assert(bias.shape[4] == 4);
  70. }
  71. }
  72. if (z.ndim != 0) {
  73. megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
  74. megdnn_assert(z.eq_shape(dst));
  75. }
  76. return ret;
  77. }
  78. } // namespace megdnn
  79. // vim: syntax=cpp.doxygen