Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-04 19:33:09 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-04 19:33:09 +0300
commitf2581bf7f968d7dce02f52f89a19b3afad3b5824 (patch)
tree45ad0da63cb2e3b710023a7a8ceb310e15f58ab3
parent2b4a5550bc53e612fefc9ca5ba122eca1c276ad0 (diff)
fastest
-rw-r--r--src/amun/gpu/mblas/matrix_functions.h28
1 files changed, 14 insertions, 14 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h
index c162910f..bd791797 100644
--- a/src/amun/gpu/mblas/matrix_functions.h
+++ b/src/amun/gpu/mblas/matrix_functions.h
@@ -151,16 +151,16 @@ __global__ void gBroadcast(Functor functor,
const MatrixWrapper<float> in2Wrap,
const VectorWrapper<uint> batchMappingWrap)
{
- //size_t srcSize = outWrap.dim(0);
- //size_t inRows = in2Wrap.dim(0);
- //size_t cols = in1Wrap.dim(1);
+ uint srcSize = outWrap.dim(0);
+ uint inRows = in2Wrap.dim(0);
+ uint cols = in1Wrap.dim(1);
int id = threadIdx.x + blockIdx.x * blockDim.x;
if (id < outWrap.size()) {
+ /*
uint indices[SHAPE_SIZE];
outWrap.id2Indices(id, indices);
- /*
int row = id / cols; // len * batch for in1
int srcId = row % srcSize; // source pos for in1
@@ -171,25 +171,25 @@ __global__ void gBroadcast(Functor functor,
in2Wrap(batchMappingIdx, indices[1], 0, 0) );
*/
- //int row = id / cols;
- //int stateIdx = id % cols;
-
- //int beamIdx = row / srcSize;
- //int srcId = row % srcSize;
+ uint row = id / cols;
+ uint stateIdx = id % cols;
+ uint beamIdx = row / srcSize;
+ uint srcId = row % srcSize;
+ uint batchIdx = batchMappingWrap[ beamIdx ];
- uint batchIdx = batchMappingWrap[ indices[2] ];
+ //uint batchIdx = batchMappingWrap[ indices[2] ];
//assert(srcId == indices[0]);
//assert(stateIdx == indices[1]);
//assert(beamIdx == indices[2]);
//assert(0 == indices[3]);
- //outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx],
- // in2Wrap[beamIdx * cols + stateIdx]);
+ outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx],
+ in2Wrap[beamIdx * cols + stateIdx]);
//outWrap[id] = functor(in1Wrap(indices[0], indices[1], 0, batchIdx),
// in2Wrap(indices[2], indices[1], 0, 0));
- outWrap(indices[0], indices[1], indices[2], 0) = functor(in1Wrap(indices[0], indices[1], 0, batchIdx),
- in2Wrap(indices[2], indices[1], 0, 0));
+ //outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx),
+ // in2Wrap(beamIdx, stateIdx, 0, 0));
}
}