Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-06 03:08:53 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-06 03:08:53 +0300
commitfab48eb3623f4426280069c0fceea64dfe9e4445 (patch)
treef858ead384f09f058a77ae6391ad5e42e961020c
parent2bff43075c530e58447621eb2598d64a44b1f993 (diff)
use new io for optimizers
-rw-r--r--src/optimizers/optimizers.cpp69
1 files changed, 42 insertions, 27 deletions
diff --git a/src/optimizers/optimizers.cpp b/src/optimizers/optimizers.cpp
index c6b92ad1..fc266325 100644
--- a/src/optimizers/optimizers.cpp
+++ b/src/optimizers/optimizers.cpp
@@ -1,7 +1,7 @@
#include "optimizers.h"
#include "tensors/tensor_operators.h"
-#include "3rd_party/cnpy/cnpy.h"
+#include "common/io.h"
namespace marian {
@@ -49,19 +49,16 @@ void Adagrad::load(const std::string& name,
size_t totalSize = 0;
// @TODO: use new IO
- auto numpy = cnpy::npz_load(name);
- for(auto it : numpy) {
- auto name = it.first;
- auto np = it.second;
-
+ auto items = io::loadItems(name);
+ for(auto item : items) {
// get the size of gt_
- totalSize = np->shape[1];
+ totalSize = item.shape.elements();
// extract data into vectors
- if(name == "adagrad_gt") {
+ if(item.name == "adagrad_gt") {
vGt.resize(totalSize);
std::copy(
- (float*)np->data(), (float*)np->data() + totalSize, vGt.begin());
+ (float*)item.data(), (float*)item.data() + totalSize, vGt.begin());
}
}
@@ -110,9 +107,16 @@ void Adagrad::save(const std::string& name,
vGt.insert(vGt.end(), tmp.begin(), tmp.end());
}
- unsigned int shape[2] = {1, (unsigned int)vGt.size()};
+ io::Item item;
+ item.name = "adagrad_gt";
+ item.shape = Shape({1, (int)vGt.size()});
+ item.type = Type::float32;
+ item.bytes.resize(vGt.size() * sizeOf(item.type));
+ std::copy((char*)vGt.data(),
+ (char*)vGt.data() + vGt.size(),
+ item.bytes.begin());
- cnpy::npz_save(name, "adagrad_gt", vGt.data(), shape, 2, "w");
+ io::saveItems(name, {item});
}
void Adagrad::resetStats() {
@@ -166,24 +170,22 @@ void Adam::load(const std::string& name,
std::vector<float> vVt;
size_t totalSize = 0;
- auto numpy = cnpy::npz_load(name);
- for(auto it : numpy) {
- auto name = it.first;
- auto np = it.second;
+ auto items = io::loadItems(name);
+ for(auto item : items) {
// get the size of mt_ and vt_, they are the same
- totalSize = np->shape[1];
+ totalSize = item.shape.elements();
// extract data into vectors
- if(name == "adam_mt") {
+ if(item.name == "adam_mt") {
vMt.resize(totalSize);
std::copy(
- (float*)np->data(), (float*)np->data() + totalSize, vMt.begin());
+ (float*)item.data(), (float*)item.data() + totalSize, vMt.begin());
}
- if(name == "adam_vt") {
+ if(item.name == "adam_vt") {
vVt.resize(totalSize);
std::copy(
- (float*)np->data(), (float*)np->data() + totalSize, vVt.begin());
+ (float*)item.data(), (float*)item.data() + totalSize, vVt.begin());
}
}
@@ -238,13 +240,26 @@ void Adam::save(const std::string& name,
opt->vt_->get(tmp);
vVt.insert(vVt.end(), tmp.begin(), tmp.end());
}
-
- // the shape is the same for mt_ and vt_
- std::vector<unsigned int> shape{1, (unsigned int)vMt.size()};
-
- cnpy::npz_save(name,
- {cnpy::NpzItem("adam_mt", vMt, shape),
- cnpy::NpzItem("adam_vt", vVt, shape)});
+
+ io::Item itemMt;
+ itemMt.name = "adam_mt";
+ itemMt.shape = Shape({1, (int)vMt.size()});
+ itemMt.type = Type::float32;
+ itemMt.bytes.resize(vMt.size() * sizeOf(itemMt.type));
+ std::copy((char*)vMt.data(),
+ (char*)vMt.data() + vMt.size(),
+ itemMt.bytes.begin());
+
+ io::Item itemVt;
+ itemVt.name = "adam_vt";
+ itemVt.shape = Shape({1, (int)vVt.size()});
+ itemVt.type = Type::float32;
+ itemVt.bytes.resize(vVt.size() * sizeOf(itemVt.type));
+ std::copy((char*)vVt.data(),
+ (char*)vVt.data() + vVt.size(),
+ itemVt.bytes.begin());
+
+ io::saveItems(name, {itemMt, itemVt});
}
void Adam::resetStats() {