|
@@ -201,7 +201,7 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): |
|
|
def forward(self, x_orig): |
|
|
def forward(self, x_orig): |
|
|
if self.enabled: |
|
|
if self.enabled: |
|
|
x = x_orig.detach() |
|
|
x = x_orig.detach() |
|
|
if is_distributed: |
|
|
|
|
|
|
|
|
if is_distributed(): |
|
|
min_x = all_reduce_min(x.min(), WORLD) |
|
|
min_x = all_reduce_min(x.min(), WORLD) |
|
|
max_x = all_reduce_max(x.max(), WORLD) |
|
|
max_x = all_reduce_max(x.max(), WORLD) |
|
|
else: |
|
|
else: |
|
|