|
|
@@ -98,6 +98,22 @@ dtype, RandomDistribution::UNIFORM>::operator ()( |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename dtype> |
|
|
|
std::shared_ptr<HostTensorND> HostTensorGenerator< |
|
|
|
dtype, RandomDistribution::CONSTANT>::operator ()( |
|
|
|
const TensorShape &shape, CompNode cn) { |
|
|
|
if (!cn.valid()) |
|
|
|
cn = CompNode::load("xpu0"); |
|
|
|
std::shared_ptr<HostTensorND> ret = |
|
|
|
std::make_shared<HostTensorND>(cn, shape, dtype()); |
|
|
|
auto ptr = ret->ptr<ctype>(); |
|
|
|
for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++ i) { |
|
|
|
ptr[i] = m_default_val; |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// explicit instantialization of HostTensorGenerator |
|
|
|
namespace mgb { |
|
|
|
template class HostTensorGenerator< |
|
|
@@ -105,15 +121,25 @@ namespace mgb { |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Float32, RandomDistribution::UNIFORM>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Float32, RandomDistribution::CONSTANT>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Float16, RandomDistribution::GAUSSIAN>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int8, RandomDistribution::UNIFORM>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int8, RandomDistribution::CONSTANT>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Uint8, RandomDistribution::UNIFORM>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Uint8, RandomDistribution::CONSTANT>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int16, RandomDistribution::UNIFORM>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int16, RandomDistribution::CONSTANT>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int32, RandomDistribution::UNIFORM>; |
|
|
|
template class HostTensorGenerator< |
|
|
|
dtype::Int32, RandomDistribution::CONSTANT>; |
|
|
|
std::shared_ptr<HostTensorND> |
|
|
|
HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM>:: |
|
|
|
operator()(const TensorShape& shape, CompNode cn) { |
|
|
|