Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian-regression-tests.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorAlham Fikri Aji <afaji321@gmail.com>2020-11-10 08:22:37 +0300
committerAlham Fikri Aji <afaji321@gmail.com>2020-11-10 08:22:37 +0300
commit7c3a2be843fcff48bfb21c6ddbf4e2231093de8c (patch)
tree6587227098971009bcf45f348360eb50aadc2200 /tools
parentcdad78089484d7817d91c803d6fc7049328e20db (diff)
add regression tests for model quantization
Diffstat (limited to 'tools')
-rwxr-xr-xtools/check-model-unique-vals.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tools/check-model-unique-vals.py b/tools/check-model-unique-vals.py
new file mode 100755
index 0000000..97c1e91
--- /dev/null
+++ b/tools/check-model-unique-vals.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+import argparse
+import re
+
+import numpy as np
+
+def main():
+ exit_code = 0
+ args = parse_user_args()
+
+ with np.load(args.file) as data:
+ for key in data:
+ # skip special:model.yml
+ if "special" in key:
+ continue
+
+ # if one of the dimension is 1, then it is a bias
+ # skip if it is bias and bias is not included
+ smallest_dim = sorted(data[key].shape)[0]
+ if(smallest_dim == 1 and not args.with_bias):
+ continue
+
+ if (np.unique(data[key]).size > 2**args.bits):
+ message("Tensor {} has more than {} unique values".format( \
+ key, \
+ 2**args.bits), args)
+ exit_code = 1
+
+ return exit_code
+
+
+def message(text, args):
+ if not text.endswith("\n"):
+ text += "\n"
+ args.output.write(text)
+ if not args.quiet \
+ and args.output is not sys.stdout \
+ and args.output is not sys.stderr:
+ sys.stderr.write(text)
+
+
+def parse_user_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("file", type=str)
+ parser.add_argument("-o", "--output", type=argparse.FileType('w'), metavar="FILE", default=sys.stdout)
+ parser.add_argument("-b", "--bits", type=int)
+ parser.add_argument("--with_bias", action="store_true")
+ parser.add_argument("-q", "--quiet", action="store_true")
+ return parser.parse_args()
+
+if __name__ == '__main__':
+ code = main()
+ exit(code)