Welcome to mirror list, hosted at ThFree Co, Russian Federation.

average.py « scripts - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 4d5c0a2f16fc5b340c364dd876b7b50c1ef48e53 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python

import sys
import numpy as np;

average = dict()

n = len(sys.argv[1:-1])
for filename in sys.argv[1:-1]:
    print "Loading", filename 
    with open(filename, "rb") as mfile:
        m = np.load(mfile)
        for k in m:
            if k not in average:
                average[k] = m[k] / n
            elif average[k].shape == m[k].shape:
                average[k] += m[k] / n

print "Saving to", sys.argv[-1]
np.savez(sys.argv[-1], **average)