Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-08-14 08:57:46 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-15 03:11:01 +0300
commit3c416301250ef84d78a1ce3be5598303aa9c9d9b (patch)
treea77b8696db9ec516344d0b49d444757104d714ce
parentb304bacee3c9f4f3cf22cbef32d433cb2e4b5657 (diff)
2d conv. Uses the width of a conv feature to rescale the outputsentiment_lstm
channels, making it possible to combine the original full-width features with the smaller conv features Adds some logging for how big each filter is
-rw-r--r--stanza/models/classifiers/cnn_classifier.py46
-rw-r--r--stanza/tests/classifiers/test_classifier.py11
2 files changed, 48 insertions, 9 deletions
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py
index 38763194..c889968e 100644
--- a/stanza/models/classifiers/cnn_classifier.py
+++ b/stanza/models/classifiers/cnn_classifier.py
@@ -182,21 +182,41 @@ class CNNClassifier(nn.Module):
conv_input_dim = total_embedding_dim
self.bilstm = None
- self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1,
- out_channels=self.config.filter_channels,
- kernel_size=(filter_size, conv_input_dim))
- for filter_size in self.config.filter_sizes])
+ self.fc_input_size = 0
+ self.conv_layers = nn.ModuleList()
+ self.max_window = 0
+ for filter_size in self.config.filter_sizes:
+ if isinstance(filter_size, int):
+ self.max_window = max(self.max_window, filter_size)
+ fc_delta = self.config.filter_channels // self.config.maxpool_width
+ logger.debug("Adding full width filter %d. Output channels: %d -> %d", filter_size, self.config.filter_channels, fc_delta)
+ self.fc_input_size += fc_delta
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
+ out_channels=self.config.filter_channels,
+ kernel_size=(filter_size, conv_input_dim)))
+ elif isinstance(filter_size, tuple) and len(filter_size) == 2:
+ filter_height, filter_width = filter_size
+ self.max_window = max(self.max_window, filter_width)
+ filter_channels = max(1, self.config.filter_channels // (conv_input_dim // filter_width))
+ fc_delta = filter_channels * (conv_input_dim // filter_width) // self.config.maxpool_width
+ logger.debug("Adding filter %s. Output channels: %d -> %d", filter_size, filter_channels, fc_delta)
+ self.fc_input_size += fc_delta
+ self.conv_layers.append(nn.Conv2d(in_channels=1,
+ out_channels=filter_channels,
+ stride=(1, filter_width),
+ kernel_size=(filter_height, filter_width)))
+ else:
+ raise ValueError("Expected int or 2d tuple for conv size")
- previous_layer_size = len(self.config.filter_sizes) * (self.config.filter_channels // self.config.maxpool_width)
+ logger.debug("Input dim to FC layers: %d", self.fc_input_size)
fc_layers = []
+ previous_layer_size = self.fc_input_size
for shape in self.config.fc_shapes:
fc_layers.append(nn.Linear(previous_layer_size, shape))
previous_layer_size = shape
fc_layers.append(nn.Linear(previous_layer_size, self.config.num_classes))
self.fc_layers = nn.ModuleList(fc_layers)
- self.max_window = max(self.config.filter_sizes)
-
self.dropout = nn.Dropout(self.config.dropout)
def add_unsaved_module(self, name, module):
@@ -348,8 +368,16 @@ class CNNClassifier(nn.Module):
# reshape to fit the input tensors
x = input_vectors.unsqueeze(1)
- conv_outs = [self.dropout(F.relu(conv(x).squeeze(3)))
- for conv in self.conv_layers]
+ conv_outs = []
+ for conv, filter_size in zip(self.conv_layers, self.config.filter_sizes):
+ # TODO: non-int filter sizes
+ if isinstance(filter_size, int):
+ conv_out = self.dropout(F.relu(conv(x).squeeze(3)))
+ conv_outs.append(conv_out)
+ else:
+ conv_out = conv(x).transpose(2, 3).flatten(1, 2)
+ conv_out = self.dropout(F.relu(conv_out))
+ conv_outs.append(conv_out)
pool_outs = [F.max_pool2d(out, (self.config.maxpool_width, out.shape[2])).squeeze(2) for out in conv_outs]
pooled = torch.cat(pool_outs, dim=1)
diff --git a/stanza/tests/classifiers/test_classifier.py b/stanza/tests/classifiers/test_classifier.py
index e47a6654..df762866 100644
--- a/stanza/tests/classifiers/test_classifier.py
+++ b/stanza/tests/classifiers/test_classifier.py
@@ -178,3 +178,14 @@ def test_train_maxpool_width(tmp_path, fake_embeddings, train_file, dev_file):
args = ["--maxpool_width", "3", "--filter_channels", "20"]
run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
+
+def test_train_conv_2d(tmp_path, fake_embeddings, train_file, dev_file):
+ args = ["--filter_sizes", "(3,4,5)", "--filter_channels", "20"]
+ run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
+
+ args = ["--filter_sizes", "((3,2),)", "--filter_channels", "20"]
+ run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
+
+ args = ["--filter_sizes", "((3,2),3)", "--filter_channels", "20"]
+ run_training(tmp_path, fake_embeddings, train_file, dev_file, args)
+