diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2018-02-27 02:36:40 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2018-02-27 02:36:40 +0300 |
commit | 5867e6e10e9e2222238cb8da03e69e7bad07e8c0 (patch) | |
tree | 3f74216d3491acf70dd6720c223419eb138cf789 | |
parent | 43aeb5539ec0f0df9f16d0773561350ce9219fb4 (diff) |
MatrixWrapper -> TensorWrapper
-rw-r--r-- | src/amun/gpu/dl4mt/gru.cu | 14 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/gru.h | 28 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.cu | 104 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.h | 48 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_wrapper.h | 14 | ||||
-rw-r--r-- | src/amun/gpu/mblas/nth_element.cu | 4 | ||||
-rw-r--r-- | src/amun/gpu/mblas/nth_element_kernels.cu | 8 | ||||
-rw-r--r-- | src/amun/gpu/mblas/nth_element_kernels.h | 8 |
8 files changed, 114 insertions, 114 deletions
diff --git a/src/amun/gpu/dl4mt/gru.cu b/src/amun/gpu/dl4mt/gru.cu index a3b4d7ae..4b371e1d 100644 --- a/src/amun/gpu/dl4mt/gru.cu +++ b/src/amun/gpu/dl4mt/gru.cu @@ -5,13 +5,13 @@ using namespace std; namespace amunmt { namespace GPU { -__global__ void gElementwiseOps(mblas::MatrixWrapper<float> outWrap, - const mblas::MatrixWrapper<float> stateWrap, - const mblas::MatrixWrapper<float> ruhWrap, - const mblas::MatrixWrapper<float> tempWrap, - const mblas::MatrixWrapper<float> bWrap, - const mblas::MatrixWrapper<float> bx1Wrap, - const mblas::MatrixWrapper<float> bx2Wrap) +__global__ void gElementwiseOps(mblas::TensorWrapper<float> outWrap, + const mblas::TensorWrapper<float> stateWrap, + const mblas::TensorWrapper<float> ruhWrap, + const mblas::TensorWrapper<float> tempWrap, + const mblas::TensorWrapper<float> bWrap, + const mblas::TensorWrapper<float> bx1Wrap, + const mblas::TensorWrapper<float> bx2Wrap) { const unsigned rows = stateWrap.dim(0); const unsigned cols = stateWrap.dim(1); diff --git a/src/amun/gpu/dl4mt/gru.h b/src/amun/gpu/dl4mt/gru.h index f9bd1471..5dcca9fb 100644 --- a/src/amun/gpu/dl4mt/gru.h +++ b/src/amun/gpu/dl4mt/gru.h @@ -102,13 +102,13 @@ class SlowGRU: public Cell { /////////////////////////////////////////////////////////////////////////////////////////////// -__global__ void gElementwiseOps(mblas::MatrixWrapper<float> outWrap, - const mblas::MatrixWrapper<float> stateWrap, - const mblas::MatrixWrapper<float> ruhWrap, - const mblas::MatrixWrapper<float> tempWrap, - const mblas::MatrixWrapper<float> bWrap, - const mblas::MatrixWrapper<float> bx1Wrap, - const mblas::MatrixWrapper<float> bx2Wrap); +__global__ void gElementwiseOps(mblas::TensorWrapper<float> outWrap, + const mblas::TensorWrapper<float> stateWrap, + const mblas::TensorWrapper<float> ruhWrap, + const mblas::TensorWrapper<float> tempWrap, + const mblas::TensorWrapper<float> bWrap, + const mblas::TensorWrapper<float> bx1Wrap, + const mblas::TensorWrapper<float> bx2Wrap); template <class Weights> class FastGRU: public Cell { @@ -200,13 +200,13 @@ class FastGRU: public Cell { NextState.NewSize(State.dim(0), State.dim(1), 1, 1); //std::cerr << "NextState=" << NextState.Debug() << std::endl; - mblas::MatrixWrapper<float> nextWrap(NextState); - const mblas::MatrixWrapper<float> stateWrap(State); - const mblas::MatrixWrapper<float> ruhWrap(RUH); - const mblas::MatrixWrapper<float> tempWrap(Temp); - const mblas::MatrixWrapper<float> bWrap(*w_.B_); - const mblas::MatrixWrapper<float> bx1Wrap(*w_.Bx1_); - const mblas::MatrixWrapper<float> bx2Wrap(*w_.Bx2_); + mblas::TensorWrapper<float> nextWrap(NextState); + const mblas::TensorWrapper<float> stateWrap(State); + const mblas::TensorWrapper<float> ruhWrap(RUH); + const mblas::TensorWrapper<float> tempWrap(Temp); + const mblas::TensorWrapper<float> bWrap(*w_.B_); + const mblas::TensorWrapper<float> bx1Wrap(*w_.Bx1_); + const mblas::TensorWrapper<float> bx2Wrap(*w_.Bx2_); /* std::cerr << "nextWrap=" << nextWrap.Debug() << std::endl; diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu index 73640dd8..82f8dbff 100644 --- a/src/amun/gpu/mblas/matrix_functions.cu +++ b/src/amun/gpu/mblas/matrix_functions.cu @@ -16,8 +16,8 @@ Tensor& Swap(Tensor& Out, Tensor& In) { return Out; } -__global__ void gMean(MatrixWrapper<float> out, - const MatrixWrapper<float> in, +__global__ void gMean(TensorWrapper<float> out, + const TensorWrapper<float> in, const VectorWrapper<unsigned> sentenceLengths) { // out = batches * states @@ -65,8 +65,8 @@ void Mean(Tensor& Out, unsigned stateLength = Out.dim(1); unsigned sentenceLength = (In.dim(0) * In.dim(2) * In.dim(3)) / batchNum; - MatrixWrapper<float> outWrap(Out); - MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + TensorWrapper<float> inWrap(In); //cerr << "outWrap=" << outWrap.Debug() << endl; VectorWrapper<unsigned> sentenceLengthsWrap(sentenceLengths); @@ -81,9 +81,9 @@ void Mean(Tensor& Out, } -__global__ void gWeightedMean(MatrixWrapper<float> out, - const MatrixWrapper<float> weights, - const MatrixWrapper<float> in, +__global__ void gWeightedMean(TensorWrapper<float> out, + const TensorWrapper<float> weights, + const TensorWrapper<float> in, const VectorWrapper<unsigned> mapping ) { @@ -114,9 +114,9 @@ void WeightedMean(Tensor& Out,const Tensor& Weights, const Tensor& In, const mbl Out.NewSize(numHypos, states); - MatrixWrapper<float> outWrap(Out); - MatrixWrapper<float> weightsWrap(Weights); - MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + TensorWrapper<float> weightsWrap(Weights); + TensorWrapper<float> inWrap(In); VectorWrapper<unsigned> mappingWrap(mapping); unsigned size = Out.size(); @@ -179,8 +179,8 @@ Tensor& Copy(Tensor& Out, const Tensor& In) { return Out; } -__global__ void gPasteRows(MatrixWrapper<float> out, - const MatrixWrapper<float> in, +__global__ void gPasteRows(TensorWrapper<float> out, + const TensorWrapper<float> in, int rowNo, int colNo) { int inRows = in.dim(0); @@ -200,8 +200,8 @@ __global__ void gPasteRows(MatrixWrapper<float> out, void PasteRows(Tensor& Out, const Tensor& In, const unsigned rowNo, unsigned colNo) { - MatrixWrapper<float> outWrap(Out); - MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + TensorWrapper<float> inWrap(In); unsigned size = In.size(); unsigned nThreads = std::min((unsigned) MAX_THREADS, (unsigned)size); @@ -238,8 +238,8 @@ Tensor& CopyRow(Tensor& Out, return Out; } -__global__ void gCopyRows(MatrixWrapper<float> out, - const MatrixWrapper<float> in, +__global__ void gCopyRows(TensorWrapper<float> out, + const TensorWrapper<float> in, const VectorWrapper<unsigned> indicesWrap) { int id = threadIdx.x + blockIdx.x * blockDim.x; @@ -279,8 +279,8 @@ Tensor& CopyRows(Tensor& Out, unsigned numPairs = indices.size(); - MatrixWrapper<float> outWrap(Out); - const MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + const TensorWrapper<float> inWrap(In); const VectorWrapper<unsigned> indicesWrap(indices); //cerr << "size=" << size << endl; @@ -305,8 +305,8 @@ Tensor& Assemble(Tensor& Out, return Out; } -__global__ void gSlice(MatrixWrapper<float> out, - const MatrixWrapper<float> in, +__global__ void gSlice(TensorWrapper<float> out, + const TensorWrapper<float> in, unsigned n, unsigned dim) { unsigned row = blockIdx.x; @@ -332,8 +332,8 @@ Tensor& Slice(Tensor& Out, Out.NewSize(In.dim(0), dim); - MatrixWrapper<float> outWrap(Out); - const MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + const TensorWrapper<float> inWrap(In); /* cerr << "outWrap=" << outWrap.Debug() << endl; @@ -433,7 +433,7 @@ Tensor& Prod(Tensor& C, const Tensor& A, const Tensor& B, return ret; } -__global__ void gSoftMax(MatrixWrapper<float> out, +__global__ void gSoftMax(TensorWrapper<float> out, const VectorWrapper<unsigned> batchIdsWrap, const VectorWrapper<unsigned> sentenceLengthsWrap, unsigned shareSize) @@ -525,7 +525,7 @@ Tensor& Softmax(Tensor& Out, { unsigned maxLength = Out.dim(1); - MatrixWrapper<float> outWrap(Out); + TensorWrapper<float> outWrap(Out); const VectorWrapper<unsigned> batchIdsWrap(batchIds); const VectorWrapper<unsigned> sentenceLengthsWrap(sentenceLengths); @@ -540,7 +540,7 @@ Tensor& Softmax(Tensor& Out, return Out; } -__global__ void gLogSoftMax(MatrixWrapper<float> out, unsigned shareSize) +__global__ void gLogSoftMax(TensorWrapper<float> out, unsigned shareSize) { extern __shared__ float _share[]; @@ -622,7 +622,7 @@ __global__ void gLogSoftMax(MatrixWrapper<float> out, unsigned shareSize) Tensor& LogSoftmax(Tensor& Out) { - MatrixWrapper<float> outWrap(Out); + TensorWrapper<float> outWrap(Out); int blocks = std::min(MAX_BLOCKS, (int)Out.dim(0)); int threads = std::min(MAX_THREADS, (int)Out.dim(1)); @@ -635,7 +635,7 @@ Tensor& LogSoftmax(Tensor& Out) return Out; } -__global__ void gSetColumn(MatrixWrapper<float> in, int noColumn, float value) { +__global__ void gSetColumn(TensorWrapper<float> in, int noColumn, float value) { int n_rows = in.dim(0); int rowNumber = threadIdx.x + blockDim.x * blockIdx.x; @@ -650,14 +650,14 @@ void SetColumn(Tensor& In, int noColumn, float value) { int nBlocks = nRows / MAX_THREADS + ((nRows % MAX_THREADS == 0) ? 0 : 1); int nThreads = std::min(MAX_THREADS, nRows); - MatrixWrapper<float> inWrap(In); + TensorWrapper<float> inWrap(In); gSetColumn<<<nBlocks, nThreads, 0, mblas::CudaStreamHandler::GetStream()>>> (inWrap, noColumn, value); HANDLE_ERROR(cudaGetLastError()); } -__global__ void gFill(MatrixWrapper<float> in, float val) { +__global__ void gFill(TensorWrapper<float> in, float val) { int index = threadIdx.x + blockDim.x * blockIdx.x; if (index < in.size()) { in[index] = val; @@ -671,7 +671,7 @@ void Fill(Tensor& In, float value) { int nThreads = std::min(MAX_THREADS, (int)size); int nBlocks = (size / nThreads) + ((size % nThreads == 0) ? 0 : 1); - MatrixWrapper<float> inWrap(In); + TensorWrapper<float> inWrap(In); gFill<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> (inWrap, value); @@ -684,7 +684,7 @@ void Fill(Tensor& In, float value) { } __global__ -void gMapMatrix(MatrixWrapper<float> in, +void gMapMatrix(TensorWrapper<float> in, const VectorWrapper<unsigned> sentenceLengthsWrap, int i) { @@ -712,7 +712,7 @@ void MapMatrix(Tensor& state, int numThreads = std::min((int)state.size(), MAX_THREADS); int numBlocks = (state.size() / numThreads) + ((state.size() % numThreads == 0) ? 0 : 1); - MatrixWrapper<float> stateWrap(state); + TensorWrapper<float> stateWrap(state); VectorWrapper<unsigned> sentenceLengthsWrap(sentenceLengths); gMapMatrix<<<numBlocks, numThreads, 0, CudaStreamHandler::GetStream()>>> @@ -738,10 +738,10 @@ __device__ unsigned getIndex(const dim3 &dim, const dim3 &val) } -__global__ void gLNormalization(MatrixWrapper<float> out, - const MatrixWrapper<float> in, - const MatrixWrapper<float> alphaWrap, - const MatrixWrapper<float> betaWrap, +__global__ void gLNormalization(TensorWrapper<float> out, + const TensorWrapper<float> in, + const TensorWrapper<float> alphaWrap, + const TensorWrapper<float> betaWrap, float eps=0.00001) { extern __shared__ float _share[]; @@ -831,10 +831,10 @@ void Normalization(Tensor &out, dim3 numBlocks(in.dim(0), in.dim(2), in.dim(3)); int shared = numThreads * sizeof(float) * 2; - MatrixWrapper<float> outWrap(out); - const MatrixWrapper<float> inWrap(in); - const MatrixWrapper<float> alphaWrap(alpha); - MatrixWrapper<float> *betaWrap = beta ? new MatrixWrapper<float>(*beta) : new MatrixWrapper<float>(); + TensorWrapper<float> outWrap(out); + const TensorWrapper<float> inWrap(in); + const TensorWrapper<float> alphaWrap(alpha); + TensorWrapper<float> *betaWrap = beta ? new TensorWrapper<float>(*beta) : new TensorWrapper<float>(); gLNormalization<<<numBlocks, numThreads, shared, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, alphaWrap, *betaWrap, eps); @@ -928,7 +928,7 @@ void gBeamSizeInit(VectorWrapper<unsigned> hypo2BeamSizeWrap, } __device__ -float GetMaxScore(const MatrixWrapper<NthOutBatch> &nBestMatrix) +float GetMaxScore(const TensorWrapper<NthOutBatch> &nBestMatrix) { float ret = LOWEST_FLOAT; for (unsigned i = 0; i < nBestMatrix.dim(1); ++i) { @@ -1018,8 +1018,8 @@ void MergeElement(float &minScore, __device__ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap, float &topScore, - const MatrixWrapper<float> &in, - const MatrixWrapper<float> &b4Wrap, + const TensorWrapper<float> &in, + const TensorWrapper<float> &b4Wrap, unsigned hypoInd, unsigned maxBeamSize, bool forbidUNK, @@ -1029,10 +1029,10 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap, extern __shared__ char _sharePtr[]; // placeholder for shared mem in subsequent function SumAndLogSoftMax - //MatrixWrapper<float> maxMatrix((float*)_sharePtr, blockDim.x, 1, 1, 1); + //TensorWrapper<float> maxMatrix((float*)_sharePtr, blockDim.x, 1, 1, 1); void *ptrOffset = _sharePtr + sizeof(float) * blockDim.x; - MatrixWrapper<NthOutBatch> nBestMatrix((NthOutBatch*)ptrOffset, blockDim.x, maxBeamSize, 1, 1); + TensorWrapper<NthOutBatch> nBestMatrix((NthOutBatch*)ptrOffset, blockDim.x, maxBeamSize, 1, 1); VectorWrapper<NthOutBatch> row = nBestMatrix.Row(threadIdx.x); unsigned vocabSize = in.dim(1); @@ -1107,8 +1107,8 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap, /////////////////////////////////////////////////////////////////////////////////////////////////////// __device__ void SumAndLogSoftMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap, - const MatrixWrapper<float> &in, - const MatrixWrapper<float> &b4Wrap, + const TensorWrapper<float> &in, + const TensorWrapper<float> &b4Wrap, unsigned hypoInd, unsigned maxBeamSize, float topScore, @@ -1160,8 +1160,8 @@ void SumAndLogSoftMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap, /////////////////////////////////////////////////////////////////////////////////////////////////////// __global__ void gLogSoftMax(VectorWrapper<NthOutBatch> nBestCandidatesWrap, - const MatrixWrapper<float> in, - const MatrixWrapper<float> b4Wrap, + const TensorWrapper<float> in, + const TensorWrapper<float> b4Wrap, unsigned maxBeamSize, bool forbidUNK, const VectorWrapper<unsigned> hypo2BeamSizeWrap, @@ -1205,7 +1205,7 @@ __global__ void gLogSoftMax(VectorWrapper<NthOutBatch> nBestCandidatesWrap, /////////////////////////////////////////////////////////////////////////////////////////////////////// __global__ void gNBestPerBatch(VectorWrapper<NthOutBatch> nBestWrap, VectorWrapper<NthOutBatch> nBestCandidatesWrap, - const MatrixWrapper<float> in, + const TensorWrapper<float> in, const VectorWrapper<float> costsWrap, unsigned maxBeamSize, bool forbidUNK, @@ -1353,8 +1353,8 @@ void LogSoftmaxAndNBest(mblas::Vector<NthOutBatch> &nBest, cerr << endl; */ - MatrixWrapper<float> inWrap(in); - MatrixWrapper<float> b4Wrap(b4); + TensorWrapper<float> inWrap(in); + TensorWrapper<float> b4Wrap(b4); VectorWrapper<unsigned> hypo2BeamSizeWrap(hypo2BeamSize); VectorWrapper<unsigned> hypo2CandidateWrap(hypo2Candidate); VectorWrapper<unsigned> batch2HypoWrap(batch2Hypo); diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h index aa4a6a7c..55e86955 100644 --- a/src/amun/gpu/mblas/matrix_functions.h +++ b/src/amun/gpu/mblas/matrix_functions.h @@ -130,9 +130,9 @@ Tensor& LogSoftmax(Tensor& Out); template <class Functor> __global__ void gBroadcast(Functor functor, - MatrixWrapper<float> outWrap, - const MatrixWrapper<float> in1Wrap, - const MatrixWrapper<float> in2Wrap, + TensorWrapper<float> outWrap, + const TensorWrapper<float> in1Wrap, + const TensorWrapper<float> in2Wrap, const VectorWrapper<unsigned> batchMappingWrap) { int id = threadIdx.x + blockIdx.x * blockDim.x; @@ -183,9 +183,9 @@ Tensor& Broadcast(Functor functor, out.NewSize(srcSize, cols, sumOfBeamSizes); - MatrixWrapper<float> outWrap(out); - const MatrixWrapper<float> in1Wrap(in1); - const MatrixWrapper<float> in2Wrap(in2); + TensorWrapper<float> outWrap(out); + const TensorWrapper<float> in1Wrap(in1); + const TensorWrapper<float> in2Wrap(in2); const VectorWrapper<unsigned> batchMappingWrap(batchMapping); unsigned size = out.size(); @@ -215,7 +215,7 @@ Tensor& Broadcast(Functor functor, template <class Functor> __global__ void gBroadcastVecColumn(Functor functor, - MatrixWrapper<float> outWrap, + TensorWrapper<float> outWrap, const VectorWrapper<float> inWrap) { extern __shared__ float sdataOrig[]; @@ -245,7 +245,7 @@ Tensor& BroadcastVecColumn(Functor functor, Tensor& Out, const mblas::Vector<flo unsigned rows = Out.dim(0); unsigned cols = Out.dim(1); - MatrixWrapper<float> outWrap(Out); + TensorWrapper<float> outWrap(Out); const VectorWrapper<float> inWrap(In); int threads = std::min(MAX_THREADS, (int)cols); @@ -260,8 +260,8 @@ Tensor& BroadcastVecColumn(Functor functor, Tensor& Out, const mblas::Vector<flo template <class Functor> __global__ void gBroadcastVec(Functor functor, - MatrixWrapper<float> outWrap, - const MatrixWrapper<float> inWrap) + TensorWrapper<float> outWrap, + const TensorWrapper<float> inWrap) { unsigned cols = outWrap.dim(1); @@ -289,8 +289,8 @@ Tensor& BroadcastVec(Functor functor, Tensor& Out, const Tensor& In) unsigned cols = Out.dim(1); - MatrixWrapper<float> outWrap(Out); - const MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + const TensorWrapper<float> inWrap(In); int threads = std::min(MAX_THREADS, (int)cols); int blocks = cols / threads + ((cols % threads == 0) ? 0 : 1); @@ -305,7 +305,7 @@ Tensor& BroadcastVec(Functor functor, Tensor& Out, const Tensor& In) template <class Functor> __global__ void gElement(Functor functor, - MatrixWrapper<float> outWrap) + TensorWrapper<float> outWrap) { unsigned ind = blockIdx.x * blockDim.x + threadIdx.x; if (ind < outWrap.size()) { @@ -322,7 +322,7 @@ Tensor& Element(Functor functor, unsigned blocks = size / threads + ((size % threads == 0) ? 0 : 1); const cudaStream_t& stream = CudaStreamHandler::GetStream(); - MatrixWrapper<float> outWrap(Out); + TensorWrapper<float> outWrap(Out); gElement<<<blocks, threads, 0, stream>>> (functor, outWrap); @@ -333,8 +333,8 @@ Tensor& Element(Functor functor, template <class Functor> __global__ void gElement(Functor functor, - MatrixWrapper<float> outWrap, - const MatrixWrapper<float> inWrap) + TensorWrapper<float> outWrap, + const TensorWrapper<float> inWrap) { unsigned ind = blockIdx.x * blockDim.x + threadIdx.x; if (ind < outWrap.size()) { @@ -353,8 +353,8 @@ Tensor& Element(Functor functor, unsigned blocks = size / threads + ((size % threads == 0) ? 0 : 1); const cudaStream_t& stream = CudaStreamHandler::GetStream(); - MatrixWrapper<float> outWrap(Out); - const MatrixWrapper<float> inWrap(In); + TensorWrapper<float> outWrap(Out); + const TensorWrapper<float> inWrap(In); gElement<<<blocks, threads, 0, stream>>> (functor, outWrap, inWrap); @@ -365,9 +365,9 @@ Tensor& Element(Functor functor, template <class Functor> __global__ void gElement(Functor functor, - MatrixWrapper<float> outWrap, - const MatrixWrapper<float> in1Wrap, - const MatrixWrapper<float> in2Wrap) + TensorWrapper<float> outWrap, + const TensorWrapper<float> in1Wrap, + const TensorWrapper<float> in2Wrap) { unsigned ind = blockIdx.x * blockDim.x + threadIdx.x; if (ind < outWrap.size()) { @@ -395,9 +395,9 @@ Tensor& Element(Functor functor, //std::cerr << "Element3=" << In1.Debug(0) << std::endl; //std::cerr << "Element3=" << In2.Debug(0) << std::endl; //std::cerr << std::endl; - MatrixWrapper<float> outWrap(Out); - const MatrixWrapper<float> in1Wrap(In1); - const MatrixWrapper<float> in2Wrap(In2); + TensorWrapper<float> outWrap(Out); + const TensorWrapper<float> in1Wrap(In1); + const TensorWrapper<float> in2Wrap(In2); //std::cerr << "outWrap=" << outWrap.Debug() << std::endl; gElement<<<blocks, threads, 0, stream>>> diff --git a/src/amun/gpu/mblas/matrix_wrapper.h b/src/amun/gpu/mblas/matrix_wrapper.h index bd3de2dd..c2c6917e 100644 --- a/src/amun/gpu/mblas/matrix_wrapper.h +++ b/src/amun/gpu/mblas/matrix_wrapper.h @@ -7,10 +7,10 @@ namespace GPU { namespace mblas { template <typename T> -class MatrixWrapper +class TensorWrapper { public: - MatrixWrapper() + TensorWrapper() { dim_[0] = 0; dim_[1] = 0; @@ -22,7 +22,7 @@ public: dataConst_ = nullptr; } - MatrixWrapper(const TTensor<T> &matrix) + TensorWrapper(const TTensor<T> &matrix) { dim_[0] = matrix.dim(0); dim_[1] = matrix.dim(1); @@ -34,7 +34,7 @@ public: dataConst_ = matrix.data(); } - MatrixWrapper(TTensor<T> &matrix) + TensorWrapper(TTensor<T> &matrix) { dim_[0] = matrix.dim(0); dim_[1] = matrix.dim(1); @@ -46,7 +46,7 @@ public: dataConst_ = data_; } - MatrixWrapper(unsigned a, unsigned b, unsigned c, unsigned d) + TensorWrapper(unsigned a, unsigned b, unsigned c, unsigned d) { // test constructor dim_[0] = a; dim_[1] = b; @@ -59,7 +59,7 @@ public: } __device__ - MatrixWrapper(T *ptr, unsigned a, unsigned b, unsigned c, unsigned d) + TensorWrapper(T *ptr, unsigned a, unsigned b, unsigned c, unsigned d) { dim_[0] = a; dim_[1] = b; @@ -309,7 +309,7 @@ protected: inline void testidToMatrixInd() { - MatrixWrapper<float> matrix(2, 4, 3, 5); + TensorWrapper<float> matrix(2, 4, 3, 5); std::cerr << "matrix=" << matrix.Debug() << std::endl; diff --git a/src/amun/gpu/mblas/nth_element.cu b/src/amun/gpu/mblas/nth_element.cu index 7bf72af7..f7c221de 100644 --- a/src/amun/gpu/mblas/nth_element.cu +++ b/src/amun/gpu/mblas/nth_element.cu @@ -100,7 +100,7 @@ void NthElement::getNBestList(mblas::Tensor &probs, cudaMemcpyHostToDevice); mblas::VectorWrapper<NthOut> outWrap(d_out); - mblas::MatrixWrapper<float> probsWrap(probs); + mblas::TensorWrapper<float> probsWrap(probs); mblas::VectorWrapper<unsigned> batchPositionWrap(d_batchPosition); mblas::VectorWrapper<NthOut> resWrap(d_res); mblas::VectorWrapper<unsigned> cumBeamSizesWrap(d_cumBeamSizes); @@ -161,7 +161,7 @@ void NthElement::getValueByKey(std::vector<float>& out, const mblas::Tensor &d_i out.resize(d_breakdown.size()); //mblas::VectorWrapper<float> breakdownWrap(d_breakdown); - //const mblas::MatrixWrapper<float> inWrap(d_in); + //const mblas::TensorWrapper<float> inWrap(d_in); //gGetValueByKey<<<1, lastN_, 0, stream_>>> // (breakdownWrap, inWrap, h_res_idx, lastN_); /* diff --git a/src/amun/gpu/mblas/nth_element_kernels.cu b/src/amun/gpu/mblas/nth_element_kernels.cu index f7707f71..5c1ae460 100644 --- a/src/amun/gpu/mblas/nth_element_kernels.cu +++ b/src/amun/gpu/mblas/nth_element_kernels.cu @@ -20,7 +20,7 @@ void UnrollMaxArgLoop(unsigned n, unsigned max, unsigned tid, float *sdata, unsi } __global__ void gMaxElement(mblas::VectorWrapper<NthOut> out, - const mblas::MatrixWrapper<float> probsWrap, + const mblas::TensorWrapper<float> probsWrap, const mblas::VectorWrapper<unsigned> batchPositionWrap, unsigned numBatches) { extern __shared__ float sdata[]; @@ -98,7 +98,7 @@ __global__ void gMaxElement(mblas::VectorWrapper<NthOut> out, } __global__ void gMaxElementUpdate(mblas::VectorWrapper<NthOut> out, - mblas::MatrixWrapper<float> probsWrap, + mblas::TensorWrapper<float> probsWrap, mblas::VectorWrapper<NthOut> resWrap, const mblas::VectorWrapper<unsigned> batchPositionWrap, const mblas::VectorWrapper<unsigned> cumBeamSizesWrap, @@ -253,8 +253,8 @@ __global__ void gMaxElementUpdate(mblas::VectorWrapper<NthOut> out, } } -__global__ void gGetValueByKey(mblas::MatrixWrapper<float> out, - const mblas::MatrixWrapper<float> in, +__global__ void gGetValueByKey(mblas::TensorWrapper<float> out, + const mblas::TensorWrapper<float> in, unsigned* indices, unsigned n) { unsigned tid = threadIdx.x + blockDim.x * blockIdx.x; diff --git a/src/amun/gpu/mblas/nth_element_kernels.h b/src/amun/gpu/mblas/nth_element_kernels.h index aeefcdd7..094f6181 100644 --- a/src/amun/gpu/mblas/nth_element_kernels.h +++ b/src/amun/gpu/mblas/nth_element_kernels.h @@ -111,19 +111,19 @@ inline std::ostream& operator<<(std::ostream &out, const NthOutBatch &obj) ///////////////////////////////////////////////////////////////////////////////////////// __global__ void gMaxElement(mblas::VectorWrapper<NthOut> out, - const mblas::MatrixWrapper<float> probsWrap, + const mblas::TensorWrapper<float> probsWrap, const mblas::VectorWrapper<unsigned> batchPositionWrap, unsigned numBatches); __global__ void gMaxElementUpdate(mblas::VectorWrapper<NthOut> out, - mblas::MatrixWrapper<float> probsWrap, + mblas::TensorWrapper<float> probsWrap, mblas::VectorWrapper<NthOut> resWrap, const mblas::VectorWrapper<unsigned> batchPositionWrap, const mblas::VectorWrapper<unsigned> cumBeamSizesWrap, unsigned numBlocks); -__global__ void gGetValueByKey(mblas::MatrixWrapper<float> out, - const mblas::MatrixWrapper<float> in, +__global__ void gGetValueByKey(mblas::TensorWrapper<float> out, + const mblas::TensorWrapper<float> in, unsigned* indices, unsigned n); } |