Browse Source

change parameters of oneflow.full to fit different versions

tags/1.0.0beta2
x54-729 2 years ago
parent
commit
2128d8d97c
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      fastNLP/core/collators/padders/oneflow_padder.py

+ 2
- 2
fastNLP/core/collators/padders/oneflow_padder.py View File

@@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder):
else:
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]

tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device)
tensor = oneflow.full(max_shape, pad_val, dtype, device)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
tensor[slices] = field
@@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0):
:return:
"""
shapes = get_shape(batch_field)
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val)
tensor = oneflow.full(shapes, pad_val, dtype)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

Loading…
Cancel
Save