diff options
author | John Bauer <horatio@gmail.com> | 2022-09-14 23:34:17 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-09-15 01:52:16 +0300 |
commit | ba3f64d5f571b1dc70121551364fc89d103ca1cd (patch) | |
tree | 16e5b6e6f2a13998f12fed8214835205e8ca3703 | |
parent | b9d75cc6a3e12de69ab6734295a3dcbf9fc919fc (diff) |
Switch dict + list to OrderedDict
-rw-r--r-- | stanza/pipeline/multilingual.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/stanza/pipeline/multilingual.py b/stanza/pipeline/multilingual.py index f8ab6abd..bcafba9a 100644 --- a/stanza/pipeline/multilingual.py +++ b/stanza/pipeline/multilingual.py @@ -4,6 +4,7 @@ Class for running multilingual pipelines import torch +from collections import OrderedDict import copy import logging @@ -35,8 +36,10 @@ class MultilingualPipeline: self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config) self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs) self.max_cache_size = max_cache_size - self.pipeline_cache = {} - self.lang_request_history = [] + # OrderedDict so we can use it as a LRU cache + # most recent Pipeline goes to the end, pop the oldest one + # when we run out of space + self.pipeline_cache = OrderedDict() # if lang is not in any of the lang_configs, update them to # include the lang parameter. otherwise, the default language @@ -70,9 +73,8 @@ class MultilingualPipeline: """ # update request history - if lang in self.lang_request_history: - self.lang_request_history.remove(lang) - self.lang_request_history.append(lang) + if lang in self.pipeline_cache: + self.pipeline_cache.move_to_end(lang, last=True) # update language configs if lang not in self.lang_configs: @@ -83,9 +85,7 @@ class MultilingualPipeline: logger.debug("Loading unknown language in MultilingualPipeline: %s", lang) # clear least recently used lang from pipeline cache if len(self.pipeline_cache) == self.max_cache_size: - lru_lang = self.lang_request_history[0] - self.pipeline_cache.pop(lru_lang) - self.lang_request_history.remove(lru_lang) + self.pipeline_cache.popitem(last=False) self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang]) def process(self, doc): |