@@ -104,24 +104,27 @@ from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheO | |||
from .version import __version__ | |||
logger = get_logger(__name__) | |||
ngpus = get_device_count("gpu") | |||
supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) | |||
for idx in range(ngpus): | |||
prop = get_cuda_device_property(idx) | |||
cur_sm = str(prop.major * 10 + prop.minor) | |||
if not cur_sm in supported_sm_versions: | |||
logger.warning( | |||
"{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( | |||
prop.name, | |||
cur_sm, | |||
"capabilities" if len(supported_sm_versions) > 1 else "capability", | |||
" ".join(["sm_" + v for v in supported_sm_versions]), | |||
prop.name, | |||
def _check_sm_version(): | |||
cur_logger = get_logger(__name__) | |||
ngpus = get_device_count("gpu") | |||
supported_sm_versions = re.findall(r"sm_(\d+)", _get_supported_sm_versions()) | |||
for idx in range(ngpus): | |||
prop = get_cuda_device_property(idx) | |||
cur_sm = str(prop.major * 10 + prop.minor) | |||
if not cur_sm in supported_sm_versions: | |||
cur_logger.warning( | |||
"{} with CUDA capability sm_{} is not compatible with the current MegEngine installation. The current MegEngine install supports CUDA {} {}. If you want to use the {} with MegEngine, please check the instructions at https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md".format( | |||
prop.name, | |||
cur_sm, | |||
"capabilities" if len(supported_sm_versions) > 1 else "capability", | |||
" ".join(["sm_" + v for v in supported_sm_versions]), | |||
prop.name, | |||
) | |||
) | |||
) | |||
_check_sm_version() | |||
_set_fork_exec_path_for_timed_func( | |||
sys.executable, | |||
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
@@ -16,6 +16,7 @@ from ..core._imperative_rt.core2 import ( | |||
adaptive_pool2d_cpp, | |||
apply, | |||
dtype_promotion, | |||
pixel_shuffle_cpp, | |||
) | |||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
from ..core.ops import builtin | |||
@@ -1849,16 +1850,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): | |||
return layerPixelShuffle | |||
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
""" | |||
Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of | |||
shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero | |||
or more batch dimensions. | |||
:param inp: input tensor. | |||
:param upscale_factor: upscale factor of pixel_shuffle. | |||
:return: output tensor. | |||
""" | |||
def layerPixelShuffle_traceable(inp, upscale_factor): | |||
assert upscale_factor > 0, "upscale_factor should larger than 0" | |||
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" | |||
assert ( | |||
@@ -1899,6 +1891,19 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
return outvar | |||
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: | |||
""" | |||
Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of | |||
shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is | |||
zero or more batch dimensions. | |||
:param inp: input tensor. | |||
:param upscale_factor: upscale factor of pixel_shuffle. | |||
:return: output tensor. | |||
""" | |||
return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable) | |||
from .quantized import conv_bias_activation # isort:skip | |||
from .loss import * # isort:skip | |||
from .metric import * # isort:skip | |||
@@ -349,6 +349,28 @@ std::optional<ValueRefList> removeAxis_grad_rule( | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<ValueRefList> pixelShuffle_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>(); | |||
mgb_assert(inputs.size() == 1); | |||
bool flag = inputs_require_grad[0]; | |||
auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor); | |||
auto maker = CustomGradMaker(backward, inputs.size()); | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
mgb_assert(grads.size() == 1); | |||
ValueRef grad = grads[0]; | |||
SmallVector<ValueRef> ret(1); | |||
if (grad && flag_) { | |||
ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
} | |||
return ret; | |||
}); | |||
maker.finalize(); | |||
return imperative::apply(op, inputs); | |||
} | |||
std::optional<ValueRefList> fastpathcopy_grad_rule( | |||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
CustomBackward& backward) { | |||
@@ -382,6 +404,8 @@ struct Init { | |||
RemoveAxis::typeinfo(), removeAxis_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
CustomBackward::register_grad_rule( | |||
PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | |||
} | |||
} _; | |||
@@ -438,6 +438,7 @@ WRAP_FUNC_PY35(batched_matmul_cpp); | |||
WRAP_FUNC_PY35(convert_single_value_cpp); | |||
WRAP_FUNC_PY35(convert_inputs_cpp); | |||
WRAP_FUNC_PY35(astensor1d_cpp); | |||
WRAP_FUNC_PY35(pixel_shuffle_cpp); | |||
#undef WRAP_FUNC_PY35 | |||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
@@ -595,6 +596,7 @@ void init_tensor(py::module m) { | |||
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||
MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | |||
MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp), | |||
{nullptr, nullptr, 0, nullptr}}; | |||
for (auto&& def : method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
@@ -1378,7 +1378,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||
} else { | |||
auto&& inp_ndim = get_ndim_safe(inp_hdl); | |||
ndim += inp_ndim.first; | |||
unknown_ndim &= ~inp_ndim.second; | |||
unknown_ndim &= !inp_ndim.second; | |||
} | |||
for (size_t i = 0; i < axis.size(); ++i) { | |||
if (axis[i] < 0) { | |||
@@ -1446,6 +1446,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2)); | |||
return ret[0]; | |||
} | |||
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
py::object obj = _expand_args(args); | |||
py::list lis; | |||
@@ -1562,6 +1563,19 @@ py::object _batched_matmul_cpp( | |||
} | |||
} | |||
py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | |||
if (enable_fastpath(inp) && PyLong_Check(val.ptr())) { | |||
std::shared_ptr<OpDef> op = PixelShuffle::make(val.cast<int32_t>()); | |||
py::object Op = py::cast(op); | |||
PyObject* p[2] = {Op.ptr(), inp.ptr()}; | |||
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2)); | |||
return ret[0]; | |||
} else { | |||
// fallback to traceable subgraph implement | |||
return func(inp, val); | |||
} | |||
} | |||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _make_shape_tuple(args[0]).release().ptr(); | |||
@@ -1632,6 +1646,13 @@ PyObject* adaptive_pool2d_cpp(PyObject* self, PyObject* const* args, size_t narg | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _pixel_shuffle_cpp(args[0], args[1], args[2]).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _Const(args[0], args[1], args[2], args[3]).release().ptr(); | |||
@@ -40,4 +40,6 @@ PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs | |||
PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
} // namespace mgb::imperative::python |
@@ -462,3 +462,19 @@ def test_dot(): | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) | |||
def test_pixel_shuffle(): | |||
x = np.random.rand(2, 3, 16, 3, 4).astype("float32") | |||
x = mge.Tensor(x) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
def f(x): | |||
p = F.pixel_shuffle(x, 2) | |||
return p * p | |||
y = f(x) | |||
grad(y, F.ones_like(y)) | |||
np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) |
@@ -255,6 +255,7 @@ def test_conv_bias_int4(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
@pytest.mark.require_ngpu(1) | |||
@pytest.mark.skipif( | |||
get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
@@ -290,6 +290,7 @@ def test_deformable_ps_roi_pooling(): | |||
check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
@pytest.mark.require_ngpu(1) | |||
@pytest.mark.skipif( | |||
get_cuda_compute_capability(0) < 61, | |||
reason="does not support int8 when gpu compute capability less than 6.1", | |||
@@ -0,0 +1,157 @@ | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
using namespace megdnn; | |||
namespace mgb::imperative { | |||
namespace pixel_shuffle { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<PixelShuffle>(); | |||
auto&& src = inputs[0]; | |||
auto&& layout = src->layout(); | |||
mgb_assert( | |||
layout.ndim >= 3, | |||
"the input dimension of pixel_shuffle should be larger than or equal to 3"); | |||
size_t idx = layout.ndim - 3; | |||
mgb_assert( | |||
layout[idx] % (op.factor * op.factor) == 0, | |||
"the -3 dimension should be divided by (upscale_factor ** 2)"); | |||
TensorLayout tlayout; | |||
TensorShape tshp; // {N, C, r, r, H, W} | |||
TensorShape vshp; // {..., C, Hr, Wr} | |||
tshp.ndim = 6; | |||
vshp.ndim = layout.ndim; | |||
tshp[0] = 1; | |||
for (size_t i = 0; i < idx; ++i) { | |||
tshp[0] *= layout[i]; | |||
vshp[i] = layout[i]; | |||
} | |||
tshp[1] = layout[idx] / (op.factor * op.factor); | |||
tshp[2] = tshp[3] = op.factor; | |||
tshp[4] = layout[idx + 1]; | |||
tshp[5] = layout[idx + 2]; | |||
vshp[idx] = tshp[1]; | |||
vshp[idx + 1] = layout[idx + 1] * op.factor; | |||
vshp[idx + 2] = layout[idx + 2] * op.factor; | |||
tlayout = layout.reshape(tshp).dimshuffle({0, 1, 4, 2, 5, 3}); | |||
TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); | |||
out->to_contiguous_inplace(); // relayout | |||
tlayout = out->layout().reshape(vshp); | |||
return {Tensor::make(out->blob(), out->offset(), tlayout)}; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op = def.cast_final_safe<PixelShuffle>(); | |||
mgb_assert(op.factor > 0, "upscale_factor should be larger than 0"); | |||
auto&& src = inputs[0]; | |||
if (src.layout.ndim == 0) { | |||
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; | |||
} | |||
mgb_assert( | |||
src.layout.ndim >= 3, | |||
"the input dimension of pixel_shuffle should be larger than or equal to 3"); | |||
size_t idx = src.layout.ndim - 3; | |||
mgb_assert( | |||
src.layout[idx] % (op.factor * op.factor) == 0, | |||
"the -3 dimension should be divided by (upscale_factor ** 2)"); | |||
TensorShape tshp; | |||
tshp.ndim = src.layout.ndim; | |||
for (size_t i = 0; i < idx; ++i) { | |||
tshp[i] = src.layout[i]; | |||
} | |||
tshp[idx] = src.layout[idx] / (op.factor * op.factor); | |||
tshp[idx + 1] = src.layout[idx + 1] * op.factor; | |||
tshp[idx + 2] = src.layout[idx + 2] * op.factor; | |||
return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
layout_checker[0] = [](const TensorLayout& layout) { | |||
return layout.is_contiguous(); | |||
}; | |||
return layout_checker; | |||
} | |||
OP_TRAIT_REG(PixelShuffle, PixelShuffle) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace pixel_shuffle | |||
namespace pixel_shuffle_backward { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op = def.cast_final_safe<PixelShuffleBackward>(); | |||
auto&& src = inputs[0]; | |||
auto&& layout = src->layout(); | |||
size_t idx = layout.ndim - 3; | |||
TensorLayout tlayout; | |||
TensorShape tshp; // {N, C, H, r, W, r} | |||
TensorShape vshp; // {..., Cr^2, H, W} | |||
tshp.ndim = 6; | |||
vshp.ndim = layout.ndim; | |||
tshp[0] = 1; | |||
for (size_t i = 0; i < idx; ++i) { | |||
tshp[0] *= layout[i]; | |||
vshp[i] = layout[i]; | |||
} | |||
tshp[1] = layout[idx]; | |||
tshp[3] = tshp[5] = op.factor; | |||
tshp[2] = layout[idx + 1] / op.factor; | |||
tshp[4] = layout[idx + 2] / op.factor; | |||
vshp[idx] = tshp[1] * op.factor * op.factor; | |||
vshp[idx + 1] = tshp[2]; | |||
vshp[idx + 2] = tshp[4]; | |||
tlayout = layout.reshape(tshp).dimshuffle({0, 1, 3, 5, 2, 4}); | |||
TensorPtr out = Tensor::make(src->blob(), src->offset(), tlayout); | |||
out->to_contiguous_inplace(); // relayout | |||
tlayout = out->layout().reshape(vshp); | |||
return {Tensor::make(out->blob(), out->offset(), tlayout)}; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& op = def.cast_final_safe<PixelShuffleBackward>(); | |||
auto&& src = inputs[0]; | |||
if (src.layout.ndim == 0) { | |||
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; | |||
} | |||
size_t idx = src.layout.ndim - 3; | |||
TensorShape tshp; | |||
tshp.ndim = src.layout.ndim; | |||
for (size_t i = 0; i < idx; ++i) { | |||
tshp[i] = src.layout[i]; | |||
} | |||
tshp[idx] = src.layout[idx] * op.factor * op.factor; | |||
tshp[idx + 1] = src.layout[idx + 1] / op.factor; | |||
tshp[idx + 2] = src.layout[idx + 2] / op.factor; | |||
return {{{TensorLayout(tshp, src.layout.dtype), src.comp_node}}, true}; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
layout_checker[0] = [](const TensorLayout& layout) { | |||
return layout.is_contiguous(); | |||
}; | |||
return layout_checker; | |||
} | |||
OP_TRAIT_REG(PixelShuffleBackward, PixelShuffleBackward) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace pixel_shuffle_backward | |||
} // namespace mgb::imperative |
@@ -435,6 +435,18 @@ def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>; | |||
def FastpathCopy: MgbHashableOp<"FastpathCopy">; | |||
def PixelShuffle: MgbHashableOp<"PixelShuffle"> { | |||
let extraArguments = (ins | |||
MgbI32Attr:$factor | |||
); | |||
} | |||
def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> { | |||
let extraArguments = (ins | |||
MgbI32Attr:$factor | |||
); | |||
} | |||
def ExternOpr: MgbHashableOp<"ExternOpr"> { | |||
let extraArguments = (ins | |||
MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes, | |||