Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-01 23:47:30 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-01 23:47:30 +0300
commit0421de9efaaf9168c8f6723186f54fead2b294b9 (patch)
tree6e3b90df61880c4eadae1bfefc5e1e6fc9d1285f
parent3391804806958803f4769122d2375712f77dc979 (diff)
trying to use proper dim in broadcast. Gives different results
-rw-r--r--src/amun/gpu/mblas/matrix_functions.h4
-rw-r--r--src/amun/half/mblas/matrix_functions.h3
2 files changed, 7 insertions, 0 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h
index 39e981c3..06786612 100644
--- a/src/amun/gpu/mblas/matrix_functions.h
+++ b/src/amun/gpu/mblas/matrix_functions.h
@@ -181,6 +181,8 @@ __global__ void gBroadcast(Functor functor,
outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx],
in2Wrap[beamIdx * cols + stateIdx]);
+ //outWrap(beamIdx, stateIdx, srcId, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx),
+ // in2Wrap(beamIdx, stateIdx, 0, 0));
}
}
@@ -192,6 +194,7 @@ Matrix& Broadcast(Functor functor,
const mblas::Vector<uint>& batchMapping,
size_t srcSize)
{
+ BEGIN_TIMER("Broadcast");
size_t sumOfBeamSizes = in2.dim(0);
//size_t rows = srcSize * sumOfBeamSizes;
@@ -225,6 +228,7 @@ Matrix& Broadcast(Functor functor,
HANDLE_ERROR(cudaDeviceSynchronize());
*/
+ PAUSE_TIMER("Broadcast");
return out;
}
diff --git a/src/amun/half/mblas/matrix_functions.h b/src/amun/half/mblas/matrix_functions.h
index 4e738e60..ebc5e728 100644
--- a/src/amun/half/mblas/matrix_functions.h
+++ b/src/amun/half/mblas/matrix_functions.h
@@ -223,6 +223,9 @@ __global__ void gBroadcast(Functor functor,
outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx],
in2Wrap[beamIdx * cols + stateIdx]);
+ //outWrap(beamIdx, stateIdx, srcId, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx),
+ // in2Wrap(beamIdx, stateIdx, 0, 0));
+
}
}