You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_psroi_pooling.py 995 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # -*- coding: utf-8 -*-
  2. from ..functional import deformable_psroi_pooling
  3. from .module import Module
  4. class DeformablePSROIPooling(Module):
  5. def __init__(
  6. self,
  7. no_trans,
  8. part_size,
  9. pooled_h,
  10. pooled_w,
  11. sample_per_part,
  12. spatial_scale,
  13. trans_std: float = 0.1,
  14. **kwargs
  15. ):
  16. super().__init__(**kwargs)
  17. self.no_trans = no_trans
  18. self.part_size = part_size
  19. self.pooled_h = pooled_h
  20. self.pooled_w = pooled_w
  21. self.sample_per_part = sample_per_part
  22. self.spatial_scale = spatial_scale
  23. self.trans_std = trans_std
  24. def forward(self, inp, rois, trans):
  25. return deformable_psroi_pooling(
  26. inp,
  27. rois,
  28. trans,
  29. self.no_trans,
  30. self.part_size,
  31. self.pooled_h,
  32. self.pooled_w,
  33. self.sample_per_part,
  34. self.spatial_scale,
  35. self.trans_std,
  36. )