Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-01-25 18:56:07 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-01-25 18:56:07 +0300
commit86d2ebe60b1cd3847730d4091827625e5524a9a0 (patch)
tree468ca7f64b842178380ad9d0763759f7b8623eb9
parentd1b1212dc706df6a6bf5afccbff4d5331a23afde (diff)
template matrix
-rw-r--r--contrib/fpga/main.cpp8
-rw-r--r--contrib/fpga/matrix.h7
2 files changed, 8 insertions, 7 deletions
diff --git a/contrib/fpga/main.cpp b/contrib/fpga/main.cpp
index a26ffad4..f66e2abd 100644
--- a/contrib/fpga/main.cpp
+++ b/contrib/fpga/main.cpp
@@ -26,10 +26,10 @@ int main()
cl_kernel kernel = CreateKernel("kernels/OutputLayer.cl", "OutputLayer_float", openCLInfo);
cerr << "CreateKernel done" << endl;
- Matrix W(openCLInfo, true, 85000, 512);
- Matrix X(openCLInfo, true, 512, 640);
- Matrix B(openCLInfo, true, 1, 85000);
- Matrix Y(openCLInfo, true, 85000, 640);
+ Matrix<float> W(openCLInfo, true, 85000, 512);
+ Matrix<float> X(openCLInfo, true, 512, 640);
+ Matrix<float> B(openCLInfo, true, 1, 85000);
+ Matrix<float> Y(openCLInfo, true, 85000, 640);
vector<float> vec;
diff --git a/contrib/fpga/matrix.h b/contrib/fpga/matrix.h
index 9e47ade2..7aac294d 100644
--- a/contrib/fpga/matrix.h
+++ b/contrib/fpga/matrix.h
@@ -2,6 +2,7 @@
#include <cassert>
#include "types-fpga.h"
+template<typename T>
class Matrix
{
public:
@@ -14,7 +15,7 @@ public:
size_ = a * b;
cl_int err;
- mem_ = clCreateBuffer(openCLInfo.context, CL_MEM_READ_WRITE, sizeof(float) * size(), NULL, &err);
+ mem_ = clCreateBuffer(openCLInfo.context, CL_MEM_READ_WRITE, sizeof(T) * size(), NULL, &err);
CheckError(err);
}
@@ -33,10 +34,10 @@ public:
unsigned size() const
{ return size_; }
- void Set(const float *arr, size_t count)
+ void Set(const T *arr, size_t count)
{
assert(count <= size_);
- size_t bytes = count * sizeof(float);
+ size_t bytes = count * sizeof(T);
CheckError( clEnqueueWriteBuffer(
openCLInfo_.commands,
mem_,