GitOrigin-RevId: 132605c7d9
tags/v1.7.2.m1
@@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) { | |||
return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}}; | |||
case Format::NCHW44_DOT: | |||
return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; | |||
case Format::NHWCD4: | |||
return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; | |||
default: | |||
megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format)) | |||
.c_str()); | |||
@@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { | |||
} | |||
}; | |||
template <> | |||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWCD4> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NHWCD4; | |||
config.config_id = OprFormatConfigID::NHWCD4; | |||
bool available = | |||
opr->input(0)->dtype().enumv() == DTypeEnum::Float32 || | |||
DNN_FLOAT16_SELECT( | |||
(opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true) || | |||
opr->input(0)->dtype().enumv() == DTypeEnum::Int8 || | |||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||
config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
config.input_tensor_types = {TensorType::FEATURE}; | |||
config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
config.input_tensor_formats = {TensorFormats::NHCWc4}; | |||
config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
if (available) | |||
return config; | |||
return None; | |||
} | |||
}; | |||
template <typename Opr, OprFormatConfigID config_id> | |||
struct ConvTensorFormatsDispatcherImpl; | |||
@@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID | |||
} | |||
}; | |||
template <typename Opr> | |||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWCD4> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NHWCD4; | |||
config.config_id = OprFormatConfigID::NHWCD4; | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
config.input_tensor_types.emplace_back(tensor_type); | |||
} | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NHCWc4, TensorFormats::KRSCk4c4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::NHCWc4, TensorFormats::KRSCk4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
if (is_channel_wise_conv<Opr>(opr)) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NHCWc4, TensorFormats::C1RSc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NHCWc4, TensorFormats::GKRSCk4c4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::NHCWc4, TensorFormats::GKRSCk4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} | |||
} | |||
} | |||
config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
return config; | |||
} | |||
}; | |||
template <> | |||
struct ConvTensorFormatsDispatcherImpl< | |||
opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { | |||
@@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl< | |||
} | |||
}; | |||
template <> | |||
struct ConvTensorFormatsDispatcherImpl< | |||
opr::ConvolutionBackwardData, OprFormatConfigID::NHWCD4> { | |||
using Opr = opr::ConvolutionBackwardData; | |||
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NHWCD4; | |||
config.config_id = OprFormatConfigID::NHWCD4; | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
config.input_tensor_types.emplace_back(tensor_type); | |||
} | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
config.input_tensor_formats = { | |||
TensorFormats::KRSCk4c4, TensorFormats::NHCWc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::KRSCk4, TensorFormats::NHCWc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
if (is_channel_wise_conv<Opr>(opr)) { | |||
config.input_tensor_formats = { | |||
TensorFormats::C1RSc4, TensorFormats::NHCWc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
config.input_tensor_formats = { | |||
TensorFormats::GKRSCk4c4, TensorFormats::NHCWc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::GKRSCk4, TensorFormats::NHCWc4, | |||
TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
} | |||
} | |||
} | |||
config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
return config; | |||
} | |||
}; | |||
struct StaticData { | |||
struct KeyHash { | |||
size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const { | |||
@@ -969,6 +1093,7 @@ StaticData::StaticData() { | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWCD4); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC); | |||
@@ -979,15 +1104,18 @@ StaticData::StaticData() { | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWCD4); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWCD4); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWCD4); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); | |||
@@ -997,10 +1125,12 @@ StaticData::StaticData() { | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWCD4); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NHWCD4); | |||
#undef OPR_TENSOR_FORMATS_CONFIG_REG | |||
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | |||
@@ -22,6 +22,7 @@ | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/plugin/base.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/tensor_format.h" | |||
using namespace mgb; | |||
using namespace cg; | |||
@@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator( | |||
std::min(config.input_tensor_formats.size(), opr->input().size()); | |||
for (; i < nr_input_tensor; ++i) { | |||
auto&& var = opr->input(i); | |||
auto&& cn = var->comp_node(); | |||
auto&& dtype = var->dtype(); | |||
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
TensorShape aligned_shape; | |||
if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||
mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | |||
@@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator( | |||
var, base_config.input_tensor_formats[i], | |||
config.input_tensor_formats[i], extra_attribute); | |||
} | |||
dval->resize(aligned_shape); | |||
std::shared_ptr<DeviceTensorND> dval = create_device_tensor_helper( | |||
config, i, var, aligned_shape, extra_attribute); | |||
if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||
new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); | |||
new_inps[i] = | |||
opr::SharedDeviceTensorWithFormat::make_const(*graph, dval).node(); | |||
} else { | |||
new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); | |||
} | |||
@@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node( | |||
const VarNode* var, TensorFormats base_format, const ReformatKey& key) const { | |||
auto&& cn = var->comp_node(); | |||
auto&& dtype = var->dtype(); | |||
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||
var, base_format, key.input_format, key.attribute); | |||
dval->resize(aligned_tensor_shape); | |||
std::shared_ptr<DeviceTensorND> dval; | |||
if (key.input_format == TensorFormats::NHCWc4 && | |||
key.attribute & ReformatAttribute::IMAGE2D) { | |||
size_t align_axis = 2; | |||
auto named_tensor = tensor_formats_to_named_tensor_shape(key.input_format); | |||
for (size_t n = 0; n < named_tensor.ndim; n++) { | |||
if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||
align_axis = n; | |||
break; | |||
} | |||
} | |||
dval = std::make_shared<DeviceTensorND>( | |||
cn, aligned_tensor_shape, dtype, | |||
megdnn::Image2DPack4TensorFormat::make( | |||
align_axis, opr::intl::get_megdnn_handle(cn))); | |||
} else | |||
dval = std::make_shared<DeviceTensorND>(cn, aligned_tensor_shape, dtype); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
graph->options().var_sanity_check_first_run = false; | |||
@@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||
return OprFormatConfigID::NHWC; | |||
case TensorFormats::CHWNc4: | |||
return OprFormatConfigID::CHWN4; | |||
case TensorFormats::NHCWc4: | |||
return OprFormatConfigID::NHWCD4; | |||
default: | |||
mgb_throw( | |||
MegBrainError, "tensor format(%u) is not supported", | |||
@@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||
} | |||
} | |||
std::shared_ptr<DeviceTensorND> ProfilerImpl::create_device_tensor_helper( | |||
const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||
const VarNode* var, const TensorShape aligned_shape, | |||
ReformatAttribute extra_attribute) const { | |||
auto&& cn = var->comp_node(); | |||
auto&& dtype = var->dtype(); | |||
std::shared_ptr<DeviceTensorND> dval; | |||
if (config.config_id == OprFormatConfigID::NHWCD4 && | |||
extra_attribute & ReformatAttribute::IMAGE2D) { | |||
size_t align_axis = 2; | |||
auto named_tensor = tensor_formats_to_named_tensor_shape( | |||
config.input_tensor_formats[inp_idx]); | |||
for (size_t n = 0; n < named_tensor.ndim; n++) { | |||
if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||
align_axis = n; | |||
break; | |||
} | |||
} | |||
// channel wise weight | |||
bool is_channel_wise = | |||
config.input_tensor_formats[inp_idx] == TensorFormats::C1RSc4; | |||
if (is_channel_wise) | |||
align_axis = 1; | |||
dval = std::make_shared<DeviceTensorND>( | |||
cn, aligned_shape, dtype, | |||
megdnn::Image2DPack4TensorFormat::make( | |||
align_axis, opr::intl::get_megdnn_handle(cn))); | |||
} else { | |||
dval = std::make_shared<DeviceTensorND>(cn, aligned_shape, dtype); | |||
} | |||
return dval; | |||
} | |||
/* ================== ProfilerBase =================*/ | |||
std::string ProfilerBase::OperatorNodeRecord::to_string() const { | |||
auto str = ssprintf( | |||
@@ -249,7 +249,7 @@ ReformatManager::ReformatManager() { | |||
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
} | |||
{ | |||
auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4; | |||
auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4; | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
@@ -259,7 +259,7 @@ ReformatManager::ReformatManager() { | |||
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
} | |||
{ | |||
auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4; | |||
auto i = TensorFormats::C11RS, o = TensorFormats::C1RSc4; | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
@@ -270,6 +270,21 @@ ReformatManager::ReformatManager() { | |||
} | |||
{ | |||
auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | |||
auto&& impl1 = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4) | |||
.node(); | |||
}; | |||
m_cache.emplace(ReformatKey{i, o}, impl1); | |||
auto&& impl2 = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4_NCHW) | |||
.node(); | |||
}; | |||
m_cache.emplace(ReformatKey{o, i}, impl2); | |||
} | |||
{ | |||
auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) | |||
@@ -281,7 +296,7 @@ ReformatManager::ReformatManager() { | |||
auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) | |||
vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW) | |||
.node(); | |||
}; | |||
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
@@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const | |||
return rst; | |||
} | |||
} | |||
if (key.attribute == Attribute::IMAGE2D) { | |||
auto key_ = key; | |||
key_.input_dtype = DTypeEnum::Float32; | |||
key_.output_dtype = DTypeEnum::Float32; | |||
auto find = m_cache.find(key_); | |||
if (find != m_cache.end()) { | |||
return find->second; | |||
} | |||
} | |||
mgb_assert( | |||
!(key.attribute & Attribute::IMAGE2D) && | |||
!(key.attribute & Attribute::IC_SMALL)); | |||
@@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape( | |||
auto target_shape = tensor_formats_to_named_tensor_shape(target_formats); | |||
for (size_t i = 0; i < target_shape.ndim; ++i) { | |||
auto name = target_shape[i].name(); | |||
if ((name == Dimension::Name::K || name == Dimension::Name::N) && | |||
if ((name == Dimension::Name::K || name == Dimension::Name::N || | |||
(extra_formats == TensorFormats::NHCWc4 && name == Dimension::Name::C)) && | |||
target_shape[i].extent() == UNDETERMINED_EXTENT) { | |||
size_t out_channels = tshp[i] * target_shape[i].stride(); | |||
tshp[i] = divup(out_channels, out_channel_alignment) * | |||
@@ -32,6 +32,7 @@ static inline const char* opr_format_to_string( | |||
cb(NCHW44); | |||
cb(NCHW88); | |||
cb(NCHW44_DOT); | |||
cb(NHWCD4); | |||
default: | |||
mgb_assert( | |||
false, "Invalid opr format(got:%u)", | |||
@@ -63,6 +64,7 @@ static inline const char* config_id_to_string( | |||
cb(NCHW88_HYBRID); | |||
cb(NCHW44_DOT); | |||
cb(NCHW44_DOT_HYBRID); | |||
cb(NHWCD4); | |||
default: | |||
mgb_assert( | |||
false, "Invalid config id(got:%u)", | |||
@@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||
return TensorFormats::NCHWc8; | |||
case OprFormat::NCHW44_DOT: | |||
return TensorFormats::NCHWc4; | |||
case OprFormat::NHWCD4: | |||
return TensorFormats::NHCWc4; | |||
default: | |||
mgb_throw( | |||
AssertionError, "format(%s) is not supported", | |||
@@ -202,6 +202,11 @@ protected: | |||
const ReformatKey& key) const; | |||
OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; | |||
std::shared_ptr<DeviceTensorND> create_device_tensor_helper( | |||
const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||
const VarNode* var, const TensorShape aligned_shape, | |||
ReformatAttribute extra_attribute) const; | |||
OprFootprint m_opr_footprint; | |||
float m_opr_threshold; /// a threshold, when the computation of the newly | |||
/// created operator that is built in some opr | |||
@@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop() | |||
return ret; | |||
} | |||
void VolatileSharedDeviceTensor::init_output_format() { | |||
output(0)->format(get_dev_tensor().format()); | |||
} | |||
SymbolVar VolatileSharedDeviceTensor::make( | |||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config) { | |||
@@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
public: | |||
using Super::Super; | |||
void init_output_format() override; | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}); | |||