Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2020-09-22 23:35:48 +0300
committerJohn Bauer <horatio@gmail.com>2022-08-30 22:30:48 +0300
commit532f7a4775319b07e19227cec4c74e38b474c41c (patch)
tree54b2926de74e054ba62ab0ad07cfb3b1bcc9c07b
parent2e0f4b65deb55114a991afec96f17792f6a9605d (diff)
Add support for elmoformanylangs to sentiment
Includes a matrix trained to connect the 3 layers of elmo instead of using the default averaging Also, a projection from elmo dim to a lower dimension (although this was less useful) Add a comment on how the sentiment processor doesn't load Elmo. Actually, in general this integration is unlikely to be used for much, but there's also no specific reason to throw this code away.
-rw-r--r--stanza/models/classifier.py11
-rw-r--r--stanza/models/classifiers/cnn_classifier.py52
-rw-r--r--stanza/models/common/utils.py10
-rw-r--r--stanza/pipeline/sentiment_processor.py6
4 files changed, 76 insertions, 3 deletions
diff --git a/stanza/models/classifier.py b/stanza/models/classifier.py
index a6583878..acaf2989 100644
--- a/stanza/models/classifier.py
+++ b/stanza/models/classifier.py
@@ -37,6 +37,8 @@ class DevScoring(Enum):
logger = logging.getLogger('stanza')
tlogger = logging.getLogger('stanza.classifiers.trainer')
+logging.getLogger('elmoformanylangs').setLevel(logging.WARNING)
+
DEFAULT_TRAIN='data/sentiment/en_sstplus.train.txt'
DEFAULT_DEV='data/sentiment/en_sst3roots.dev.txt'
DEFAULT_TEST='data/sentiment/en_sst3roots.test.txt'
@@ -187,6 +189,10 @@ def parse_args(args=None):
parser.add_argument('--charlm_projection', type=int, default=None, help="Project the charlm values to this dimension")
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
+ parser.add_argument('--elmo_model', default='extern_data/manyelmo/english', help='Directory with elmo model')
+ parser.add_argument('--use_elmo', dest='use_elmo', default=False, action='store_true', help='Use an elmo model as a source of parameters')
+ parser.add_argument('--elmo_projection', type=int, default=None, help='Project elmo to this many dimensions')
+
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
@@ -554,6 +560,7 @@ def load_model(args):
Load both the pretrained embedding and other pieces from the args as well as the model itself
"""
pretrain = load_pretrain(args)
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
@@ -563,7 +570,7 @@ def load_model(args):
load_name = os.path.join(args.save_dir, args.load_name)
if not os.path.exists(load_name):
raise FileNotFoundError("Could not find model to load in either %s or %s" % (args.load_name, load_name))
- return cnn_classifier.load(load_name, pretrain, charmodel_forward, charmodel_backward)
+ return cnn_classifier.load(load_name, pretrain, charmodel_forward, charmodel_backward, elmo_model)
def build_new_model(args, train_set):
"""
@@ -573,6 +580,7 @@ def build_new_model(args, train_set):
raise ValueError("Must have a train set to build a new model - needed for labels and delta word vectors")
pretrain = load_pretrain(args)
+ elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
charmodel_forward = load_charlm(args.charlm_forward_file)
charmodel_backward = load_charlm(args.charlm_backward_file)
@@ -586,6 +594,7 @@ def build_new_model(args, train_set):
labels=labels,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
+ elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
args=args)
diff --git a/stanza/models/classifiers/cnn_classifier.py b/stanza/models/classifiers/cnn_classifier.py
index 5dbe1470..b454f3bd 100644
--- a/stanza/models/classifiers/cnn_classifier.py
+++ b/stanza/models/classifiers/cnn_classifier.py
@@ -52,7 +52,7 @@ tlogger = logging.getLogger('stanza.classifiers.trainer')
class CNNClassifier(nn.Module):
def __init__(self, pretrain, extra_vocab, labels,
- charmodel_forward, charmodel_backward, bert_model, bert_tokenizer,
+ charmodel_forward, charmodel_backward, elmo_model, bert_model, bert_tokenizer,
args):
"""
pretrain is a pretrained word embedding. should have .emb and .vocab
@@ -82,6 +82,8 @@ class CNNClassifier(nn.Module):
extra_wordvec_max_norm = args.extra_wordvec_max_norm,
char_lowercase = args.char_lowercase,
charlm_projection = args.charlm_projection,
+ use_elmo = args.use_elmo,
+ elmo_projection = args.elmo_projection,
bert_model = args.bert_model,
bilstm = args.bilstm,
bilstm_hidden_dim = args.bilstm_hidden_dim,
@@ -94,6 +96,7 @@ class CNNClassifier(nn.Module):
emb_matrix = pretrain.emb
self.add_unsaved_module('embedding', nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True))
+ self.add_unsaved_module('elmo_model', elmo_model)
self.vocab_size = emb_matrix.shape[0]
self.embedding_dim = emb_matrix.shape[1]
@@ -172,6 +175,19 @@ class CNNClassifier(nn.Module):
self.charmodel_backward_projection = None
total_embedding_dim += charmodel_backward.hidden_dim()
+ if self.config.use_elmo:
+ if elmo_model is None:
+ raise ValueError("Model requires elmo, but elmo_model not passed in")
+ elmo_dim = elmo_model.sents2elmo([["Test"]])[0].shape[1]
+
+ # this mapping will combine 3 layers of elmo to 1 layer of features
+ self.elmo_combine_layers = nn.Linear(in_features=3, out_features=1, bias=False)
+ if self.config.elmo_projection:
+ self.elmo_projection = nn.Linear(in_features=elmo_dim, out_features=self.config.elmo_projection)
+ total_embedding_dim = total_embedding_dim + self.config.elmo_projection
+ else:
+ total_embedding_dim = total_embedding_dim + elmo_dim
+
if bert_model is not None:
if bert_tokenizer is None:
raise ValueError("Cannot have a bert model without a tokenizer")
@@ -284,6 +300,9 @@ class CNNClassifier(nn.Module):
extra_batch_indices = []
begin_paddings = []
end_paddings = []
+
+ elmo_batch_words = []
+
for phrase in inputs:
# we use random at training time to try to learn different
# positions of padding. at test time, though, we want to
@@ -332,6 +351,13 @@ class CNNClassifier(nn.Module):
extra_sentence_indices.extend([PAD_ID] * end_pad_width)
extra_batch_indices.append(extra_sentence_indices)
+ if self.config.use_elmo:
+ elmo_phrase_words = [""] * begin_pad_width
+ for word in phrase:
+ elmo_phrase_words.append(word)
+ elmo_phrase_words.extend([""] * end_pad_width)
+ elmo_batch_words.append(elmo_phrase_words)
+
# creating a single large list with all the indices lets us
# create a single tensor, which is much faster than creating
# many tiny tensors
@@ -370,6 +396,25 @@ class CNNClassifier(nn.Module):
char_reps_backward = self.build_char_reps(inputs, max_phrase_len, self.backward_charlm, self.charmodel_backward_projection, begin_paddings, device)
all_inputs.append(char_reps_backward)
+ if self.config.use_elmo:
+ # this will be N arrays of 3xMx1024 where M is the number of words
+ # and N is the number of sentences (and 1024 is actually the number of weights)
+ elmo_arrays = self.elmo_model.sents2elmo(elmo_batch_words, output_layer=-2)
+ elmo_tensors = [torch.tensor(x).to(device=device) for x in elmo_arrays]
+ # elmo_tensor will now be Nx3xMx1024
+ elmo_tensor = torch.stack(elmo_tensors)
+ # Nx1024xMx3
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 3)
+ # NxMx1024x3
+ elmo_tensor = torch.transpose(elmo_tensor, 1, 2)
+ # NxMx1024x1
+ elmo_tensor = self.elmo_combine_layers(elmo_tensor)
+ # NxMx1024
+ elmo_tensor = elmo_tensor.squeeze(3)
+ if self.config.elmo_projection:
+ elmo_tensor = self.elmo_projection(elmo_tensor)
+ all_inputs.append(elmo_tensor)
+
if self.bert_model is not None:
bert_embeddings = self.extract_bert_embeddings(inputs, max_phrase_len, begin_paddings, device)
all_inputs.append(bert_embeddings)
@@ -425,7 +470,7 @@ def save(filename, model, skip_modules=True):
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
-def load(filename, pretrain, charmodel_forward, charmodel_backward, foundation_cache=None):
+def load(filename, pretrain, charmodel_forward, charmodel_backward, elmo_model, foundation_cache=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
@@ -434,6 +479,8 @@ def load(filename, pretrain, charmodel_forward, charmodel_backward, foundation_c
logger.debug("Loaded model {}".format(filename))
# TODO: should not be needed when all models have this value set
+ setattr(checkpoint['config'], 'use_elmo', getattr(checkpoint['config'], 'use_elmo', False))
+ setattr(checkpoint['config'], 'elmo_projection', getattr(checkpoint['config'], 'elmo_projection', False))
setattr(checkpoint['config'], 'char_lowercase', getattr(checkpoint['config'], 'char_lowercase', False))
setattr(checkpoint['config'], 'charlm_projection', getattr(checkpoint['config'], 'charlm_projection', None))
setattr(checkpoint['config'], 'bert_model', getattr(checkpoint['config'], 'bert_model', None))
@@ -453,6 +500,7 @@ def load(filename, pretrain, charmodel_forward, charmodel_backward, foundation_c
labels=checkpoint['labels'],
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
+ elmo_model=elmo_model,
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
args=checkpoint['config'])
diff --git a/stanza/models/common/utils.py b/stanza/models/common/utils.py
index e04d0ba2..af2d983f 100644
--- a/stanza/models/common/utils.py
+++ b/stanza/models/common/utils.py
@@ -429,3 +429,13 @@ def checkpoint_name(save_dir, save_name, checkpoint_name):
return save_name[:-3] + "_checkpoint.pt"
return save_name + "_checkpoint"
+
+def load_elmo(elmo_model):
+ # This import is here so that Elmo integration can be treated
+ # as an optional feature
+ import elmoformanylangs
+
+ logger.info("Loading elmo: %s" % elmo_model)
+ elmo_model = elmoformanylangs.Embedder(elmo_model)
+ return elmo_model
+
diff --git a/stanza/pipeline/sentiment_processor.py b/stanza/pipeline/sentiment_processor.py
index 8573e4e6..8c866304 100644
--- a/stanza/pipeline/sentiment_processor.py
+++ b/stanza/pipeline/sentiment_processor.py
@@ -34,10 +34,16 @@ class SentimentProcessor(UDProcessor):
backward_charlm_path = config.get('backward_charlm_path', None)
charmodel_backward = pipeline.foundation_cache.load_charlm(backward_charlm_path)
# set up model
+ # elmo does not have a convenient way to download intermediate
+ # models the way stanza downloads charlms & pretrains or
+ # transformers downloads bert etc
+ # however, elmo in general is not as good as using a
+ # transformer, so it is unlikely we will ever fix this
self._model = cnn_classifier.load(filename=config['model_path'],
pretrain=self._pretrain,
charmodel_forward=charmodel_forward,
charmodel_backward=charmodel_backward,
+ elmo_model=None,
foundation_cache=pipeline.foundation_cache)
# batch size counted as words
self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)