# python: 3.6 # encoding: utf-8 import torch import torch.nn as nn from torch.nn.init import xavier_uniform_ # import torch.nn.functional as F class Conv(nn.Module): """ Basic 1-d convolution module. initialize with xavier_uniform """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation='relu'): super(Conv, self).__init__() self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) xavier_uniform_(self.conv.weight) activations = { 'relu': nn.ReLU(), 'tanh': nn.Tanh()} if activation in activations: self.activation = activations[activation] else: raise Exception( 'Should choose activation function from: ' + ', '.join([x for x in activations])) def forward(self, x): x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] x = self.conv(x) # [N,C_in,L] -> [N,C_out,L] x = self.activation(x) x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C] return x