diff options
author | John Bauer <horatio@gmail.com> | 2022-08-14 08:57:46 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-08-15 03:11:01 +0300 |
commit | 3c416301250ef84d78a1ce3be5598303aa9c9d9b (patch) | |
tree | a77b8696db9ec516344d0b49d444757104d714ce | |
parent | b304bacee3c9f4f3cf22cbef32d433cb2e4b5657 (diff) |
2d conv. Uses the width of a conv feature to rescale the outputsentiment_lstm
channels, making it possible to combine the original full-width
features with the smaller conv features
Adds some logging for how big each filter is
-rw-r--r-- | stanza/models/classifiers/cnn_classifier.py | 46 | ||||
-rw-r--r-- | stanza/tests/classifiers/test_classifier.py | 11 |
2 files changed, 48 insertions, 9 deletions
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py index 38763194..c889968e 100644 --- a/stanza/models/classifiers/cnn_classifier.py +++ b/stanza/models/classifiers/cnn_classifier.py @@ -182,21 +182,41 @@ class CNNClassifier(nn.Module): conv_input_dim = total_embedding_dim self.bilstm = None - self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, - out_channels=self.config.filter_channels, - kernel_size=(filter_size, conv_input_dim)) - for filter_size in self.config.filter_sizes]) + self.fc_input_size = 0 + self.conv_layers = nn.ModuleList() + self.max_window = 0 + for filter_size in self.config.filter_sizes: + if isinstance(filter_size, int): + self.max_window = max(self.max_window, filter_size) + fc_delta = self.config.filter_channels // self.config.maxpool_width + logger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, self.config.filter_channels, fc_delta) + self.fc_input_size += fc_delta + self.conv_layers.append(nn.Conv2d(in_channels=1, + out_channels=self.config.filter_channels, + kernel_size=(filter_size, conv_input_dim))) + elif isinstance(filter_size, tuple) and len(filter_size) == 2: + filter_height, filter_width = filter_size + self.max_window = max(self.max_window, filter_width) + filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width)) + fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width + logger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta) + self.fc_input_size += fc_delta + self.conv_layers.append(nn.Conv2d(in_channels=1, + out_channels=filter_channels, + stride=(1, filter_width), + kernel_size=(filter_height, filter_width))) + else: + raise ValueError("Expected int or 2d tuple for conv size") - previous_layer_size = len(self.config.filter_sizes) * (self.config.filter_channels // self.config.maxpool_width) + logger.debug("Input dim to FC layers: %d", self.fc_input_size) fc_layers = [] + previous_layer_size = self.fc_input_size for shape in self.config.fc_shapes: fc_layers.append(nn.Linear(previous_layer_size, shape)) previous_layer_size = shape fc_layers.append(nn.Linear(previous_layer_size, self.config.num_classes)) self.fc_layers = nn.ModuleList(fc_layers) - self.max_window = max(self.config.filter_sizes) - self.dropout = nn.Dropout(self.config.dropout) def add_unsaved_module(self, name, module): @@ -348,8 +368,16 @@ class CNNClassifier(nn.Module): # reshape to fit the input tensors x = input_vectors.unsqueeze(1) - conv_outs = [self.dropout(F.relu(conv(x).squeeze(3))) - for conv in self.conv_layers] + conv_outs = [] + for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes): + # TODO: non-int filter sizes + if isinstance(filter_size, int): + conv_out = self.dropout(F.relu(conv(x).squeeze(3))) + conv_outs.append(conv_out) + else: + conv_out = conv(x).transpose(2, 3).flatten(1, 2) + conv_out = self.dropout(F.relu(conv_out)) + conv_outs.append(conv_out) pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs] pooled = torch.cat(pool_outs, dim=1) diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py index e47a6654..df762866 100644 --- a/stanza/tests/classifiers/test_classifier.py +++ b/stanza/tests/classifiers/test_classifier.py @@ -178,3 +178,14 @@ def test_train_maxpool_width(tmp_path, fake_embeddings, train_file, dev_file): args = ["--maxpool_width", "3", "--filter_channels", "20"] run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + +def test_train_conv_2d(tmp_path, fake_embeddings, train_file, dev_file): + args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20"] + run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20"] + run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + + args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20"] + run_training(tmp_path, fake_embeddings, train_file, dev_file, args) + |