Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-11-27 15:24:45 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-11-27 15:24:45 +0300
commitb24b03346cec4ed399946da57bd62ca59c5ebd64 (patch)
tree69d4abd0997cf33d3ab41e81c1f6abbd19ce6603
parentb9371d035e949c0a047499b5be9e1676621ada93 (diff)
use vector in matrix
-rw-r--r--src/amun/gpu/mblas/matrix.h126
-rw-r--r--src/amun/gpu/mblas/vector.h88
2 files changed, 59 insertions, 155 deletions
diff --git a/src/amun/gpu/mblas/matrix.h b/src/amun/gpu/mblas/matrix.h
index 8dc419a9..f016f9b3 100644
--- a/src/amun/gpu/mblas/matrix.h
+++ b/src/amun/gpu/mblas/matrix.h
@@ -55,8 +55,6 @@ class TMatrix : public BaseMatrix {
typedef T value_type;
TMatrix()
- : arrSize_(0)
- , data_(nullptr)
{
dim_[0] = 0;
dim_[1] = 0;
@@ -70,11 +68,12 @@ class TMatrix : public BaseMatrix {
dim_[1] = cols;
dim_[2] = beam;
dim_[3] = batches;
- arrSize_ = size();
- HANDLE_ERROR( cudaMalloc(&data_, arrSize_ * sizeof(T)) );
+ uint newSize = size();
+ vec_.newSize(newSize);
+
if (zero) {
- HANDLE_ERROR( cudaMemsetAsync(data_, 0, arrSize_ * sizeof(T), CudaStreamHandler::GetStream()) );
+ HANDLE_ERROR( cudaMemsetAsync(vec_.data(), 0, newSize * sizeof(T), CudaStreamHandler::GetStream()) );
}
}
@@ -85,26 +84,16 @@ class TMatrix : public BaseMatrix {
}
TMatrix(const TMatrix& m)
- : arrSize_(m.arrSize_)
+ : vec_(m.vec_)
{
dim_[0] = m.dim_[0];
dim_[1] = m.dim_[1];
dim_[2] = m.dim_[2];
dim_[3] = m.dim_[3];
-
- HANDLE_ERROR( cudaMalloc(&data_, arrSize_ * sizeof(T)) );
- //std::cerr << "malloc data2:" << data_ << std::endl;
- HANDLE_ERROR( cudaMemcpyAsync(
- data_,
- m.data_,
- arrSize_ * sizeof(T),
- cudaMemcpyDeviceToDevice,
- CudaStreamHandler::GetStream()) );
}
~TMatrix()
{
- HANDLE_ERROR(cudaFree(data_));
}
virtual size_t dim(size_t i) const
@@ -114,41 +103,7 @@ class TMatrix : public BaseMatrix {
void Resize(size_t rows, size_t cols, size_t beam = 1, size_t batches = 1) {
size_t newSize = cols * rows * beam * batches;
- if (data_) {
- if (newSize > arrSize_) {
- T *newData;
- HANDLE_ERROR( cudaMalloc(&newData, newSize * sizeof(T)) );
- //std::cerr << "malloc data3:" << data_ << std::endl;
-
- //size_t count = std::min(arrSize_, newSize);
-
- HANDLE_ERROR( cudaMemcpyAsync(
- newData,
- data_,
- size() * sizeof(T),
- cudaMemcpyDeviceToDevice,
- CudaStreamHandler::GetStream()) );
-
- //std::cerr << "free data1:" << data_ << std::endl;
- HANDLE_ERROR(cudaFree(data_));
- data_ = newData;
- arrSize_ = newSize;
- }
- else if (rows == 0 || cols == 0) {
- HANDLE_ERROR(cudaFree(data_));
- data_ = nullptr;
- dim_[0] = 0;
- dim_[1] = 0;
- dim_[2] = 0;
- dim_[3] = 0;
- arrSize_ = 0;
- }
- }
- else {
- HANDLE_ERROR( cudaMalloc(&data_, newSize * sizeof(T)) );
- //std::cerr << "malloc data4:" << data_ << std::endl;
- arrSize_ = newSize;
- }
+ vec_.resize(newSize);
dim_[0] = rows;
dim_[1] = cols;
@@ -158,29 +113,7 @@ class TMatrix : public BaseMatrix {
void NewSize(size_t rows, size_t cols, size_t beam = 1, size_t batches = 1) {
size_t newSize = cols * rows * beam * batches;
- if (data_) {
- if (newSize > arrSize_) {
- T *newData;
- HANDLE_ERROR( cudaMalloc(&newData, newSize * sizeof(T)) );
- HANDLE_ERROR( cudaFree(data_));
- data_ = newData;
- arrSize_ = newSize;
- }
- else if (rows == 0 || cols == 0) {
- HANDLE_ERROR( cudaFree(data_));
- data_ = nullptr;
- dim_[0] = 0;
- dim_[1] = 0;
- dim_[2] = 0;
- dim_[3] = 0;
- arrSize_ = 0;
- }
- }
- else {
- HANDLE_ERROR( cudaMalloc(&data_, newSize * sizeof(T)) );
- //std::cerr << "malloc data4:" << data_ << std::endl;
- arrSize_ = newSize;
- }
+ vec_.newSize(newSize);
dim_[0] = rows;
dim_[1] = cols;
@@ -188,40 +121,13 @@ class TMatrix : public BaseMatrix {
dim_[3] = batches;
}
- void reserve(size_t size)
- {
- assert(data_ == nullptr);
- HANDLE_ERROR( cudaMalloc(&data_, size * sizeof(T)) );
- arrSize_ = size;
- }
-
- /*
- void ReduceDimensions()
- {
- if (dim_[2] == 1) {
- dim_[2] = dim_[3];
- dim_[3] = 1;
- }
- if (dim_[0] == 1) {
- dim_[0] = dim_[2];
- dim_[2] = dim_[3];
- dim_[3] = 1;
- }
- if (dim_[1] == 1) {
- dim_[1] = dim_[0];
- dim_[0] = dim_[2];
- dim_[2] = dim_[3];
- dim_[3] = 1;
- }
- }
- */
-
virtual std::string Debug(size_t verbosity = 1) const
{
std::stringstream strm;
strm << BaseMatrix::Debug(verbosity) << " ";
- strm << data_ << " "
- << arrSize_ << " "
+ strm << vec_.data() << " "
+ << vec_.size() << " "
+ << vec_.maxSize() << " "
<< std::flush;
if (verbosity) {
@@ -234,7 +140,7 @@ class TMatrix : public BaseMatrix {
HANDLE_ERROR( cudaMemcpyAsync(
&h_data,
- data_,
+ vec_.data(),
size() * sizeof(T),
cudaMemcpyDeviceToHost,
stream) );
@@ -250,26 +156,22 @@ class TMatrix : public BaseMatrix {
}
value_type* data() {
- return data_;
+ return vec_.data();
}
const value_type* data() const {
- return data_;
+ return vec_.data();
}
void swap(TMatrix &other)
{
std::swap(dim_, other.dim_);
- std::swap(arrSize_, other.arrSize_);
- std::swap(data_, other.data_);
+ vec_.swap(other.vec_);
}
private:
size_t dim_[SHAPE_SIZE];
-
Vector<T> vec_;
- size_t arrSize_;
- T *data_;
};
typedef TMatrix<float> Matrix;
diff --git a/src/amun/gpu/mblas/vector.h b/src/amun/gpu/mblas/vector.h
index 6d66158e..985af0a2 100644
--- a/src/amun/gpu/mblas/vector.h
+++ b/src/amun/gpu/mblas/vector.h
@@ -17,20 +17,20 @@ class Vector
{
public:
Vector()
- :m_size(0)
- ,m_maxSize(0)
- ,m_arr(nullptr)
+ :size_(0)
+ ,maxSize_(0)
+ ,data_(nullptr)
{
}
Vector(size_t size)
- :m_maxSize(0)
+ :maxSize_(0)
{
newSize(size);
}
Vector(size_t size, const T &val)
- :m_maxSize(0)
+ :maxSize_(0)
{
newSize(size);
@@ -38,109 +38,111 @@ public:
abort();
}
else {
- HANDLE_ERROR(cudaMemsetAsync(m_arr, 0, m_size * sizeof(float), CudaStreamHandler::GetStream()));
+ HANDLE_ERROR(cudaMemsetAsync(data_, 0, size_ * sizeof(float), CudaStreamHandler::GetStream()));
}
}
Vector(const std::vector<T> &vec)
- :m_maxSize(0)
+ :maxSize_(0)
{
newSize(vec.size());
- HANDLE_ERROR( cudaMemcpyAsync(m_arr, vec.data(), vec.size() * sizeof(T), cudaMemcpyHostToDevice, CudaStreamHandler::GetStream()) );
+ HANDLE_ERROR( cudaMemcpyAsync(data_, vec.data(), vec.size() * sizeof(T), cudaMemcpyHostToDevice, CudaStreamHandler::GetStream()) );
}
Vector(const Vector<T> &other)
- :m_maxSize(other.m_size)
+ :maxSize_(other.size_)
+ ,size_(other.size_)
{
- HANDLE_ERROR( cudaMalloc(&m_arr, m_size * sizeof(T)) );
+ HANDLE_ERROR( cudaMalloc(&data_, size_ * sizeof(T)) );
//std::cerr << "malloc data2:" << data_ << std::endl;
HANDLE_ERROR( cudaMemcpyAsync(
- m_arr,
- other.m_arr,
- m_size * sizeof(T),
+ data_,
+ other.data_,
+ size_ * sizeof(T),
cudaMemcpyDeviceToDevice,
CudaStreamHandler::GetStream()) );
}
~Vector()
{
- HANDLE_ERROR(cudaFree(m_arr));
+ HANDLE_ERROR(cudaFree(data_));
}
size_t size() const
- { return m_size; }
+ { return size_; }
+
+ size_t maxSize() const
+ { return maxSize_; }
T *data()
- { return m_arr; }
+ { return data_; }
const T *data() const
- { return m_arr; }
-
- void setdata(T *val)
- {
- m_arr = val;
- }
+ { return data_; }
void resize(size_t newSize)
{
- if (newSize > m_maxSize) {
+ if (newSize > maxSize_) {
T *newData;
HANDLE_ERROR( cudaMalloc(&newData, newSize * sizeof(T)) );
- if (m_maxSize) {
- assert(m_arr);
+ if (maxSize_) {
+ assert(data_);
HANDLE_ERROR( cudaMemcpyAsync(
newData,
- m_arr,
- m_size * sizeof(T),
+ data_,
+ size_ * sizeof(T),
cudaMemcpyDeviceToDevice,
CudaStreamHandler::GetStream()) );
- HANDLE_ERROR(cudaFree(m_arr));
+ HANDLE_ERROR(cudaFree(data_));
+ }
+ else {
+ assert(data_ == nullptr);
}
- m_arr = newData;
- m_maxSize = newSize;
+ data_ = newData;
+ maxSize_ = newSize;
}
- m_size = newSize;
+ size_ = newSize;
}
void newSize(size_t newSize)
{
reserve(newSize);
- m_size = newSize;
+ size_ = newSize;
}
void reserve(size_t newSize)
{
- if (newSize > m_maxSize) {
- if (m_maxSize) {
- HANDLE_ERROR(cudaFree(m_arr));
+ if (newSize > maxSize_) {
+ if (maxSize_) {
+ HANDLE_ERROR(cudaFree(data_));
}
- HANDLE_ERROR( cudaMalloc(&m_arr, newSize * sizeof(T)) );
+ HANDLE_ERROR( cudaMalloc(&data_, newSize * sizeof(T)) );
- m_maxSize = newSize;
+ maxSize_ = newSize;
}
}
void clear()
{
- m_size = 0;
+ size_ = 0;
}
void swap(Vector &other)
{
- std::swap(m_size, other.m_size);
- std::swap(m_maxSize, other.m_maxSize);
- std::swap(m_arr, other.m_arr);
+ std::swap(size_, other.size_);
+ std::swap(maxSize_, other.maxSize_);
+ std::swap(data_, other.data_);
}
protected:
- size_t m_size, m_maxSize;
- T *m_arr;
+ size_t size_, maxSize_;
+ T *data_;