GitOrigin-RevId: c35a219e52
tags/v1.7.2.m1
@@ -1344,7 +1344,7 @@ protected: | |||||
* \brief check whether input contains inf or nan value. | * \brief check whether input contains inf or nan value. | ||||
*/ | */ | ||||
class CheckNonFinite : public OperatorBase { | class CheckNonFinite : public OperatorBase { | ||||
DEF_OPR_PARAM(Empty); | |||||
DEF_OPR_PARAM(CheckNonFinite); | |||||
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1); | DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1); | ||||
size_t m_size = 0; | size_t m_size = 0; | ||||
@@ -1176,6 +1176,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
) | ) | ||||
pdef('Fill').add_fields('float32', 'value', '0') | pdef('Fill').add_fields('float32', 'value', '0') | ||||
pdef('CheckNonFinite').add_fields('float32', 'scale', '1.0') | |||||
PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | ||||
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | ||||
@@ -156,37 +156,6 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_> | |||||
struct CheckNonFiniteOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
RefPtr* srcs; | |||||
RefPtr srcs_total_nr_elems; | |||||
RefPtr dst; | |||||
const size_t B; | |||||
wtype read(uint32_t idx) { | |||||
size_t x = idx / B; | |||||
size_t y = idx % B; | |||||
if (y < srcs_total_nr_elems.ptr<index_ctype>()[x]) { | |||||
RefPtr src = srcs[x]; | |||||
return !std::isfinite(src.ptr<src_ctype>()[y]); | |||||
} | |||||
return 0; | |||||
} | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } | |||||
CheckNonFiniteOp( | |||||
RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst, | |||||
size_t B) | |||||
: INIT(wtype(0)), | |||||
srcs(srcs), | |||||
srcs_total_nr_elems(srcs_total_nr_elems), | |||||
dst(dst), | |||||
B(B) {} | |||||
}; | |||||
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | ||||
} // namespace reduce | } // namespace reduce | ||||
@@ -194,6 +194,7 @@ struct CheckNonFiniteOp { | |||||
index_ctype* srcs_total_nr_elems; | index_ctype* srcs_total_nr_elems; | ||||
dst_ctype* dst; | dst_ctype* dst; | ||||
const size_t B; | const size_t B; | ||||
const src_ctype scale; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | ||||
size_t x = idx / B; | size_t x = idx / B; | ||||
@@ -204,6 +205,8 @@ struct CheckNonFiniteOp { | |||||
#else | #else | ||||
wtype val = std::isfinite(srcs[x][y]); | wtype val = std::isfinite(srcs[x][y]); | ||||
#endif | #endif | ||||
if (val) | |||||
srcs[x][y] *= scale; | |||||
return !val; | return !val; | ||||
} | } | ||||
return 0; | return 0; | ||||
@@ -214,12 +217,13 @@ struct CheckNonFiniteOp { | |||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( | MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( | ||||
src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, | src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, | ||||
size_t B) | |||||
size_t B, src_ctype scale) | |||||
: INIT(wtype(0)), | : INIT(wtype(0)), | ||||
srcs(srcs), | srcs(srcs), | ||||
srcs_total_nr_elems(srcs_total_nr_elems), | srcs_total_nr_elems(srcs_total_nr_elems), | ||||
dst(dst), | dst(dst), | ||||
B(B) {} | |||||
B(B), | |||||
scale(scale) {} | |||||
}; | }; | ||||
} // namespace device_reduce | } // namespace device_reduce | ||||
@@ -97,7 +97,7 @@ void CheckNonFiniteImpl::exec( | |||||
workspace_gpu.total_size_in_bytes())), | workspace_gpu.total_size_in_bytes())), | ||||
1, m_size * total_nr_elems_max, 1, stream, | 1, m_size * total_nr_elems_max, 1, stream, | ||||
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(), | Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(), | ||||
total_nr_elems_max)); | |||||
total_nr_elems_max, param().scale)); | |||||
} | } | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -19,7 +19,7 @@ using namespace megdnn; | |||||
#define wtype dt_int32 | #define wtype dt_int32 | ||||
void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { | |||||
void reduce_fwd(const TensorNDArray& srcs, wtype* dptr, dt_float32 scale) { | |||||
dptr[0] = 0; | dptr[0] = 0; | ||||
for (auto src : srcs) { | for (auto src : srcs) { | ||||
auto sptr = src.ptr<dt_float32>(); | auto sptr = src.ptr<dt_float32>(); | ||||
@@ -31,6 +31,8 @@ void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { | |||||
return func(l, mid) | func(mid, r); | return func(l, mid) | func(mid, r); | ||||
} else { | } else { | ||||
auto val = std::isfinite(sptr[l]); | auto val = std::isfinite(sptr[l]); | ||||
if (val) | |||||
sptr[l] *= scale; | |||||
return static_cast<wtype>(!val); | return static_cast<wtype>(!val); | ||||
} | } | ||||
}; | }; | ||||
@@ -47,9 +49,9 @@ void CheckNonFiniteImpl::exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, | _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(srcs, dst, workspace.size); | check_exec(srcs, dst, workspace.size); | ||||
float scale = param().scale; | |||||
auto handle = static_cast<HandleImpl*>(this->handle()); | auto handle = static_cast<HandleImpl*>(this->handle()); | ||||
MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>())); | |||||
MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>(), scale)); | |||||
} | } | ||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -128,28 +128,28 @@ class GradScaler: | |||||
grad_tensors: Tensors needed to unscale grads. Should be all tensors | grad_tensors: Tensors needed to unscale grads. Should be all tensors | ||||
that are affected by ``target`` tensor in GradManager's backward. | that are affected by ``target`` tensor in GradManager's backward. | ||||
""" | """ | ||||
# to support tracing, _check_gradients should be applied to every grad. | |||||
if self._check_gradients([x.grad for x in grad_tensors]): | |||||
self._found_non_finite = True | |||||
if self._found_non_finite: | |||||
for tensor in grad_tensors: | |||||
if tensor is None or getattr(tensor, "grad", None) is None: | |||||
continue | |||||
tensor.grad = None | |||||
else: | |||||
if self.growth_interval == 0: | |||||
# use float64 for better precision | # use float64 for better precision | ||||
inv_scale = Tensor(1.0 / self.scale_factor) | inv_scale = Tensor(1.0 / self.scale_factor) | ||||
for tensor in grad_tensors: | for tensor in grad_tensors: | ||||
if tensor is None or getattr(tensor, "grad", None) is None: | if tensor is None or getattr(tensor, "grad", None) is None: | ||||
continue | continue | ||||
tensor.grad *= inv_scale | tensor.grad *= inv_scale | ||||
return self | |||||
# to support tracing, _check_gradients should be applied to every grad. | |||||
if self._check_gradients( | |||||
[x.grad for x in grad_tensors], 1.0 / self.scale_factor | |||||
): | |||||
self._found_non_finite = True | |||||
for tensor in grad_tensors: | |||||
if tensor is None or getattr(tensor, "grad", None) is None: | |||||
continue | |||||
tensor.grad = None | |||||
return self | return self | ||||
def _check_gradients(self, grad): | |||||
if self.growth_interval == 0: | |||||
return False | |||||
return _check_non_finite(grad) | |||||
def _check_gradients(self, grad, scale): | |||||
return _check_non_finite(grad, scale) | |||||
def update(self, new_scale: float = None): | def update(self, new_scale: float = None): | ||||
r"""Update the scale factor according to whether encountered overflow grad. | r"""Update the scale factor according to whether encountered overflow grad. | ||||
@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||||
return U, sigma, V | return U, sigma, V | ||||
def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: | |||||
def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: | |||||
r"""Check whether input contains infinite or nan value. | r"""Check whether input contains infinite or nan value. | ||||
Args: | Args: | ||||
@@ -1192,7 +1192,11 @@ def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: | |||||
Returns: | Returns: | ||||
a int32 scalar tensor, 0 for False and 1 for True. | a int32 scalar tensor, 0 for False and 1 for True. | ||||
""" | """ | ||||
op = builtin.CheckNonFinite() | |||||
(oup,) = apply(op, *inps) | |||||
oup._setscalar() | |||||
return oup | |||||
op = builtin.CheckNonFinite(scale=scale) | |||||
oups = apply(op, *inps) | |||||
out = oups[-1] | |||||
for i in range(len(inps)): | |||||
inps[i]._reset(oups[i]) | |||||
out._setscalar() | |||||
return out |
@@ -191,17 +191,21 @@ def test_sum_neg_axis(): | |||||
def test_non_finite(): | def test_non_finite(): | ||||
shape = (32, 3, 32, 32) | shape = (32, 3, 32, 32) | ||||
data1 = np.random.random(shape).astype(np.float32) | |||||
data2 = np.random.random(shape).astype(np.float32) | |||||
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) | |||||
data = [] | |||||
for i in range(2): | |||||
data.append(np.random.random(shape).astype(np.float32)) | |||||
tensorList = [tensor(x) for x in data] | |||||
rst = F.math._check_non_finite(tensorList, 0.7) | |||||
np.testing.assert_equal(rst.numpy(), [0]) | np.testing.assert_equal(rst.numpy(), [0]) | ||||
for i in range(len(tensorList)): | |||||
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) | |||||
data2[0][0][0][0] = float("inf") | |||||
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) | |||||
data[1][0][0][0][0] = float("inf") | |||||
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | np.testing.assert_equal(rst.numpy(), [1]) | ||||
data2[0][0][0][0] = float("nan") | |||||
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) | |||||
data[1][0][0][0][0] = float("nan") | |||||
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) | |||||
np.testing.assert_equal(rst.numpy(), [1]) | np.testing.assert_equal(rst.numpy(), [1]) | ||||
@@ -17,44 +17,62 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
namespace check_non_finite { | namespace check_non_finite { | ||||
SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = def.cast_final_safe<CheckNonFinite>(); | auto&& op = def.cast_final_safe<CheckNonFinite>(); | ||||
OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
return opr::CheckNonFinite::make(inputs, {}, config); | |||||
return opr::CheckNonFinite::make(inputs, op.param(), config); | |||||
} | } | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
size_t size = inputs.size(); | size_t size = inputs.size(); | ||||
auto dest = Tensor::make( | |||||
auto&& op = def.cast_final_safe<CheckNonFinite>(); | |||||
SmallVector<TensorPtr> outputs(size + 1); | |||||
outputs[size] = Tensor::make( | |||||
TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node()); | TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node()); | ||||
auto dest = outputs[size]; | |||||
auto cn = dest->comp_node(); | auto cn = dest->comp_node(); | ||||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn); | auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn); | ||||
size_t wk_size = 0; | size_t wk_size = 0; | ||||
SmallVector<megdnn::TensorND> srcs(size); | SmallVector<megdnn::TensorND> srcs(size); | ||||
// copy an outputs to the dnn for inplace | |||||
for (size_t i = 0; i < size; ++i) { | for (size_t i = 0; i < size; ++i) { | ||||
srcs[i] = inputs[i]->dev_tensor().as_megdnn(); | |||||
outputs[i] = Tensor::make(inputs[i]->layout(), inputs[0]->comp_node()); | |||||
outputs[i]->dev_tensor().copy_from_fixlayout(inputs[i]->dev_tensor()); | |||||
srcs[i] = outputs[i]->dev_tensor().as_megdnn(); | |||||
} | } | ||||
megdnn::CheckNonFinite::Param param({op.scale}); | |||||
dnn_opr->param() = param; | |||||
wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout()); | wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout()); | ||||
auto wk = Blob::make(cn, wk_size); | auto wk = Blob::make(cn, wk_size); | ||||
megdnn::Workspace dnn_wk(wk->storage().get(), wk_size); | megdnn::Workspace dnn_wk(wk->storage().get(), wk_size); | ||||
dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk); | dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk); | ||||
return {dest}; | |||||
return outputs; | |||||
} | } | ||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
SmallVector<LogicalTensorDesc> dests(1); | |||||
dests[0].comp_node = inputs[0].comp_node; | |||||
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||||
size_t size = inputs.size(); | |||||
SmallVector<LogicalTensorDesc> dests(size + 1); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
dests[i].comp_node = inputs[i].comp_node; | |||||
dests[i].layout = inputs[i].layout; | |||||
} | |||||
dests[size].comp_node = inputs[0].comp_node; | |||||
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||||
return {dests, true}; | return {dests, true}; | ||||
} | } | ||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | SmallVector<LogicalTensorDesc> infer_output_attrs( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
SmallVector<LogicalTensorDesc> dests(1); | |||||
dests[0].comp_node = inputs[0]->comp_node(); | |||||
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||||
size_t size = inputs.size(); | |||||
SmallVector<LogicalTensorDesc> dests(size + 1); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
dests[i].comp_node = inputs[i]->comp_node(); | |||||
dests[i].layout = inputs[i]->layout(); | |||||
} | |||||
dests[size].comp_node = inputs[0]->comp_node(); | |||||
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||||
return dests; | return dests; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | ||||
@@ -397,7 +397,7 @@ def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> { | |||||
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | ||||
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; | |||||
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>; | |||||
def FastpathCopy: MgbHashableOp<"FastpathCopy">; | def FastpathCopy: MgbHashableOp<"FastpathCopy">; | ||||
@@ -487,39 +487,60 @@ CheckNonFinite::CheckNonFinite( | |||||
const VarNodeArrayView& inp, const Param& param, | const VarNodeArrayView& inp, const Param& param, | ||||
const OperatorNodeConfig& config) | const OperatorNodeConfig& config) | ||||
: Super(OperatorNodeBaseCtorParam{ | : Super(OperatorNodeBaseCtorParam{ | ||||
inp[0]->owner_graph(), config, "check_non_finite", inp}) { | |||||
inp[0]->owner_graph(), config, "check_non_finite", inp}), | |||||
m_scale(param.scale) { | |||||
mgb_assert(!inp.empty()); | mgb_assert(!inp.empty()); | ||||
for (auto&& i : inp) { | for (auto&& i : inp) { | ||||
add_input({i}); | add_input({i}); | ||||
add_output(None) | |||||
->dtype(dtype::Float32()) | |||||
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
} | } | ||||
add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | ||||
cg::add_workspace_output(this); | cg::add_workspace_output(this); | ||||
} | } | ||||
SymbolVar CheckNonFinite::make( | |||||
SymbolVarArray CheckNonFinite::make( | |||||
const VarNodeArrayView& inp, const Param& param, | const VarNodeArrayView& inp, const Param& param, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
mgb_assert(!inp.empty()); | mgb_assert(!inp.empty()); | ||||
intl::BatchedDTypePromotion dtp{inp}; | intl::BatchedDTypePromotion dtp{inp}; | ||||
return SymbolVar{inp[0]}.insert_single_output_opr<CheckNonFinite>( | |||||
dtp.get_vars(), param, config); | |||||
auto outputs = | |||||
inp[0]->owner_graph() | |||||
->insert_opr(std::make_unique<CheckNonFinite>(inp, param, config)) | |||||
->output(); | |||||
mgb_assert(outputs.size() == inp.size() + 2); | |||||
SymbolVarArray ret(outputs.size() - 1); | |||||
for (size_t i = 0; i < ret.size(); ++i) | |||||
ret[i] = outputs[i]; | |||||
return ret; | |||||
} | } | ||||
void CheckNonFinite::scn_do_execute() { | void CheckNonFinite::scn_do_execute() { | ||||
megdnn::TensorNDArray inp_arr(input().size()); | |||||
for (size_t i = 0; i < input().size(); ++i) { | |||||
inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); | |||||
size_t size = input().size(); | |||||
megdnn::TensorNDArray oup_arr(size); | |||||
// copy an outputs to the dnn for inplace | |||||
for (size_t i = 0; i < size; ++i) { | |||||
oup_arr[i] = output(i) | |||||
->dev_tensor() | |||||
.copy_from_fixlayout(input(i)->dev_tensor()) | |||||
.as_megdnn(); | |||||
} | } | ||||
megdnn_opr()->param().scale = m_scale; | |||||
megdnn_opr()->exec( | megdnn_opr()->exec( | ||||
inp_arr, output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output(1))); | |||||
oup_arr, output(size)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output(size + 1))); | |||||
} | } | ||||
void CheckNonFinite::init_output_static_infer_desc() { | void CheckNonFinite::init_output_static_infer_desc() { | ||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto&& mgr = owner_graph()->static_infer_manager(); | auto&& mgr = owner_graph()->static_infer_manager(); | ||||
size_t size = input().size(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
mgr.register_shape_infer(output(i), ShapeInferDesc::make_identity(input(i))); | |||||
} | |||||
auto infer_oshp = [](TensorShape& dest, const InpVal& iv) { | auto infer_oshp = [](TensorShape& dest, const InpVal& iv) { | ||||
TensorLayout dst; | TensorLayout dst; | ||||
dst.shape[0] = 1; | dst.shape[0] = 1; | ||||
@@ -532,7 +553,7 @@ void CheckNonFinite::init_output_static_infer_desc() { | |||||
DepVal deps; | DepVal deps; | ||||
for (auto i : input()) | for (auto i : input()) | ||||
deps.push_back({i, DepType::SHAPE}); | deps.push_back({i, DepType::SHAPE}); | ||||
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp}); | |||||
mgr.register_shape_infer(output(size), {SourceType::DEP, deps, infer_oshp}); | |||||
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { | auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { | ||||
dest.ndim = 1; | dest.ndim = 1; | ||||
@@ -541,10 +562,11 @@ void CheckNonFinite::init_output_static_infer_desc() { | |||||
inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}}; | inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}}; | ||||
} | } | ||||
dest.shape[0] = megdnn_opr()->get_workspace_in_bytes( | dest.shape[0] = megdnn_opr()->get_workspace_in_bytes( | ||||
inp_arr, {output(0)->shape(), output(0)->dtype()}); | |||||
inp_arr, {output(input().size() + 1)->shape(), | |||||
output(input().size() + 1)->dtype()}); | |||||
return true; | return true; | ||||
}; | }; | ||||
mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk}); | |||||
mgr.register_shape_infer(output(size + 1), {SourceType::DEP, deps, infer_wk}); | |||||
} | } | ||||
void CheckNonFinite::add_input_layout_constraint() { | void CheckNonFinite::add_input_layout_constraint() { | ||||
@@ -56,7 +56,16 @@ struct OprMaker<opr::TopK, 2> { | |||||
}; | }; | ||||
template <> | template <> | ||||
struct OprMaker<opr::CheckNonFinite, 0> : public OprMakerVariadic<opr::CheckNonFinite> { | |||||
struct OprMaker<opr::CheckNonFinite, 0> { | |||||
using Opr = opr::CheckNonFinite; | |||||
using Param = Opr::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
auto out = Opr::make(inputs, param, config); | |||||
return out[0].node()->owner_opr(); | |||||
} | |||||
}; | }; | ||||
} // namespace serialization | } // namespace serialization | ||||
@@ -183,18 +183,19 @@ public: | |||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{ | |||||
void scn_do_execute() override; | |||||
void init_output_static_infer_desc() override; | |||||
void add_input_layout_constraint() override; | |||||
MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) // { | |||||
void scn_do_execute() override; | |||||
void init_output_static_infer_desc() override; | |||||
void add_input_layout_constraint() override; | |||||
float m_scale = 1; | |||||
public: | public: | ||||
MGE_WIN_DECLSPEC_FUC CheckNonFinite( | |||||
const VarNodeArrayView& inp, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
const VarNodeArrayView& inp, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC CheckNonFinite( | |||||
const VarNodeArrayView& inp, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
const VarNodeArrayView& inp, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | }; | ||||
} // namespace opr | } // namespace opr | ||||
@@ -115,6 +115,7 @@ union OperatorParam { | |||||
param.SlidingWindowTranspose = 81, | param.SlidingWindowTranspose = 81, | ||||
param.Padding = 82, | param.Padding = 82, | ||||
param.ShuffleRNG = 83, | param.ShuffleRNG = 83, | ||||
param.CheckNonFinite = 84, | |||||
} | } | ||||
table Operator { | table Operator { | ||||