@@ -8,10 +8,9 @@ namespace cuda { | |||||
#define COMMA , | #define COMMA , | ||||
#define cb(_dtype) \ | |||||
INST_REDUCE( \ | |||||
device_reduce::CheckNonFiniteOp< \ | |||||
_dtype COMMA dt_float32 COMMA dt_int32 COMMA dt_int32>, \ | |||||
#define cb(_dtype) \ | |||||
INST_REDUCE( \ | |||||
device_reduce::CheckNonFiniteOp<_dtype COMMA dt_int32 COMMA dt_int32>, \ | |||||
false); | false); | ||||
cb(dt_float32); | cb(dt_float32); | ||||
@@ -14,7 +14,7 @@ using device_reduce::CheckNonFiniteOp; | |||||
template <typename T> | template <typename T> | ||||
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { | size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { | ||||
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes | // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes | ||||
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; | |||||
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op; | |||||
megdnn_assert(m_size > 0); | megdnn_assert(m_size > 0); | ||||
WorkspaceBundle bundle( | WorkspaceBundle bundle( | ||||
nullptr, { | nullptr, { | ||||
@@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, | _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(srcs, dst, workspace.size); | check_exec(srcs, dst, workspace.size); | ||||
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; | |||||
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op; | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
SmallVector<size_t> workspace_sizes{ | SmallVector<size_t> workspace_sizes{ | ||||
sizeof(T*) * m_size, | sizeof(T*) * m_size, | ||||
@@ -247,4 +247,4 @@ def _override( | |||||
def _get_actual_op_param(function_param, config_param): | def _get_actual_op_param(function_param, config_param): | ||||
return function_param if config_param is "default" else config_param | |||||
return function_param if config_param == "default" else config_param |
@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
"optimizer can only optimize Parameters, but one of the params is " | "optimizer can only optimize Parameters, but one of the params is " | ||||
+ str(type(param)) | + str(type(param)) | ||||
) | ) | ||||
param._reset(Tensor(param, no_cache=True)) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
@@ -581,9 +581,9 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
(GenericFunction&)inputs[1].cast<FunctionValue>(); | (GenericFunction&)inputs[1].cast<FunctionValue>(); | ||||
// make param grads as FormattedTensor | // make param grads as FormattedTensor | ||||
GenericFunction new_callback = | GenericFunction new_callback = | ||||
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||||
[&, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||||
auto wrapped_inputs = SmallVector<ValueRef>{ | auto wrapped_inputs = SmallVector<ValueRef>{ | ||||
this->value_type().make(inputs_.item(), format)}; | |||||
m_value_type.make(inputs_.item(), format)}; | |||||
auto ret = callback(wrapped_inputs); | auto ret = callback(wrapped_inputs); | ||||
return ret; | return ret; | ||||
}; | }; | ||||
@@ -67,7 +67,6 @@ template <typename T> | |||||
class Type : public IType { | class Type : public IType { | ||||
protected: | protected: | ||||
Type(std::string name) : IType(std::move(name)) {} | Type(std::string name) : IType(std::move(name)) {} | ||||
Type(IType&& type) : IType(std::move(type)) {} | |||||
// TODO: each type owns an allocator | // TODO: each type owns an allocator | ||||
public: | public: | ||||
@@ -105,7 +104,6 @@ template <typename T> | |||||
class ObjectType : public Type<T> { | class ObjectType : public Type<T> { | ||||
public: | public: | ||||
ObjectType(std::string name) : Type<T>(name) {} | ObjectType(std::string name) : Type<T>(name) {} | ||||
ObjectType(IType&& type) : Type<T>(std::move(type)) {} | |||||
}; | }; | ||||
/** | /** | ||||