Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/ctranslate2/specs/model_spec.py4
-rw-r--r--python/tests/test_spec.py10
2 files changed, 7 insertions, 7 deletions
diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py
index ea137472..405e7a83 100644
--- a/python/ctranslate2/specs/model_spec.py
+++ b/python/ctranslate2/specs/model_spec.py
@@ -161,8 +161,9 @@ class LayerSpec(FrozenAttr, metaclass=FrozenMeta):
if not isinstance(value, np.ndarray):
return
+ key = _split_scope(name)[-1]
scale = None
- is_quantizable = hasattr(spec, "%s_scale" % name)
+ is_quantizable = hasattr(spec, "%s_scale" % key)
if is_quantizable:
if quantization == "int16":
@@ -194,7 +195,6 @@ class LayerSpec(FrozenAttr, metaclass=FrozenMeta):
if value.dtype == np.float16:
value = value.astype(np.float32)
- key = _split_scope(name)[-1]
setattr(spec, key, value)
if scale is not None:
setattr(spec, "%s_scale" % key, scale)
diff --git a/python/tests/test_spec.py b/python/tests/test_spec.py
index e5f2cc57..fb93f36c 100644
--- a/python/tests/test_spec.py
+++ b/python/tests/test_spec.py
@@ -44,6 +44,8 @@ def test_layer_spec_optimize():
class SubSpec(ctranslate2.specs.LayerSpec):
def __init__(self):
self.a = np.ones([6], dtype=np.float32)
+ self.weight = np.ones([5, 4], dtype=np.float32)
+ self.weight_scale = OPTIONAL
class Spec(ctranslate2.specs.LayerSpec):
def __init__(self):
@@ -51,8 +53,6 @@ def test_layer_spec_optimize():
self.b = np.ones([5], dtype=np.float32)
self.c = np.zeros([5], dtype=np.int32)
self.d = np.dtype("float32").type(3.14)
- self.weight = np.ones([5, 4], dtype=np.float32)
- self.weight_scale = OPTIONAL
self.sub = SubSpec()
spec = Spec()
@@ -61,8 +61,8 @@ def test_layer_spec_optimize():
assert spec.b == "a"
assert spec.c.dtype == np.int32
assert spec.d.dtype == np.float32
- assert spec.weight.dtype == np.int16
- assert spec.weight_scale.dtype == np.float32
+ assert spec.sub.weight.dtype == np.int16
+ assert spec.sub.weight_scale.dtype == np.float32
spec = Spec()
spec.optimize(quantization="float16")
@@ -70,7 +70,7 @@ def test_layer_spec_optimize():
assert spec.b == "a"
assert spec.c.dtype == np.int32
assert spec.d.dtype == np.float32
- assert spec.weight.dtype == np.float16
+ assert spec.sub.weight.dtype == np.float16
assert spec.sub.a.dtype == np.float16