|
@@ -108,7 +108,7 @@ def full(shape, value, dtype="float32", device=None): |
|
|
if device is None: |
|
|
if device is None: |
|
|
device = get_default_device() |
|
|
device = get_default_device() |
|
|
(x,) = Const(value, dtype=dtype, device=device)() |
|
|
(x,) = Const(value, dtype=dtype, device=device)() |
|
|
if len(shape) == 0: # scalar |
|
|
|
|
|
|
|
|
if shape is (): # scalar.shape |
|
|
return x |
|
|
return x |
|
|
return broadcast_to(x, shape) |
|
|
return broadcast_to(x, shape) |
|
|
|
|
|
|
|
|