diff options
Diffstat (limited to 'stanza/models/constituency/base_model.py')
-rw-r--r-- | stanza/models/constituency/base_model.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/stanza/models/constituency/base_model.py b/stanza/models/constituency/base_model.py index 23781bce..de73554c 100644 --- a/stanza/models/constituency/base_model.py +++ b/stanza/models/constituency/base_model.py @@ -204,6 +204,16 @@ class BaseModel(ABC): for tree in trees] return self.initial_state_from_preterminals(preterminal_lists, gold_trees=trees) + def build_batch_from_states(self, batch_size, data_iterator): + state_batch = [] + for _ in range(batch_size): + state = next(data_iterator, None) + if state is None: + break + state_batch.append(state) + + return state_batch + def build_batch_from_trees(self, batch_size, data_iterator): """ Read from the data_iterator batch_size trees and turn them into new parsing states |