# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import numpy as np import torch import torch.nn as nn from pytorch import mutables from ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm from utils import GlobalMaxPool, GlobalAvgPool class Layer(mutables.MutableScope): def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask): super(Layer, self).__init__(key) def conv_shortcut(kernel_size): return ConvBN(kernel_size, hidden_units, hidden_units, cnn_keep_prob, False, True) self.n_candidates = len(prev_keys) if self.n_candidates: self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1) else: # first layer, skip input choice self.prec = None self.op = mutables.LayerChoice([ conv_shortcut(1), conv_shortcut(3), conv_shortcut(5), conv_shortcut(7), AvgPool(3, False, True), MaxPool(3, False, True), RNN(hidden_units, lstm_keep_prob), Attention(hidden_units, 4, att_keep_prob, att_mask) ]) if self.n_candidates: self.skipconnect = mutables.InputChoice(choose_from=prev_keys) else: self.skipconnect = None self.bn = BatchNorm(hidden_units, False, True) def forward(self, last_layer, prev_layers, mask): # pass an extra last_layer to deal with layer 0 (prev_layers is empty) if self.prec is None: prec = last_layer else: prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first out = self.op(prec, mask) if self.skipconnect is not None: connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:]) if connection is not None: out += connection out = self.bn(out, mask) return out class Model(nn.Module): def __init__(self, embedding, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5, lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True, embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"): super(Model, self).__init__() # load word embedding self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False) self.hidden_units = hidden_units self.num_layers = num_layers self.num_classes = num_classes # 第一层 self.init_conv = ConvBN(1, self.embedding.embedding_dim, hidden_units, cnn_keep_prob, False, True) self.layers = nn.ModuleList() candidate_keys_pool = [] # ['layer_0', 'layer_1'] for layer_id in range(self.num_layers): k = "layer_{}".format(layer_id) self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask)) candidate_keys_pool.append(k) self.linear_combine = LinearCombine(self.num_layers) self.linear_out = nn.Linear(self.hidden_units, self.num_classes) self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob) self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob) assert global_pool in ["max", "avg"] if global_pool == "max": self.global_pool = GlobalMaxPool() elif global_pool == "avg": self.global_pool = GlobalAvgPool() def forward(self, inputs): sent_ids, mask = inputs seq = self.embedding(sent_ids.long()) seq = self.embed_dropout(seq) seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L) # from (batch_size, seq_len, feat_size) -> (batch_size, feat_size, seq_len) x = self.init_conv(seq, mask) prev_layers = [] for layer in self.layers: x = layer(x, prev_layers, mask) prev_layers.append(x) x = self.linear_combine(torch.stack(prev_layers)) x = self.global_pool(x, mask) x = self.output_dropout(x) x = self.linear_out(x) return x