Welcome to mirror list, hosted at ThFree Co, Russian Federation.

npz_converter.cu « gpu « amun « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 14bb5dd852be6d519b01a18f208ca646201f3817 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include "npz_converter.h"
#include "common/exception.h"
#include "mblas/matrix_functions.h"

using namespace std;

namespace amunmt {
namespace GPU {

NpzConverter::NpzConverter(const std::string& file)
  : model_(cnpy::npz_load(file)),
    destructed_(false)
{
}

NpzConverter::~NpzConverter() {
  if(!destructed_)
    model_.destruct();
}

void NpzConverter::Destruct() {
  model_.destruct();
  destructed_ = true;
}

template<typename T>
T Debug(const T *data, size_t size)
{
  T sum = 0;
  for (size_t i = 0; i < size; ++i) {
    sum += data[i];
  }
  return sum;
}

std::shared_ptr<mblas::Matrix> NpzConverter::get(const std::string& key, bool mandatory, bool transpose) const
{
  //mblas::TestMemCpy();

  std::shared_ptr<mblas::Matrix> ret;
  auto it = model_.find(key);
  if(it != model_.end()) {
    NpyMatrixWrapper np(it->second);
    size_t size = np.size();

    mblas::Matrix *matrix = new mblas::Matrix(np.size1(), np.size2(), 1, 1);
    mblas::copy(np.data(), size, matrix->data(), cudaMemcpyHostToDevice);

    if (transpose) {
      mblas::Transpose(*matrix);
    }

    ret.reset(matrix);
  }
  else if (mandatory) {
    std::cerr << "Error: Matrix not found:" << key << std::endl;
    //amunmt_UTIL_THROW2(strm.str()); //  << key << std::endl
    abort();
  }
  else {
    mblas::Matrix *matrix = new mblas::Matrix();
    ret.reset(matrix);
  }

  //std::cerr << "key=" << key << " " << matrix.Debug(1) << std::endl;
  return ret;
}

std::shared_ptr<mblas::Matrix> NpzConverter::getFirstOfMany(const std::vector<std::pair<std::string, bool>> keys, bool mandatory) const
{
  std::shared_ptr<mblas::Matrix> ret;
  for (auto key : keys) {
    auto it = model_.find(key.first);
    if(it != model_.end()) {
      NpyMatrixWrapper np(it->second);
      mblas::Matrix *matrix = new mblas::Matrix(np.size1(), np.size2(), 1, 1);
      mblas::copy(np.data(), np.size(), matrix->data(), cudaMemcpyHostToDevice);

      if (key.second) {
        mblas::Transpose(*matrix);
      }
      ret.reset(matrix);
      return ret;
    }
  }

  if (mandatory) {
    std::cerr << "Error: Matrix not found:" << keys[0].first << std::endl;
    //amunmt_UTIL_THROW2(strm.str()); //  << key << std::endl
    abort();
  }
  else {
    std::cerr << "Optional matrix not found, continuing: " << keys[0].first << std::endl;
  }

  return ret;

}


}
}