Welcome to mirror list, hosted at ThFree Co, Russian Federation.

test_char_model.py « common « tests « stanza - github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 30b348518f222df2911d637a8da2b22dcb512b82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Currently tests a few configurations of files for creating a charlm vocab

Also has a skeleton test of loading & saving a charlm
"""

from collections import Counter
import glob
import lzma
import os
import tempfile

import pytest

from stanza.models import charlm
from stanza.models.common import char_model
from stanza.tests import TEST_MODELS_DIR

pytestmark = [pytest.mark.travis, pytest.mark.pipeline]

fake_text_1 = """
Unban mox opal!
I hate watching Peppa Pig
"""

fake_text_2 = """
This is plastic cheese
"""

class TestCharModel:
    def test_single_file_vocab(self):
        with tempfile.TemporaryDirectory() as tempdir:
            sample_file = os.path.join(tempdir, "text.txt")
            with open(sample_file, "w", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            vocab = char_model.build_charlm_vocab(sample_file)

        for i in fake_text_1:
            assert i in vocab
        assert "Q" not in vocab

    def test_single_file_xz_vocab(self):
        with tempfile.TemporaryDirectory() as tempdir:
            sample_file = os.path.join(tempdir, "text.txt.xz")
            with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            vocab = char_model.build_charlm_vocab(sample_file)

        for i in fake_text_1:
            assert i in vocab
        assert "Q" not in vocab

    def test_single_file_dir_vocab(self):
        with tempfile.TemporaryDirectory() as tempdir:
            sample_file = os.path.join(tempdir, "text.txt")
            with open(sample_file, "w", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            vocab = char_model.build_charlm_vocab(tempdir)

        for i in fake_text_1:
            assert i in vocab
        assert "Q" not in vocab

    def test_multiple_files_vocab(self):
        with tempfile.TemporaryDirectory() as tempdir:
            sample_file = os.path.join(tempdir, "t1.txt")
            with open(sample_file, "w", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            sample_file = os.path.join(tempdir, "t2.txt.xz")
            with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
                fout.write(fake_text_2)
            vocab = char_model.build_charlm_vocab(tempdir)

        for i in fake_text_1:
            assert i in vocab
        for i in fake_text_2:
            assert i in vocab
        assert "Q" not in vocab

    def test_cutoff_vocab(self):
        with tempfile.TemporaryDirectory() as tempdir:
            sample_file = os.path.join(tempdir, "t1.txt")
            with open(sample_file, "w", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            sample_file = os.path.join(tempdir, "t2.txt.xz")
            with lzma.open(sample_file, "wt", encoding="utf-8") as fout:
                fout.write(fake_text_2)

            vocab = char_model.build_charlm_vocab(tempdir, cutoff=2)

        counts = Counter(fake_text_1) + Counter(fake_text_2)
        for letter, count in counts.most_common():
            if count < 2:
                assert letter not in vocab
            else:
                assert letter in vocab

    def test_build_model(self):
        """
        Test the whole thing on a small dataset for an iteration or two
        """
        with tempfile.TemporaryDirectory() as tempdir:
            eval_file = os.path.join(tempdir, "en_test.dev.txt")
            with open(eval_file, "w", encoding="utf-8") as fout:
                fout.write(fake_text_1)
            train_file = os.path.join(tempdir, "en_test.train.txt")
            with open(train_file, "w", encoding="utf-8") as fout:
                for i in range(1000):
                    fout.write(fake_text_1)
                    fout.write("\n")
                    fout.write(fake_text_2)
                    fout.write("\n")
            save_name = 'en_test.forward.pt'
            vocab_save_name = 'en_text.vocab.pt'
            checkpoint_save_name = 'en_text.checkpoint.pt'
            args = ['--train_file', train_file,
                    '--eval_file', eval_file,
                    '--eval_steps', '0', # eval once per opoch
                    '--epochs', '2',
                    '--cutoff', '1',
                    '--batch_size', '%d' % len(fake_text_1),
                    '--lang', 'en',
                    '--shorthand', 'en_test',
                    '--save_dir', tempdir,
                    '--save_name', save_name,
                    '--vocab_save_name', vocab_save_name,
                    '--checkpoint_save_name', checkpoint_save_name]
            args = charlm.parse_args(args)
            charlm.train(args)

            assert os.path.exists(os.path.join(tempdir, vocab_save_name))

            # test that saving & loading of the model worked
            assert os.path.exists(os.path.join(tempdir, save_name))
            model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, save_name))

            # test that saving & loading of the checkpoint worked
            assert os.path.exists(os.path.join(tempdir, checkpoint_save_name))
            model = char_model.CharacterLanguageModel.load(os.path.join(tempdir, checkpoint_save_name))
            trainer = char_model.CharacterLanguageModelTrainer.load(args, os.path.join(tempdir, checkpoint_save_name))

            assert trainer.global_step > 0
            assert trainer.epoch == 2

            # quick test to verify this method works with a trained model
            charlm.get_current_lr(trainer, args)

            # test loading a vocab built by the training method...
            vocab = charlm.load_char_vocab(os.path.join(tempdir, vocab_save_name))
            trainer = char_model.CharacterLanguageModelTrainer.from_new_model(args, vocab)
            # ... and test the get_current_lr for an untrained model as well
            # this test is super "eager"
            assert charlm.get_current_lr(trainer, args) == args['lr0']

    @pytest.fixture(scope="class")
    def english_forward(self):
        # eg, stanza_test/models/en/forward_charlm/1billion.pt
        models_path = os.path.join(TEST_MODELS_DIR, "en", "forward_charlm", "*")
        models = glob.glob(models_path)
        # we expect at least one English model downloaded for the tests
        assert len(models) >= 1
        model_file = models[0]
        return char_model.CharacterLanguageModel.load(model_file)

    @pytest.fixture(scope="class")
    def english_backward(self):
        # eg, stanza_test/models/en/forward_charlm/1billion.pt
        models_path = os.path.join(TEST_MODELS_DIR, "en", "backward_charlm", "*")
        models = glob.glob(models_path)
        # we expect at least one English model downloaded for the tests
        assert len(models) >= 1
        model_file = models[0]
        return char_model.CharacterLanguageModel.load(model_file)

    def test_load_model(self, english_forward, english_backward):
        """
        Check that basic loading functions work
        """
        assert english_forward.is_forward_lm
        assert not english_backward.is_forward_lm

    def test_save_load_model(self, english_forward, english_backward):
        """
        Load, save, and load again
        """
        with tempfile.TemporaryDirectory() as tempdir:
            for model in (english_forward, english_backward):
                save_file = os.path.join(tempdir, "resaved", "charlm.pt")
                model.save(save_file)
                reloaded = char_model.CharacterLanguageModel.load(save_file)
                assert model.is_forward_lm == reloaded.is_forward_lm