Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-05-06 19:03:07 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-05-06 19:03:07 +0300
commita30f331556294823eca4ab173c1b63d193e19d19 (patch)
treeaf5fa1b0f81790ed13f19f47d8609f96ee226bf5
parent3487667e680aae9e034f751d4fb5a059db0d2577 (diff)
Ingoing work on device vectordevice_vector
-rw-r--r--src/amun/common/history.cpp1
-rwxr-xr-xsrc/amun/common/search.cpp19
-rw-r--r--src/amun/common/sentences.h2
-rw-r--r--src/amun/common/soft_alignment.h5
-rwxr-xr-xsrc/amun/common/translation_task.cpp18
-rw-r--r--src/amun/gpu/decoder/best_hyps.h5
-rw-r--r--src/amun/gpu/dl4mt/decoder.h56
-rw-r--r--src/amun/gpu/mblas/device_vector.h31
-rw-r--r--src/amun/gpu/mblas/matrix.h85
-rw-r--r--src/amun/gpu/mblas/matrix_functions.cu12
-rw-r--r--src/amun/gpu/mblas/matrix_functions.h19
-rw-r--r--src/amun/gpu/mblas/thrust_functions.h6
12 files changed, 156 insertions, 103 deletions
diff --git a/src/amun/common/history.cpp b/src/amun/common/history.cpp
index c6b9db12..08aa3a14 100644
--- a/src/amun/common/history.cpp
+++ b/src/amun/common/history.cpp
@@ -69,6 +69,7 @@ Histories::Histories(const Sentences& sentences, bool normalize)
{
for (const auto& sentence : sentences) {
coll_.emplace_back(new History(sentence->GetLineNum(), normalize));
+ std::cerr << "SIZE: " << size() << std::endl;
}
}
diff --git a/src/amun/common/search.cpp b/src/amun/common/search.cpp
index 3803fd6f..62007e89 100755
--- a/src/amun/common/search.cpp
+++ b/src/amun/common/search.cpp
@@ -6,6 +6,10 @@
#include "common/filter.h"
#include "common/base_matrix.h"
+#ifdef CUDA
+#include <cuda.h>
+#endif
+
using namespace std;
namespace amunmt {
@@ -22,7 +26,7 @@ Search::~Search()
{
#ifdef CUDA
if (deviceInfo_.deviceType == GPUDevice) {
- cudaSetDevice(deviceInfo_.deviceId);
+ // cudaSetDevice(deviceInfo_.deviceId);
}
#endif
}
@@ -30,7 +34,6 @@ const DeviceInfo& Search::GetDeviceInfo() const
{
return deviceInfo_;
}
-
const std::vector<ScorerPtr>& Search::GetScorers() const
{
return scorers_;
@@ -102,9 +105,11 @@ void Search::PreProcess(
void Search::Encode(const Sentences& sentences, States& states) {
for (size_t i = 0; i < scorers_.size(); i++) {
- Scorer& scorer = *scorers_[i];
+ Scorer &scorer = *scorers_[i];
+ std::cerr << "Set Source" << std::endl;
scorer.SetSource(sentences);
+ std::cerr << "BEgin Sentencs State" << std::endl;
scorer.BeginSentenceState(*states[i], sentences.size());
}
}
@@ -157,6 +162,7 @@ bool Search::Decode(
const State &state = *states[i];
State &nextState = *nextStates[i];
+ std::cerr << "Decode" << std::endl;
scorer.Decode(state, nextState, beamSizes);
}
@@ -203,14 +209,20 @@ bool Search::CalcBeam(
bool returnAlignment = god.Get<bool>("return-alignment");
size_t batchSize = sentences.size();
+ std::cerr << "Calc Beam" << std::endl;
bestHyps_->CalcBeam(god, prevHyps, scorers_, filterIndices_, returnAlignment, beams, beamSizes);
+ std::cerr << "DONE" << std::endl;
+ std::cerr << "ADD to historues" << std::endl;
+ std::cerr << beams.size() << std::endl;
+ std::cerr << histories->size() << " " << sentences.size() << std::endl;
for (size_t i = 0; i < batchSize; ++i) {
if (!beams[i].empty()) {
histories->at(i)->Add(beams[i], histories->at(i)->size() == 3 * sentences.at(i)->GetWords().size());
}
}
+ std::cerr << "Surivors" << std::endl;
for (size_t batchID = 0; batchID < batchSize; ++batchID) {
for (auto& h : beams[batchID]) {
if (h->GetWord() != EOS_ID) {
@@ -225,6 +237,7 @@ bool Search::CalcBeam(
return false;
}
+ std::cerr << "Assemble" << std::endl;
for (size_t i = 0; i < scorers_.size(); i++) {
scorers_[i]->AssembleBeamState(*nextStates[i], survivors, *states[i]);
}
diff --git a/src/amun/common/sentences.h b/src/amun/common/sentences.h
index 7cb1838d..281f24da 100644
--- a/src/amun/common/sentences.h
+++ b/src/amun/common/sentences.h
@@ -23,7 +23,7 @@ class Sentences {
}
auto end() const -> decltype(coll_.cend()) {
- return coll_.begin();
+ return coll_.cend();
}
SentencePtr at(size_t id) const {
diff --git a/src/amun/common/soft_alignment.h b/src/amun/common/soft_alignment.h
index 41a808bc..98269003 100644
--- a/src/amun/common/soft_alignment.h
+++ b/src/amun/common/soft_alignment.h
@@ -3,11 +3,6 @@
#include <vector>
#include <memory>
-#ifdef CUDA
-#include <thrust/host_vector.h>
-using SoftAlignment = thrust::host_vector<float>;
-#else
using SoftAlignment = std::vector<float>;
-#endif
using SoftAlignmentPtr = std::shared_ptr<SoftAlignment>;
diff --git a/src/amun/common/translation_task.cpp b/src/amun/common/translation_task.cpp
index 015ae510..3059fe0a 100755
--- a/src/amun/common/translation_task.cpp
+++ b/src/amun/common/translation_task.cpp
@@ -31,13 +31,13 @@ std::shared_ptr<Histories> TranslationTask(const God &god, std::shared_ptr<Sente
//cerr << "histories=" << histories->size() << endl;
return histories;
}
-#ifdef CUDA
- catch(thrust::system_error &e)
- {
- std::cerr << "CUDA error during some_function: " << e.what() << std::endl;
- abort();
- }
-#endif
+// #ifdef CUDA
+ // catch(thrust::system_error &e)
+ // {
+ // std::cerr << "CUDA error during some_function: " << e.what() << std::endl;
+ // abort();
+ // }
+// #endif
catch(std::bad_alloc &e)
{
std::cerr << "Bad memory allocation during some_function: " << e.what() << std::endl;
@@ -48,9 +48,9 @@ std::shared_ptr<Histories> TranslationTask(const God &god, std::shared_ptr<Sente
std::cerr << "Runtime error during some_function: " << e.what() << std::endl;
abort();
}
- catch(...)
+ catch(std::exception& e)
{
- std::cerr << "Some other kind of error during some_function" << std::endl;
+ std::cerr << "Some other kind of error during some_function"<< e.what() << std::endl;
abort();
}
diff --git a/src/amun/gpu/decoder/best_hyps.h b/src/amun/gpu/decoder/best_hyps.h
index e15bacd0..a8eeee43 100644
--- a/src/amun/gpu/decoder/best_hyps.h
+++ b/src/amun/gpu/decoder/best_hyps.h
@@ -74,6 +74,7 @@ class BestHyps : public BestHypsBase
std::vector<size_t>& beamSizes
)
{
+ std::cerr << "CALC BEAM" << std::endl;
using namespace mblas;
mblas::Matrix& Probs = static_cast<mblas::Matrix&>(scorers[0]->GetProbs());
@@ -108,7 +109,10 @@ class BestHyps : public BestHypsBase
std::vector<float> bestCosts;
std::vector<unsigned> bestKeys;
+ std::cerr << "Find BEST" << std::endl;
FindBests(beamSizes, Probs, bestCosts, bestKeys, isFirst);
+ std::cerr << "Find BEST DONE" << std::endl;
+ std::cerr << "Find BEST DONE" << bestCosts.size() << " " << bestKeys.size() << std::endl;
std::vector<HostVector<float>> breakDowns;
bool doBreakdown = god.Get<bool>("n-best");
@@ -143,6 +147,7 @@ class BestHyps : public BestHypsBase
float cost = bestCosts[i];
HypothesisPtr hyp;
+ std::cerr << "ADDING HYPS" << std::endl;
if (returnAlignment) {
hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost,
GetAlignments(scorers, hypIndex)));
diff --git a/src/amun/gpu/dl4mt/decoder.h b/src/amun/gpu/dl4mt/decoder.h
index 61f58a64..172cef91 100644
--- a/src/amun/gpu/dl4mt/decoder.h
+++ b/src/amun/gpu/dl4mt/decoder.h
@@ -117,9 +117,12 @@ class Decoder {
void Init(const mblas::Matrix& SourceContext) {
using namespace mblas;
+ std::cerr << "INIT PROD" << std::endl;
Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_);
if (w_.Gamma_1_) {
+ std::cerr << "INIT NORM" << std::endl;
Normalization(SCU_, SCU_, w_.Gamma_1_, w_.B_, 1e-9);
+ std::cerr << "INIT NORM DONE" << std::endl;
}
}
@@ -137,52 +140,63 @@ class Decoder {
batchMapping[k++] = i;
}
}
+ for (auto i : batchMapping) std::cerr << i << " ";
+ std::cerr << std::endl;
+ std::cerr << "COPY" << std::endl;
mblas::copy(batchMapping.data(),
batchMapping.size(),
dBatchMapping_.data(),
cudaMemcpyHostToDevice);
const size_t srcSize = mapping.size() / beamSizes.size();
-
+ std::cerr << srcSize << std::endl;
+
+ std::cerr << "TEMP2: " << std::endl;
+ Temp2_.debugDim();
+ std::cerr << "HP: " << std::endl;
+ HiddenState.debugDim();
+ std::cerr << "WW: " << std::endl;
+ w_.W_.debugDim();
+ // std::cerr << "HS: " << HiddenState.debugDim() << std::endl;
+ // std::cerr << "WW: " << w_.W_.debugDim() << std::endl;
+ std::cerr << "PROD" << std::endl;
Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
+ std::cerr << "PROD DONE" << std::endl;
if (w_.Gamma_2_) {
+ std::cerr << "NORM" << std::endl;
Normalization(Temp2_, Temp2_, w_.Gamma_2_, 1e-9);
} else {
+ std::cerr << "BROD" << std::endl;
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
}
+ std::cerr << "COPY" << std::endl;
Copy(Temp1_, SCU_);
- //std::cerr << std::endl;
- //std::cerr << "batchMapping=" << batchMapping.size() << std::endl;
- //std::cerr << "SCU_=" << SCU_.Debug() << std::endl;
- //std::cerr << "1Temp1_=" << Temp1_.Debug() << std::endl;
- //std::cerr << "Temp2_=" << Temp2_.Debug() << std::endl;
-
+ std::cerr << "BROD" << std::endl;
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize);
- //std::cerr << "2Temp1_=" << Temp1_.Debug() << std::endl;
+ std::cerr << "RESHAPE" << std::endl;
Temp1_.Reshape2D();
- //std::cerr << "w_.V_=" << w_.V_.Debug() << std::endl;
- //std::cerr << "3Temp1_=" << Temp1_.Debug() << std::endl;
-
+ std::cerr << "PROD" << std::endl;
Prod(A_, w_.V_, Temp1_, false, true);
size_t rows1 = SourceContext.dim(0);
size_t rows2 = HiddenState.dim(0);
- //std::cerr << "1A_=" << A_.Debug() << std::endl;
+ std::cerr << "RESHAPE" << std::endl;
A_.Reshape(rows2, srcSize, 1, 1); // due to broadcasting above
- //std::cerr << "2A_=" << A_.Debug() << std::endl;
+ std::cerr << "SOFTMAX" << std::endl;
mblas::Softmax(A_, dBatchMapping_, mapping, srcSize);
+ std::cerr << "RESIZE" << std::endl;
AlignedSourceContext.Resize(A_.dim(0), SourceContext.dim(1));
- mblas::WeightedMean(AlignedSourceContext, A_, SourceContext, dBatchMapping_);
- //std::cerr << "AlignedSourceContext=" << AlignedSourceContext.Debug() << std::endl;
+ std::cerr << "WEIGHTED MEAN" << std::endl;
+ mblas::WeightedMean(AlignedSourceContext, A_, SourceContext, dBatchMapping_);
}
void GetAttention(mblas::Matrix& Attention) {
@@ -312,11 +326,14 @@ class Decoder {
const DeviceVector<int>& mapping,
const std::vector<size_t>& beamSizes) {
+ std::cerr << "Get Hidden State" << std::endl;
GetHiddenState(HiddenState_, State, Embeddings);
+ std::cerr << "Get Aligned Source Context" << std::endl;
GetAlignedSourceContext(AlignedSourceContext_, HiddenState_, SourceContext, mapping, beamSizes);
+ std::cerr << "Get Next State" << std::endl;
GetNextState(NextState, HiddenState_, AlignedSourceContext_);
+ std::cerr << "Get Get Probs" << std::endl;
GetProbs(NextState, Embeddings, AlignedSourceContext_);
-
}
mblas::Matrix& GetProbs() {
@@ -327,12 +344,17 @@ class Decoder {
const mblas::Matrix& SourceContext,
size_t batchSize,
const DeviceVector<int>& batchMapping) {
+ std::cerr << "EMPTY STATE" << std::endl;
rnn1_.InitializeState(State, SourceContext, batchSize, batchMapping);
+ std::cerr << "ALIGN INItk" << std::endl;
alignment_.Init(SourceContext);
+ std::cerr << "ALIGN INItk DONE" << std::endl;
}
void EmptyEmbedding(mblas::Matrix& Embedding, size_t batchSize = 1) {
- Embedding.Clear();
+ std::cerr << "BATCH SIZE: " << batchSize << std::endl;
+ // Embedding.Clear();
+ std::cerr << embeddings_.GetCols() << std::endl;
Embedding.Resize(batchSize, embeddings_.GetCols());
mblas::Fill(Embedding, 0);
}
diff --git a/src/amun/gpu/mblas/device_vector.h b/src/amun/gpu/mblas/device_vector.h
index 2434cf52..9d3708a7 100644
--- a/src/amun/gpu/mblas/device_vector.h
+++ b/src/amun/gpu/mblas/device_vector.h
@@ -70,17 +70,22 @@ class device_vector
if (newSize > realSize_) {
if (data_ == nullptr) {
HANDLE_ERROR( cudaMalloc((void**)&data_, newSize * sizeof(T)) );
+ realSize_ = newSize;
+ size_ = newSize;
} else {
- T* newData_;
- HANDLE_ERROR( cudaMalloc((void**)&newData_, newSize * sizeof(T)) );
- HANDLE_ERROR( cudaMemcpyAsync(
- newData_,
- data_,
- size_ * sizeof(T),
- cudaMemcpyDeviceToDevice,
- CudaStreamHandler::GetStream())
- );
- HANDLE_ERROR( cudaFree(data_) );
+ T* newData;
+ HANDLE_ERROR( cudaMalloc((void**)&newData, newSize * sizeof(T)) );
+ HANDLE_ERROR( cudaMemcpyAsync(
+ newData,
+ data_,
+ size_ * sizeof(T),
+ cudaMemcpyDeviceToDevice,
+ CudaStreamHandler::GetStream())
+ );
+ HANDLE_ERROR( cudaFree(data_) );
+ data_ = newData;
+ realSize_ = newSize;
+ size_ = newSize;
}
}
size_ = newSize;
@@ -98,6 +103,12 @@ class device_vector
return data_;
}
+ ~device_vector() {
+ if (data_) {
+ HANDLE_ERROR( cudaFree(data_) );
+ }
+ }
+
protected:
T* data_;
size_t size_;
diff --git a/src/amun/gpu/mblas/matrix.h b/src/amun/gpu/mblas/matrix.h
index 75937a27..ac6391c8 100644
--- a/src/amun/gpu/mblas/matrix.h
+++ b/src/amun/gpu/mblas/matrix.h
@@ -2,7 +2,7 @@
#include <memory>
#include <sstream>
-// #include <thrust/execution_policy.h>
+#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include "common/exception.h"
@@ -18,26 +18,28 @@ using namespace thrust::placeholders;
float Sum(const float *data, size_t count);
+
template <typename T>
class TMatrix : public BaseMatrix {
public:
typedef T value_type;
TMatrix()
- : rows_(0)
- , cols_(0)
- , beam_(0)
- , batches_(0)
- , arrSize_(0)
- , data_(nullptr)
- {}
+ : rows_(0)
+ , cols_(0)
+ , beam_(0)
+ , batches_(0)
+ , arrSize_(0)
+ , data_(nullptr)
+ {
+ }
TMatrix(size_t rows, size_t cols, size_t beam, size_t batches, bool zero = false)
- : rows_(rows)
- , cols_(cols)
- , beam_(1)
- , batches_(1)
- , arrSize_(size())
+ : rows_(rows)
+ , cols_(cols)
+ , beam_(1)
+ , batches_(1)
+ , arrSize_(size())
{
HANDLE_ERROR( cudaMalloc((void**)&data_, arrSize_ * sizeof(T)) );
if (zero) {
@@ -46,7 +48,7 @@ class TMatrix : public BaseMatrix {
}
TMatrix(TMatrix&& m)
- : TMatrix()
+ : TMatrix()
{
swap(m);
}
@@ -59,6 +61,10 @@ class TMatrix : public BaseMatrix {
, arrSize_(m.arrSize_)
{
HANDLE_ERROR( cudaMalloc((void**)&data_, arrSize_ * sizeof(T)) );
+ std::cerr << m.data_ << std::endl;
+ std::cerr << data_ << std::endl;
+
+ std::cerr << "COPY: " << size() << " " << m.size() << std::endl;
HANDLE_ERROR( cudaMemcpyAsync(
data_,
m.data_,
@@ -74,33 +80,40 @@ class TMatrix : public BaseMatrix {
virtual size_t dim(size_t i) const
{
- switch (i) {
- case 0: return rows_;
- case 1: return cols_;
- case 2: return beam_;
- case 3: return batches_;
- default: abort();
- }
+ switch (i) {
+ case 0: return rows_;
+ case 1: return cols_;
+ case 2: return beam_;
+ case 3: return batches_;
+ default:
+ abort();
+ }
}
void Resize(size_t rows, size_t cols, size_t beam = 1, size_t batches = 1) {
size_t newSize = cols * rows * beam * batches;
if (data_) {
if (newSize > arrSize_) {
- T *newData;
+ T* newData;
HANDLE_ERROR( cudaMalloc((void**)&newData, newSize * sizeof(T)) );
+ std::cerr << newData << std::endl;
+ std::cerr << data_ << std::endl;
+
+ std::cerr << "RESIZE: " << size() << " " << newSize << std::endl;
HANDLE_ERROR( cudaMemcpyAsync(
newData,
data_,
size() * sizeof(T),
cudaMemcpyDeviceToDevice,
- CudaStreamHandler::GetStream()) );
+ CudaStreamHandler::GetStream())
+ );
HANDLE_ERROR(cudaFree(data_));
data_ = newData;
arrSize_ = newSize;
- } else if (rows == 0 || cols == 0) {
+ }
+ else if (rows == 0 || cols == 0) {
Clear();
}
}
@@ -137,12 +150,12 @@ class TMatrix : public BaseMatrix {
std::stringstream strm;
strm << BaseMatrix::Debug(detailed) << " ";
strm << data_ << " "
- << arrSize_ << " "
+ << arrSize_ << " "
<< std::flush;
if (detailed) {
- // float sum = Sum(data(), size());
- // strm << "size=" << size() << " sum=" << sum << std::flush;
+ float sum = Sum(data(), size());
+ strm << "size=" << size() << " sum=" << sum << std::flush;
}
return strm.str();
@@ -158,18 +171,26 @@ class TMatrix : public BaseMatrix {
arrSize_ = 0;
}
- virtual value_type* data() {
+ value_type* data() {
return data_;
}
- virtual const value_type* data() const {
+ const value_type* data() const {
return data_;
}
- virtual size_t size() const {
+ size_t size() const {
+ // return data_.size();
return cols_ * rows_ * beam_ * batches_;
}
+ void debugDim() const {
+ std::cerr << "Rows: " << rows_ << std::endl;
+ std::cerr << "Cols: " << cols_ << std::endl;
+ std::cerr << "Beam: " << beam_ << std::endl;
+ std::cerr << "Bathces: " << batches_ << std::endl;
+ }
+
void swap(TMatrix &other)
{
std::swap(rows_, other.rows_);
@@ -185,13 +206,13 @@ class TMatrix : public BaseMatrix {
return (int)size() != 0;
}
- protected:
+ private:
size_t rows_;
size_t cols_;
size_t beam_;
size_t batches_;
size_t arrSize_;
- T* data_;
+ T *data_;
};
typedef TMatrix<float> Matrix;
diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu
index 863ac667..33f90866 100644
--- a/src/amun/gpu/mblas/matrix_functions.cu
+++ b/src/amun/gpu/mblas/matrix_functions.cu
@@ -44,7 +44,7 @@ void Mean(Matrix& Out, const Matrix& In, const DeviceVector<int>& mapping) {
int nBlocks = (stateLength / 512) + ((stateLength % 512 == 0) ? 0 : 1);
gMean<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>>
- (Out.data(), In.data(), mapping.data(),
+ (Out.data(), In.data(), thrust::raw_pointer_cast(mapping.data()),
batchNum, sentenceLength, stateLength);
}
@@ -75,7 +75,7 @@ void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const Dev
int nBlocks = (Out.size() / 512) + ((Out.size() % 512 == 0) ? 0 : 1);
gWeightedMean<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>>
- (Out.data(), Weights.data(), In.data(), mapping.data(),
+ (Out.data(), Weights.data(), In.data(), thrust::raw_pointer_cast(mapping.data()),
numRows, numCols, Weights.dim(1));
}
@@ -203,7 +203,7 @@ Matrix& Assemble(Matrix& Out,
const Matrix& In,
const DeviceVector<size_t>& indeces) {
Out.Resize(indeces.size(), In.dim(1));
- CopyRows(Out, In, indeces.data(), indeces.size());
+ CopyRows(Out, In, thrust::raw_pointer_cast(indeces.data()), indeces.size());
return Out;
}
@@ -373,8 +373,8 @@ Matrix& Softmax(Matrix& Out, const DeviceVector<int>& batchIds, const DeviceVect
gSoftMax<<<blocks, threads, shared, CudaStreamHandler::GetStream()>>>
(Out.data(), Out.dim(0), Out.dim(1),
- batchIds.data(), batchIds.size(),
- srcMapping.data(), srcSize);
+ thrust::raw_pointer_cast(batchIds.data()), batchIds.size(),
+ thrust::raw_pointer_cast(srcMapping.data()), srcSize);
return Out;
}
@@ -512,7 +512,7 @@ void MapMatrix(Matrix& state, const DeviceVector<int>& mapping, size_t i) {
int numBlocks = (state.size() / numThreads) + 1;
float* d_in = state.data();
- const int* d_mapping = mapping.data();
+ const int* d_mapping = thrust::raw_pointer_cast(mapping.data());
gMapMatrix<<<numBlocks, numThreads, 0, CudaStreamHandler::GetStream()>>>
(d_in, batchSize, stateLength, sentenceLength, d_mapping, i);
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h
index d129afc1..f9a85bdb 100644
--- a/src/amun/gpu/mblas/matrix_functions.h
+++ b/src/amun/gpu/mblas/matrix_functions.h
@@ -5,7 +5,6 @@
#include <cmath>
#include <cublas_v2.h>
-#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <iostream>
@@ -31,8 +30,6 @@ void Debug(const M& m, size_t pos = 0, size_t l = 8) {
std::cerr << m.GetVec()[i * m.dim(1) + j] << " ";
}
std::cerr << std::endl;
- // if(i == 4)
- // break;
}
}
@@ -180,20 +177,8 @@ Matrix& Broadcast(Functor functor, Matrix& Out, const Matrix& In, const DeviceVe
int threads = 512;
int blocks = (Temp.size() / threads) + 1;
- /*
- std::cerr << "\nTemp=" << Temp.Debug() << std::endl;
- std::cerr << "Out=" << Out.Debug() << std::endl;
- std::cerr << "In=" << In.Debug() << std::endl;
- std::cerr << "srcSize=" << srcSize << std::endl;
-
- std::cerr << "batchMapping=" << batchMapping.size() << ":";
- for (size_t i = 0; i < batchMapping.size(); ++i) {
- std::cerr << batchMapping[i] << " ";
- }
- std::cerr << std::endl;
- */
gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>>
- (functor, d_out, d_in1, d_in2, srcSize, batchMapping.size(), cols, thrust::raw_pointer_cast(batchMapping.data()),
+ (functor, d_out, d_in1, d_in2, srcSize, batchMapping.size(), cols, batchMapping.data(),
batchMapping.size(), Temp.size(), Out.size(), In.size(), In.dim(0)
);
@@ -238,7 +223,7 @@ Matrix& BroadcastVecColumn(Functor functor, Matrix& Out, const DeviceVector<floa
size_t cols = Out.dim(1);
float* d_out = Out.data();
- const float* d_in = thrust::raw_pointer_cast(In.data());
+ const float* d_in = In.data();
int threads = std::min(MAX_THREADS, (int)cols);
int blocks = cols / threads + (cols % threads != 0);
diff --git a/src/amun/gpu/mblas/thrust_functions.h b/src/amun/gpu/mblas/thrust_functions.h
index 1760fdcd..e0266fb3 100644
--- a/src/amun/gpu/mblas/thrust_functions.h
+++ b/src/amun/gpu/mblas/thrust_functions.h
@@ -1,8 +1,8 @@
#pragma once
#include <cmath>
-#include <cublas_v2.h>
-#include <thrust/device_vector.h>
+#include <cublas_v2.h>
+// #include <thrust/device_vector.h>
#include <thrust/functional.h>
@@ -12,7 +12,7 @@ namespace thrust
{
namespace functional
{
-
+
template<typename T>
struct unary_exp : public thrust::unary_function<T,T> {
__host__ __device__