GitOrigin-RevId: 29e069fb23
tags/v1.3.0
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
from .elemwise import * | from .elemwise import * | ||||
from .img_proc import * | |||||
from .math import * | from .math import * | ||||
from .nn import * | from .nn import * | ||||
from .tensor import * | from .tensor import * | ||||
@@ -0,0 +1,50 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core.ops import builtin | |||||
from ..tensor import Tensor | |||||
__all__ = [ | |||||
"cvt_color", | |||||
] | |||||
def cvt_color(inp: Tensor, mode: str = ""): | |||||
r""" | |||||
Convert images from one format to another | |||||
:param inp: input images. | |||||
:param mode: format mode. | |||||
:return: convert result. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32)) | |||||
y = F.img_proc.cvt_color(x, mode="RGB2GRAY") | |||||
print(y.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[[[0.86555195]]]] | |||||
""" | |||||
assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" | |||||
mode = getattr(builtin.CvtColor.Mode, mode) | |||||
assert isinstance(mode, builtin.CvtColor.Mode) | |||||
op = builtin.CvtColor(mode=mode) | |||||
(out,) = apply(op, inp) | |||||
return out |
@@ -704,3 +704,14 @@ def test_argmxx_on_inf(): | |||||
assert all(run_argmax() >= 0) | assert all(run_argmax() >= 0) | ||||
assert all(run_argmin() >= 0) | assert all(run_argmin() >= 0) | ||||
def test_cvt_color(): | |||||
def rgb2gray(rgb): | |||||
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) | |||||
inp = np.random.randn(3, 3, 3, 3).astype(np.float32) | |||||
out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32) | |||||
x = tensor(inp) | |||||
y = F.img_proc.cvt_color(x, mode="RGB2GRAY") | |||||
np.testing.assert_allclose(y.numpy(), out, atol=1e-5) |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/img_proc.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/imgproc.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const CvtColor&>(def); | |||||
mgb_assert(inputs.size() == 1); | |||||
return opr::CvtColor::make(inputs[0], op.param()); | |||||
} | |||||
OP_TRAIT_REG(CvtColor, CvtColor) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} | |||||
} | |||||
} |
@@ -254,4 +254,6 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> { | |||||
); | ); | ||||
} | } | ||||
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | |||||
#endif // MGB_OPS | #endif // MGB_OPS |