Browse Source

fix(mge/module): fix missing import

GitOrigin-RevId: 7cdf6ac814
release-1.2
Megvii Engine Team 4 years ago
parent
commit
aa953c3bd6
2 changed files with 13 additions and 10 deletions
  1. +1
    -0
      imperative/python/megengine/module/__init__.py
  2. +12
    -10
      imperative/python/megengine/module/normalization.py

+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -20,6 +20,7 @@ from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .module import Module
from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential

+ 12
- 10
imperative/python/megengine/module/normalization.py View File

@@ -8,11 +8,13 @@
import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine import Parameter

from .init import ones_, zeros_
from .module import Module

class GroupNorm(M.Module):

class GroupNorm(Module):
"""
Simple implementation of GroupNorm. Only support 4d tensor now.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
@@ -35,8 +37,8 @@ class GroupNorm(M.Module):

def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
ones_(self.weight)
zeros_(self.bias)

def forward(self, x):
N, C, H, W = x.shape
@@ -61,7 +63,7 @@ class GroupNorm(M.Module):
return s.format(**self.__dict__)


class InstanceNorm(M.Module):
class InstanceNorm(Module):
"""
Simple implementation of InstanceNorm. Only support 4d tensor now.
Reference: https://arxiv.org/abs/1607.08022.
@@ -83,8 +85,8 @@ class InstanceNorm(M.Module):

def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
ones_(self.weight)
zeros_(self.bias)

def forward(self, x):
N, C, H, W = x.shape
@@ -105,7 +107,7 @@ class InstanceNorm(M.Module):
return s.format(**self.__dict__)


class LayerNorm(M.Module):
class LayerNorm(Module):
"""
Simple implementation of LayerNorm. Only support 4d tensor now.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
@@ -127,8 +129,8 @@ class LayerNorm(M.Module):

def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
ones_(self.weight)
zeros_(self.bias)

def forward(self, x):
N, C, H, W = x.shape


Loading…
Cancel
Save