GitOrigin-RevId: 82833f41d9
release-1.1
@@ -179,6 +179,11 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum_alias('Format', 'ConvolutionV0') | add_enum_alias('Format', 'ConvolutionV0') | ||||
) | ) | ||||
(pdef('AdaptivePooling'). | |||||
add_enum_alias('Mode', 'Pooling'). | |||||
add_enum_alias('Format', 'ConvolutionV0') | |||||
) | |||||
(pdef('LRN', | (pdef('LRN', | ||||
'see ImageNet Classification with Deep Convolutional Neural Networks for' | 'see ImageNet Classification with Deep Convolutional Neural Networks for' | ||||
' meaning of the fields'). | ' meaning of the fields'). | ||||
@@ -0,0 +1,148 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/adaptive_pooling.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingForward); | |||||
AdaptivePoolingForward::AdaptivePoolingForward(VarNode* src, VarNode* out_shape, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super(OperatorNodeBaseCtorParam{src->owner_graph(), | |||||
config, | |||||
"adaptive_pooling", | |||||
{src, out_shape}}) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({src, out_shape}); | |||||
outshape_by_symvar_enable(1, 1); | |||||
} | |||||
SymbolVar AdaptivePoolingForward::make(SymbolVar src, SymbolVar out_shape, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return src.insert_single_output_opr<AdaptivePoolingForward>( | |||||
src.node(), out_shape.node(), param, config); | |||||
} | |||||
void AdaptivePoolingForward::scn_do_execute() { | |||||
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} | |||||
void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( | |||||
TensorShape& dest, const ShapeInferInfo& shpinfo) { | |||||
TensorShape oshp2d; | |||||
cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); | |||||
auto src = shpinfo.shape_inp_shp.at(0); | |||||
mgb_assert(src.ndim == 4 && oshp2d.ndim == 2, | |||||
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", | |||||
src.to_string().c_str(), oshp2d.to_string().c_str()); | |||||
mgb_assert(param().format == Param::Format::NCHW, | |||||
"AdaptivePooling only support NCHW"); | |||||
dest.ndim = 4; | |||||
dest.shape[0] = src.shape[0]; | |||||
dest.shape[1] = src.shape[1]; | |||||
dest.shape[2] = oshp2d.shape[0]; | |||||
dest.shape[3] = oshp2d.shape[1]; | |||||
} | |||||
size_t AdaptivePoolingForward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return megdnn_opr()->get_workspace_in_bytes( | |||||
{input_shapes[0], this->input(0)->dtype(), | |||||
this->input(0)->format()}, | |||||
{output_shapes[0], this->output(0)->dtype(), | |||||
this->output(0)->format()}); | |||||
} | |||||
void AdaptivePoolingForward::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | |||||
} | |||||
void AdaptivePoolingForward::add_input_layout_constraint() { | |||||
mixin::megdnn_utils::add_input_layout_constraint_contig(*this); | |||||
} | |||||
void AdaptivePoolingForward::init_output_static_infer_desc() { | |||||
Super::init_output_static_infer_desc(); | |||||
init_output_static_infer_desc_workspace(false); | |||||
} | |||||
void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | |||||
} | |||||
#ifdef MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { | |||||
if (wrt_idx == 0) { | |||||
// wrt src | |||||
SymbolVar grad = AdaptivePoolingBackward::make( | |||||
opr.input(0), opr.input(1), opr.output(0), out_grad[0], | |||||
opr.param()); | |||||
return grad.node(); | |||||
} else { | |||||
mgb_assert(wrt_idx == 1); | |||||
return InvalidGrad::make(opr, wrt_idx); | |||||
} | |||||
} | |||||
#endif | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AdaptivePoolingBackward); | |||||
AdaptivePoolingBackward::AdaptivePoolingBackward( | |||||
VarNode* src, VarNode* out_shape, VarNode* dst, VarNode* diff, | |||||
const Param& param, const OperatorNodeConfig& config) | |||||
: Super(OperatorNodeBaseCtorParam{src->owner_graph(), | |||||
config, | |||||
"adaptive_pooling_bwd", | |||||
{src}}, | |||||
0, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({src, out_shape, dst, diff}); | |||||
} | |||||
SymbolVar AdaptivePoolingBackward::make(SymbolVar src, SymbolVar out_shape, | |||||
SymbolVar dst, SymbolVar diff, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return src.insert_single_output_opr<AdaptivePoolingBackward>( | |||||
src.node(), out_shape.node(), dst.node(), diff.node(), param, | |||||
config); | |||||
} | |||||
void AdaptivePoolingBackward::scn_do_execute() { | |||||
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
input(2)->dev_tensor().as_megdnn(), | |||||
input(3)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} | |||||
size_t AdaptivePoolingBackward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return megdnn_opr()->get_workspace_in_bytes( | |||||
{input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
{input_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
{input_shapes[3], input(3)->dtype(), input(3)->format()}, | |||||
{output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -140,6 +140,13 @@ decl_opr('Pooling', | |||||
inputs=['src'], | inputs=['src'], | ||||
params='Pooling') | params='Pooling') | ||||
decl_opr('AdaptivePooling', | |||||
inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | |||||
Doc('out_shape', 'output image shape, containing two elements specifying output height and width.')], | |||||
params='AdaptivePooling', | |||||
desc='Adaptive Pooling.' | |||||
'The output shape is (n, c, oh, ow), where (oh, ow) is given by *out_shape*.') | |||||
decl_opr('ROIPooling', outputs=[0], | decl_opr('ROIPooling', outputs=[0], | ||||
inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | inputs=[Doc('src', 'input image, shape (n, c, ih, iw)'), | ||||
Doc('rois', 'regions of interest, shape (m, 5). ' | Doc('rois', 'regions of interest, shape (m, 5). ' | ||||
@@ -258,7 +265,7 @@ decl_opr('ROIAlign', outputs=[0], | |||||
'store it as a float, but it should be an integral value.' | 'store it as a float, but it should be an integral value.' | ||||
' The rois[:, 1:5] are (x0, y0, x1, y1) for each ROI, ' | ' The rois[:, 1:5] are (x0, y0, x1, y1) for each ROI, ' | ||||
'which would be multiplied by the scale value given in ' | 'which would be multiplied by the scale value given in ' | ||||
'param.')], | |||||
'param.')], | |||||
params='ROIAlign', | params='ROIAlign', | ||||
desc='ROI Align, see ' | desc='ROI Align, see ' | ||||
'Mask-RCNN: https://arxiv.org/pdf/1703.06870.pdf, ' | 'Mask-RCNN: https://arxiv.org/pdf/1703.06870.pdf, ' | ||||
@@ -295,7 +302,7 @@ decl_opr('BatchConvBiasForward', | |||||
('execution_policy', 'ExecutionPolicy')], | ('execution_policy', 'ExecutionPolicy')], | ||||
desc=Doc(None, | desc=Doc(None, | ||||
r""" | r""" | ||||
Apply a convolution of input tensor and filter tensor whose weights are not shared in batch dimensions. Outputs with batch index use the same weight. | |||||
Apply a convolution of input tensor and filter tensor whose weights are not shared in batch dimensions. Outputs with batch index use the same weight. | |||||
Assume input shape is :math:`(N, IC, IH, IW)` and filter shape is :math:`(batch, OC, IC, FH, FW)`, the output shape will be :math:`(N, OC, OH, OW)` where :math:`(OH, OW)` would be computed from padding, stride, :math:`(FH, FW)` and :math:`(IH, IW)`, as in convolution. | Assume input shape is :math:`(N, IC, IH, IW)` and filter shape is :math:`(batch, OC, IC, FH, FW)`, the output shape will be :math:`(N, OC, OH, OW)` where :math:`(OH, OW)` would be computed from padding, stride, :math:`(FH, FW)` and :math:`(IH, IW)`, as in convolution. | ||||
for each output location, we have; | for each output location, we have; | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
@@ -388,6 +389,9 @@ namespace opr { | |||||
MGB_SEREG_OPR(Pooling, 1); | MGB_SEREG_OPR(Pooling, 1); | ||||
MGB_SEREG_OPR(PoolingBackward, 3); | MGB_SEREG_OPR(PoolingBackward, 3); | ||||
MGB_SEREG_OPR(AdaptivePooling, 2); | |||||
MGB_SEREG_OPR(AdaptivePoolingBackward, 4); | |||||
MGB_SEREG_OPR(ROIPooling, 3); | MGB_SEREG_OPR(ROIPooling, 3); | ||||
MGB_SEREG_OPR(ROIPoolingBackward, 4); | MGB_SEREG_OPR(ROIPoolingBackward, 4); | ||||
@@ -0,0 +1,76 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/adaptive_pooling.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS( | |||||
AdaptivePoolingForward, | |||||
intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr< | |||||
mixin::MegDNNOprHolderImpl<megdnn::AdaptivePoolingForward>>>) // { | |||||
public: | |||||
AdaptivePoolingForward(VarNode * src, VarNode * out_shape, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar src, SymbolVar out_shape, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config = {}); | |||||
static SymbolVar make(SymbolVar src, const TensorShape& out_shape, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(src, cg::var_from_tensor_shape(src, out_shape), param, | |||||
config); | |||||
} | |||||
private: | |||||
void scn_do_execute() override; | |||||
void outshape_by_symvar_do_get_output_shape( | |||||
TensorShape & dest, const ShapeInferInfo& shpinfo) override; | |||||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) | |||||
const override; | |||||
void init_output_dtype() override; | |||||
void add_input_layout_constraint() override; | |||||
void init_output_static_infer_desc() override; | |||||
void record_execute_deps(ExecDependencyArray& deps) override; | |||||
}; | |||||
using AdaptivePooling = AdaptivePoolingForward; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
AdaptivePoolingBackward, | |||||
intl::MegDNNOprWrapperBwd<megdnn::AdaptivePoolingBackward>) // { | |||||
public: | |||||
AdaptivePoolingBackward(VarNode * src, VarNode * out_shape, VarNode * dst, | |||||
VarNode * diff, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar src, SymbolVar out_shape, SymbolVar dst, | |||||
SymbolVar diff, const Param& param, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void scn_do_execute() override; | |||||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) | |||||
const override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* \file src/opr/test/dnn/adaptive_pooling.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/test/autocheck.h" | |||||
#include "megbrain/test/megdnn_helper.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
using namespace std; | |||||
using namespace mgb; | |||||
namespace { | |||||
using Param = opr::AdaptivePoolingForward::Param; | |||||
void run(Param::Mode mode) { | |||||
using Checker = AutoOprChecker<2, 1>; | |||||
Param param{mode}; | |||||
auto make_graph = | |||||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
auto o0 = opr::GetVarShape::make(inputs[1]); | |||||
auto o1 = opr::AdaptivePoolingForward::make(inputs[0], o0, param); | |||||
return {o1}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto opr = MegDNNHandle::get( | |||||
CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
->create_operator<megdnn::AdaptivePoolingForward>(); | |||||
opr->param() = param; | |||||
size_t N = inp[0].get()->shape(0), C = inp[0].get()->shape(1); | |||||
size_t OH = inp[1].get()->shape(0), OW = inp[1].get()->shape(1); | |||||
dest[0].resize(TensorShape{N, C, OH, OW}); | |||||
opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), {}); | |||||
}; | |||||
auto gen = [&](HostTensorND& src) { | |||||
if (mode == Param::Mode::MAX) { | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::CONSECUTIVE> | |||||
src_gen(1.0f, 0.1f); | |||||
src = *src_gen(src.shape(), src.comp_node()); | |||||
} else { | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> | |||||
src_gen(10.f); | |||||
src = *src_gen(src.shape(), src.comp_node()); | |||||
} | |||||
}; | |||||
Checker::RunOptions opt; | |||||
opt.numdiff_max_err = 1e-2; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.set_input_allow_grad(1, false) | |||||
.set_input_generator(0, gen); | |||||
checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt); | |||||
checker.run({TensorShape{1, 1, 9, 7}, TensorShape{5, 4}}, opt); | |||||
checker.run({TensorShape{1, 2, 8, 9}, TensorShape{3, 4}}, opt); | |||||
} | |||||
} // anonymous namespace | |||||
TEST(TestOprDNN, AdaptivePoolingMax) { | |||||
run(Param::Mode::MAX); | |||||
} | |||||
TEST(TestOprDNN, AdaptivePoolingAverage) { | |||||
run(Param::Mode::AVERAGE); | |||||
} | |||||
TEST(TestOprDNN, AdaptivePoolingAverageCountExcludePadding) { | |||||
run(Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -99,6 +99,7 @@ union OperatorParam { | |||||
DType = 67, | DType = 67, | ||||
param.Remap = 68, | param.Remap = 68, | ||||
param.NMSKeep = 69, | param.NMSKeep = 69, | ||||
param.AdaptivePooling = 70, | |||||
} | } | ||||
table Operator { | table Operator { | ||||
@@ -113,6 +113,20 @@ dtype, RandomDistribution::CONSTANT>::operator ()( | |||||
return ret; | return ret; | ||||
} | } | ||||
template<typename dtype> | |||||
std::shared_ptr<HostTensorND> HostTensorGenerator< | |||||
dtype, RandomDistribution::CONSECUTIVE>::operator ()( | |||||
const TensorShape &shape, CompNode cn) { | |||||
if (!cn.valid()) | |||||
cn = CompNode::load("xpu0"); | |||||
std::shared_ptr<HostTensorND> ret = | |||||
std::make_shared<HostTensorND>(cn, shape, dtype()); | |||||
auto ptr = ret->ptr<ctype>(); | |||||
for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++ i) { | |||||
ptr[i] = m_val + i * m_delta; | |||||
} | |||||
return ret; | |||||
} | |||||
// explicit instantialization of HostTensorGenerator | // explicit instantialization of HostTensorGenerator | ||||
namespace mgb { | namespace mgb { | ||||
@@ -123,12 +137,16 @@ namespace mgb { | |||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Float32, RandomDistribution::CONSTANT>; | dtype::Float32, RandomDistribution::CONSTANT>; | ||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Float32, RandomDistribution::CONSECUTIVE>; | |||||
template class HostTensorGenerator< | |||||
dtype::Float16, RandomDistribution::GAUSSIAN>; | dtype::Float16, RandomDistribution::GAUSSIAN>; | ||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Int8, RandomDistribution::UNIFORM>; | dtype::Int8, RandomDistribution::UNIFORM>; | ||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Int8, RandomDistribution::CONSTANT>; | dtype::Int8, RandomDistribution::CONSTANT>; | ||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Int8, RandomDistribution::CONSECUTIVE>; | |||||
template class HostTensorGenerator< | |||||
dtype::Uint8, RandomDistribution::UNIFORM>; | dtype::Uint8, RandomDistribution::UNIFORM>; | ||||
template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
dtype::Uint8, RandomDistribution::CONSTANT>; | dtype::Uint8, RandomDistribution::CONSTANT>; | ||||
@@ -168,7 +168,7 @@ class RNGxorshf { | |||||
}; | }; | ||||
enum class RandomDistribution { | enum class RandomDistribution { | ||||
GAUSSIAN, UNIFORM, CONSTANT | |||||
GAUSSIAN, UNIFORM, CONSTANT, CONSECUTIVE | |||||
}; | }; | ||||
template<class dtype> | template<class dtype> | ||||
@@ -342,6 +342,29 @@ class HostTensorGenerator<dtype, RandomDistribution::CONSTANT> final: | |||||
private: | private: | ||||
ctype m_default_val; | ctype m_default_val; | ||||
}; | }; | ||||
//! consecutive value | |||||
template<class dtype> | |||||
class HostTensorGenerator<dtype, RandomDistribution::CONSECUTIVE> final: | |||||
public HostTensorGeneratorBase { | |||||
public: | |||||
using ctype = typename DTypeTrait<dtype>::ctype; | |||||
HostTensorGenerator(ctype val, ctype delta) | |||||
: HostTensorGeneratorBase{next_rand_seed()}, | |||||
m_val{val}, m_delta{delta} {} | |||||
std::shared_ptr<HostTensorND> operator ()( | |||||
const TensorShape &shape, CompNode cn = {}) override; | |||||
using HostTensorGeneratorBase::operator(); | |||||
private: | |||||
ctype m_val; | |||||
ctype m_delta; | |||||
}; | |||||
template <> | template <> | ||||
class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final | class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final | ||||
: public HostTensorGeneratorBase { | : public HostTensorGeneratorBase { | ||||