diff options
-rw-r--r-- | python/ctranslate2/specs/model_spec.py | 4 | ||||
-rw-r--r-- | python/tests/test_spec.py | 10 |
2 files changed, 7 insertions, 7 deletions
diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index ea137472..405e7a83 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -161,8 +161,9 @@ class LayerSpec(FrozenAttr, metaclass=FrozenMeta): if not isinstance(value, np.ndarray): return + key = _split_scope(name)[-1] scale = None - is_quantizable = hasattr(spec, "%s_scale" % name) + is_quantizable = hasattr(spec, "%s_scale" % key) if is_quantizable: if quantization == "int16": @@ -194,7 +195,6 @@ class LayerSpec(FrozenAttr, metaclass=FrozenMeta): if value.dtype == np.float16: value = value.astype(np.float32) - key = _split_scope(name)[-1] setattr(spec, key, value) if scale is not None: setattr(spec, "%s_scale" % key, scale) diff --git a/python/tests/test_spec.py b/python/tests/test_spec.py index e5f2cc57..fb93f36c 100644 --- a/python/tests/test_spec.py +++ b/python/tests/test_spec.py @@ -44,6 +44,8 @@ def test_layer_spec_optimize(): class SubSpec(ctranslate2.specs.LayerSpec): def __init__(self): self.a = np.ones([6], dtype=np.float32) + self.weight = np.ones([5, 4], dtype=np.float32) + self.weight_scale = OPTIONAL class Spec(ctranslate2.specs.LayerSpec): def __init__(self): @@ -51,8 +53,6 @@ def test_layer_spec_optimize(): self.b = np.ones([5], dtype=np.float32) self.c = np.zeros([5], dtype=np.int32) self.d = np.dtype("float32").type(3.14) - self.weight = np.ones([5, 4], dtype=np.float32) - self.weight_scale = OPTIONAL self.sub = SubSpec() spec = Spec() @@ -61,8 +61,8 @@ def test_layer_spec_optimize(): assert spec.b == "a" assert spec.c.dtype == np.int32 assert spec.d.dtype == np.float32 - assert spec.weight.dtype == np.int16 - assert spec.weight_scale.dtype == np.float32 + assert spec.sub.weight.dtype == np.int16 + assert spec.sub.weight_scale.dtype == np.float32 spec = Spec() spec.optimize(quantization="float16") @@ -70,7 +70,7 @@ def test_layer_spec_optimize(): assert spec.b == "a" assert spec.c.dtype == np.int32 assert spec.d.dtype == np.float32 - assert spec.weight.dtype == np.float16 + assert spec.sub.weight.dtype == np.float16 assert spec.sub.a.dtype == np.float16 |