blob: 24a58b575035951ef4527c0230b338658b7c2033 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
"""
Test a couple basic functions - load & save an existing model
"""
import pytest
import glob
import os
import tempfile
from stanza.models.lemma import trainer
from stanza.tests import *
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
@pytest.fixture
def english_model():
models_path = os.path.join(TEST_MODELS_DIR, "en", "lemma", "*")
models = glob.glob(models_path)
# we expect at least one English model downloaded for the tests
assert len(models) >= 1
model_file = models[0]
return trainer.Trainer(model_file=model_file)
def test_load_model(english_model):
"""
Does nothing, just tests that loading works
"""
def test_save_load_model(english_model):
"""
Load, save, and load again
"""
with tempfile.TemporaryDirectory() as tempdir:
save_file = os.path.join(tempdir, "resaved", "lemma.pt")
english_model.save(save_file)
reloaded = trainer.Trainer(model_file=save_file)
|