diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-05-06 19:03:07 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-05-06 19:03:07 +0300 |
commit | a30f331556294823eca4ab173c1b63d193e19d19 (patch) | |
tree | af5fa1b0f81790ed13f19f47d8609f96ee226bf5 | |
parent | 3487667e680aae9e034f751d4fb5a059db0d2577 (diff) |
Ingoing work on device vectordevice_vector
-rw-r--r-- | src/amun/common/history.cpp | 1 | ||||
-rwxr-xr-x | src/amun/common/search.cpp | 19 | ||||
-rw-r--r-- | src/amun/common/sentences.h | 2 | ||||
-rw-r--r-- | src/amun/common/soft_alignment.h | 5 | ||||
-rwxr-xr-x | src/amun/common/translation_task.cpp | 18 | ||||
-rw-r--r-- | src/amun/gpu/decoder/best_hyps.h | 5 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/decoder.h | 56 | ||||
-rw-r--r-- | src/amun/gpu/mblas/device_vector.h | 31 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix.h | 85 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.cu | 12 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.h | 19 | ||||
-rw-r--r-- | src/amun/gpu/mblas/thrust_functions.h | 6 |
12 files changed, 156 insertions, 103 deletions
diff --git a/src/amun/common/history.cpp b/src/amun/common/history.cpp index c6b9db12..08aa3a14 100644 --- a/src/amun/common/history.cpp +++ b/src/amun/common/history.cpp @@ -69,6 +69,7 @@ Histories::Histories(const Sentences& sentences, bool normalize) { for (const auto& sentence : sentences) { coll_.emplace_back(new History(sentence->GetLineNum(), normalize)); + std::cerr << "SIZE: " << size() << std::endl; } } diff --git a/src/amun/common/search.cpp b/src/amun/common/search.cpp index 3803fd6f..62007e89 100755 --- a/src/amun/common/search.cpp +++ b/src/amun/common/search.cpp @@ -6,6 +6,10 @@ #include "common/filter.h" #include "common/base_matrix.h" +#ifdef CUDA +#include <cuda.h> +#endif + using namespace std; namespace amunmt { @@ -22,7 +26,7 @@ Search::~Search() { #ifdef CUDA if (deviceInfo_.deviceType == GPUDevice) { - cudaSetDevice(deviceInfo_.deviceId); + // cudaSetDevice(deviceInfo_.deviceId); } #endif } @@ -30,7 +34,6 @@ const DeviceInfo& Search::GetDeviceInfo() const { return deviceInfo_; } - const std::vector<ScorerPtr>& Search::GetScorers() const { return scorers_; @@ -102,9 +105,11 @@ void Search::PreProcess( void Search::Encode(const Sentences& sentences, States& states) { for (size_t i = 0; i < scorers_.size(); i++) { - Scorer& scorer = *scorers_[i]; + Scorer &scorer = *scorers_[i]; + std::cerr << "Set Source" << std::endl; scorer.SetSource(sentences); + std::cerr << "BEgin Sentencs State" << std::endl; scorer.BeginSentenceState(*states[i], sentences.size()); } } @@ -157,6 +162,7 @@ bool Search::Decode( const State &state = *states[i]; State &nextState = *nextStates[i]; + std::cerr << "Decode" << std::endl; scorer.Decode(state, nextState, beamSizes); } @@ -203,14 +209,20 @@ bool Search::CalcBeam( bool returnAlignment = god.Get<bool>("return-alignment"); size_t batchSize = sentences.size(); + std::cerr << "Calc Beam" << std::endl; bestHyps_->CalcBeam(god, prevHyps, scorers_, filterIndices_, returnAlignment, beams, beamSizes); + std::cerr << "DONE" << std::endl; + std::cerr << "ADD to historues" << std::endl; + std::cerr << beams.size() << std::endl; + std::cerr << histories->size() << " " << sentences.size() << std::endl; for (size_t i = 0; i < batchSize; ++i) { if (!beams[i].empty()) { histories->at(i)->Add(beams[i], histories->at(i)->size() == 3 * sentences.at(i)->GetWords().size()); } } + std::cerr << "Surivors" << std::endl; for (size_t batchID = 0; batchID < batchSize; ++batchID) { for (auto& h : beams[batchID]) { if (h->GetWord() != EOS_ID) { @@ -225,6 +237,7 @@ bool Search::CalcBeam( return false; } + std::cerr << "Assemble" << std::endl; for (size_t i = 0; i < scorers_.size(); i++) { scorers_[i]->AssembleBeamState(*nextStates[i], survivors, *states[i]); } diff --git a/src/amun/common/sentences.h b/src/amun/common/sentences.h index 7cb1838d..281f24da 100644 --- a/src/amun/common/sentences.h +++ b/src/amun/common/sentences.h @@ -23,7 +23,7 @@ class Sentences { } auto end() const -> decltype(coll_.cend()) { - return coll_.begin(); + return coll_.cend(); } SentencePtr at(size_t id) const { diff --git a/src/amun/common/soft_alignment.h b/src/amun/common/soft_alignment.h index 41a808bc..98269003 100644 --- a/src/amun/common/soft_alignment.h +++ b/src/amun/common/soft_alignment.h @@ -3,11 +3,6 @@ #include <vector> #include <memory> -#ifdef CUDA -#include <thrust/host_vector.h> -using SoftAlignment = thrust::host_vector<float>; -#else using SoftAlignment = std::vector<float>; -#endif using SoftAlignmentPtr = std::shared_ptr<SoftAlignment>; diff --git a/src/amun/common/translation_task.cpp b/src/amun/common/translation_task.cpp index 015ae510..3059fe0a 100755 --- a/src/amun/common/translation_task.cpp +++ b/src/amun/common/translation_task.cpp @@ -31,13 +31,13 @@ std::shared_ptr<Histories> TranslationTask(const God &god, std::shared_ptr<Sente //cerr << "histories=" << histories->size() << endl; return histories; } -#ifdef CUDA - catch(thrust::system_error &e) - { - std::cerr << "CUDA error during some_function: " << e.what() << std::endl; - abort(); - } -#endif +// #ifdef CUDA + // catch(thrust::system_error &e) + // { + // std::cerr << "CUDA error during some_function: " << e.what() << std::endl; + // abort(); + // } +// #endif catch(std::bad_alloc &e) { std::cerr << "Bad memory allocation during some_function: " << e.what() << std::endl; @@ -48,9 +48,9 @@ std::shared_ptr<Histories> TranslationTask(const God &god, std::shared_ptr<Sente std::cerr << "Runtime error during some_function: " << e.what() << std::endl; abort(); } - catch(...) + catch(std::exception& e) { - std::cerr << "Some other kind of error during some_function" << std::endl; + std::cerr << "Some other kind of error during some_function"<< e.what() << std::endl; abort(); } diff --git a/src/amun/gpu/decoder/best_hyps.h b/src/amun/gpu/decoder/best_hyps.h index e15bacd0..a8eeee43 100644 --- a/src/amun/gpu/decoder/best_hyps.h +++ b/src/amun/gpu/decoder/best_hyps.h @@ -74,6 +74,7 @@ class BestHyps : public BestHypsBase std::vector<size_t>& beamSizes ) { + std::cerr << "CALC BEAM" << std::endl; using namespace mblas; mblas::Matrix& Probs = static_cast<mblas::Matrix&>(scorers[0]->GetProbs()); @@ -108,7 +109,10 @@ class BestHyps : public BestHypsBase std::vector<float> bestCosts; std::vector<unsigned> bestKeys; + std::cerr << "Find BEST" << std::endl; FindBests(beamSizes, Probs, bestCosts, bestKeys, isFirst); + std::cerr << "Find BEST DONE" << std::endl; + std::cerr << "Find BEST DONE" << bestCosts.size() << " " << bestKeys.size() << std::endl; std::vector<HostVector<float>> breakDowns; bool doBreakdown = god.Get<bool>("n-best"); @@ -143,6 +147,7 @@ class BestHyps : public BestHypsBase float cost = bestCosts[i]; HypothesisPtr hyp; + std::cerr << "ADDING HYPS" << std::endl; if (returnAlignment) { hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost, GetAlignments(scorers, hypIndex))); diff --git a/src/amun/gpu/dl4mt/decoder.h b/src/amun/gpu/dl4mt/decoder.h index 61f58a64..172cef91 100644 --- a/src/amun/gpu/dl4mt/decoder.h +++ b/src/amun/gpu/dl4mt/decoder.h @@ -117,9 +117,12 @@ class Decoder { void Init(const mblas::Matrix& SourceContext) { using namespace mblas; + std::cerr << "INIT PROD" << std::endl; Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_); if (w_.Gamma_1_) { + std::cerr << "INIT NORM" << std::endl; Normalization(SCU_, SCU_, w_.Gamma_1_, w_.B_, 1e-9); + std::cerr << "INIT NORM DONE" << std::endl; } } @@ -137,52 +140,63 @@ class Decoder { batchMapping[k++] = i; } } + for (auto i : batchMapping) std::cerr << i << " "; + std::cerr << std::endl; + std::cerr << "COPY" << std::endl; mblas::copy(batchMapping.data(), batchMapping.size(), dBatchMapping_.data(), cudaMemcpyHostToDevice); const size_t srcSize = mapping.size() / beamSizes.size(); - + std::cerr << srcSize << std::endl; + + std::cerr << "TEMP2: " << std::endl; + Temp2_.debugDim(); + std::cerr << "HP: " << std::endl; + HiddenState.debugDim(); + std::cerr << "WW: " << std::endl; + w_.W_.debugDim(); + // std::cerr << "HS: " << HiddenState.debugDim() << std::endl; + // std::cerr << "WW: " << w_.W_.debugDim() << std::endl; + std::cerr << "PROD" << std::endl; Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_); + std::cerr << "PROD DONE" << std::endl; if (w_.Gamma_2_) { + std::cerr << "NORM" << std::endl; Normalization(Temp2_, Temp2_, w_.Gamma_2_, 1e-9); } else { + std::cerr << "BROD" << std::endl; BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/); } + std::cerr << "COPY" << std::endl; Copy(Temp1_, SCU_); - //std::cerr << std::endl; - //std::cerr << "batchMapping=" << batchMapping.size() << std::endl; - //std::cerr << "SCU_=" << SCU_.Debug() << std::endl; - //std::cerr << "1Temp1_=" << Temp1_.Debug() << std::endl; - //std::cerr << "Temp2_=" << Temp2_.Debug() << std::endl; - + std::cerr << "BROD" << std::endl; Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize); - //std::cerr << "2Temp1_=" << Temp1_.Debug() << std::endl; + std::cerr << "RESHAPE" << std::endl; Temp1_.Reshape2D(); - //std::cerr << "w_.V_=" << w_.V_.Debug() << std::endl; - //std::cerr << "3Temp1_=" << Temp1_.Debug() << std::endl; - + std::cerr << "PROD" << std::endl; Prod(A_, w_.V_, Temp1_, false, true); size_t rows1 = SourceContext.dim(0); size_t rows2 = HiddenState.dim(0); - //std::cerr << "1A_=" << A_.Debug() << std::endl; + std::cerr << "RESHAPE" << std::endl; A_.Reshape(rows2, srcSize, 1, 1); // due to broadcasting above - //std::cerr << "2A_=" << A_.Debug() << std::endl; + std::cerr << "SOFTMAX" << std::endl; mblas::Softmax(A_, dBatchMapping_, mapping, srcSize); + std::cerr << "RESIZE" << std::endl; AlignedSourceContext.Resize(A_.dim(0), SourceContext.dim(1)); - mblas::WeightedMean(AlignedSourceContext, A_, SourceContext, dBatchMapping_); - //std::cerr << "AlignedSourceContext=" << AlignedSourceContext.Debug() << std::endl; + std::cerr << "WEIGHTED MEAN" << std::endl; + mblas::WeightedMean(AlignedSourceContext, A_, SourceContext, dBatchMapping_); } void GetAttention(mblas::Matrix& Attention) { @@ -312,11 +326,14 @@ class Decoder { const DeviceVector<int>& mapping, const std::vector<size_t>& beamSizes) { + std::cerr << "Get Hidden State" << std::endl; GetHiddenState(HiddenState_, State, Embeddings); + std::cerr << "Get Aligned Source Context" << std::endl; GetAlignedSourceContext(AlignedSourceContext_, HiddenState_, SourceContext, mapping, beamSizes); + std::cerr << "Get Next State" << std::endl; GetNextState(NextState, HiddenState_, AlignedSourceContext_); + std::cerr << "Get Get Probs" << std::endl; GetProbs(NextState, Embeddings, AlignedSourceContext_); - } mblas::Matrix& GetProbs() { @@ -327,12 +344,17 @@ class Decoder { const mblas::Matrix& SourceContext, size_t batchSize, const DeviceVector<int>& batchMapping) { + std::cerr << "EMPTY STATE" << std::endl; rnn1_.InitializeState(State, SourceContext, batchSize, batchMapping); + std::cerr << "ALIGN INItk" << std::endl; alignment_.Init(SourceContext); + std::cerr << "ALIGN INItk DONE" << std::endl; } void EmptyEmbedding(mblas::Matrix& Embedding, size_t batchSize = 1) { - Embedding.Clear(); + std::cerr << "BATCH SIZE: " << batchSize << std::endl; + // Embedding.Clear(); + std::cerr << embeddings_.GetCols() << std::endl; Embedding.Resize(batchSize, embeddings_.GetCols()); mblas::Fill(Embedding, 0); } diff --git a/src/amun/gpu/mblas/device_vector.h b/src/amun/gpu/mblas/device_vector.h index 2434cf52..9d3708a7 100644 --- a/src/amun/gpu/mblas/device_vector.h +++ b/src/amun/gpu/mblas/device_vector.h @@ -70,17 +70,22 @@ class device_vector if (newSize > realSize_) { if (data_ == nullptr) { HANDLE_ERROR( cudaMalloc((void**)&data_, newSize * sizeof(T)) ); + realSize_ = newSize; + size_ = newSize; } else { - T* newData_; - HANDLE_ERROR( cudaMalloc((void**)&newData_, newSize * sizeof(T)) ); - HANDLE_ERROR( cudaMemcpyAsync( - newData_, - data_, - size_ * sizeof(T), - cudaMemcpyDeviceToDevice, - CudaStreamHandler::GetStream()) - ); - HANDLE_ERROR( cudaFree(data_) ); + T* newData; + HANDLE_ERROR( cudaMalloc((void**)&newData, newSize * sizeof(T)) ); + HANDLE_ERROR( cudaMemcpyAsync( + newData, + data_, + size_ * sizeof(T), + cudaMemcpyDeviceToDevice, + CudaStreamHandler::GetStream()) + ); + HANDLE_ERROR( cudaFree(data_) ); + data_ = newData; + realSize_ = newSize; + size_ = newSize; } } size_ = newSize; @@ -98,6 +103,12 @@ class device_vector return data_; } + ~device_vector() { + if (data_) { + HANDLE_ERROR( cudaFree(data_) ); + } + } + protected: T* data_; size_t size_; diff --git a/src/amun/gpu/mblas/matrix.h b/src/amun/gpu/mblas/matrix.h index 75937a27..ac6391c8 100644 --- a/src/amun/gpu/mblas/matrix.h +++ b/src/amun/gpu/mblas/matrix.h @@ -2,7 +2,7 @@ #include <memory> #include <sstream> -// #include <thrust/execution_policy.h> +#include <thrust/execution_policy.h> #include <thrust/functional.h> #include "common/exception.h" @@ -18,26 +18,28 @@ using namespace thrust::placeholders; float Sum(const float *data, size_t count); + template <typename T> class TMatrix : public BaseMatrix { public: typedef T value_type; TMatrix() - : rows_(0) - , cols_(0) - , beam_(0) - , batches_(0) - , arrSize_(0) - , data_(nullptr) - {} + : rows_(0) + , cols_(0) + , beam_(0) + , batches_(0) + , arrSize_(0) + , data_(nullptr) + { + } TMatrix(size_t rows, size_t cols, size_t beam, size_t batches, bool zero = false) - : rows_(rows) - , cols_(cols) - , beam_(1) - , batches_(1) - , arrSize_(size()) + : rows_(rows) + , cols_(cols) + , beam_(1) + , batches_(1) + , arrSize_(size()) { HANDLE_ERROR( cudaMalloc((void**)&data_, arrSize_ * sizeof(T)) ); if (zero) { @@ -46,7 +48,7 @@ class TMatrix : public BaseMatrix { } TMatrix(TMatrix&& m) - : TMatrix() + : TMatrix() { swap(m); } @@ -59,6 +61,10 @@ class TMatrix : public BaseMatrix { , arrSize_(m.arrSize_) { HANDLE_ERROR( cudaMalloc((void**)&data_, arrSize_ * sizeof(T)) ); + std::cerr << m.data_ << std::endl; + std::cerr << data_ << std::endl; + + std::cerr << "COPY: " << size() << " " << m.size() << std::endl; HANDLE_ERROR( cudaMemcpyAsync( data_, m.data_, @@ -74,33 +80,40 @@ class TMatrix : public BaseMatrix { virtual size_t dim(size_t i) const { - switch (i) { - case 0: return rows_; - case 1: return cols_; - case 2: return beam_; - case 3: return batches_; - default: abort(); - } + switch (i) { + case 0: return rows_; + case 1: return cols_; + case 2: return beam_; + case 3: return batches_; + default: + abort(); + } } void Resize(size_t rows, size_t cols, size_t beam = 1, size_t batches = 1) { size_t newSize = cols * rows * beam * batches; if (data_) { if (newSize > arrSize_) { - T *newData; + T* newData; HANDLE_ERROR( cudaMalloc((void**)&newData, newSize * sizeof(T)) ); + std::cerr << newData << std::endl; + std::cerr << data_ << std::endl; + + std::cerr << "RESIZE: " << size() << " " << newSize << std::endl; HANDLE_ERROR( cudaMemcpyAsync( newData, data_, size() * sizeof(T), cudaMemcpyDeviceToDevice, - CudaStreamHandler::GetStream()) ); + CudaStreamHandler::GetStream()) + ); HANDLE_ERROR(cudaFree(data_)); data_ = newData; arrSize_ = newSize; - } else if (rows == 0 || cols == 0) { + } + else if (rows == 0 || cols == 0) { Clear(); } } @@ -137,12 +150,12 @@ class TMatrix : public BaseMatrix { std::stringstream strm; strm << BaseMatrix::Debug(detailed) << " "; strm << data_ << " " - << arrSize_ << " " + << arrSize_ << " " << std::flush; if (detailed) { - // float sum = Sum(data(), size()); - // strm << "size=" << size() << " sum=" << sum << std::flush; + float sum = Sum(data(), size()); + strm << "size=" << size() << " sum=" << sum << std::flush; } return strm.str(); @@ -158,18 +171,26 @@ class TMatrix : public BaseMatrix { arrSize_ = 0; } - virtual value_type* data() { + value_type* data() { return data_; } - virtual const value_type* data() const { + const value_type* data() const { return data_; } - virtual size_t size() const { + size_t size() const { + // return data_.size(); return cols_ * rows_ * beam_ * batches_; } + void debugDim() const { + std::cerr << "Rows: " << rows_ << std::endl; + std::cerr << "Cols: " << cols_ << std::endl; + std::cerr << "Beam: " << beam_ << std::endl; + std::cerr << "Bathces: " << batches_ << std::endl; + } + void swap(TMatrix &other) { std::swap(rows_, other.rows_); @@ -185,13 +206,13 @@ class TMatrix : public BaseMatrix { return (int)size() != 0; } - protected: + private: size_t rows_; size_t cols_; size_t beam_; size_t batches_; size_t arrSize_; - T* data_; + T *data_; }; typedef TMatrix<float> Matrix; diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu index 863ac667..33f90866 100644 --- a/src/amun/gpu/mblas/matrix_functions.cu +++ b/src/amun/gpu/mblas/matrix_functions.cu @@ -44,7 +44,7 @@ void Mean(Matrix& Out, const Matrix& In, const DeviceVector<int>& mapping) { int nBlocks = (stateLength / 512) + ((stateLength % 512 == 0) ? 0 : 1); gMean<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> - (Out.data(), In.data(), mapping.data(), + (Out.data(), In.data(), thrust::raw_pointer_cast(mapping.data()), batchNum, sentenceLength, stateLength); } @@ -75,7 +75,7 @@ void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const Dev int nBlocks = (Out.size() / 512) + ((Out.size() % 512 == 0) ? 0 : 1); gWeightedMean<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> - (Out.data(), Weights.data(), In.data(), mapping.data(), + (Out.data(), Weights.data(), In.data(), thrust::raw_pointer_cast(mapping.data()), numRows, numCols, Weights.dim(1)); } @@ -203,7 +203,7 @@ Matrix& Assemble(Matrix& Out, const Matrix& In, const DeviceVector<size_t>& indeces) { Out.Resize(indeces.size(), In.dim(1)); - CopyRows(Out, In, indeces.data(), indeces.size()); + CopyRows(Out, In, thrust::raw_pointer_cast(indeces.data()), indeces.size()); return Out; } @@ -373,8 +373,8 @@ Matrix& Softmax(Matrix& Out, const DeviceVector<int>& batchIds, const DeviceVect gSoftMax<<<blocks, threads, shared, CudaStreamHandler::GetStream()>>> (Out.data(), Out.dim(0), Out.dim(1), - batchIds.data(), batchIds.size(), - srcMapping.data(), srcSize); + thrust::raw_pointer_cast(batchIds.data()), batchIds.size(), + thrust::raw_pointer_cast(srcMapping.data()), srcSize); return Out; } @@ -512,7 +512,7 @@ void MapMatrix(Matrix& state, const DeviceVector<int>& mapping, size_t i) { int numBlocks = (state.size() / numThreads) + 1; float* d_in = state.data(); - const int* d_mapping = mapping.data(); + const int* d_mapping = thrust::raw_pointer_cast(mapping.data()); gMapMatrix<<<numBlocks, numThreads, 0, CudaStreamHandler::GetStream()>>> (d_in, batchSize, stateLength, sentenceLength, d_mapping, i); diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h index d129afc1..f9a85bdb 100644 --- a/src/amun/gpu/mblas/matrix_functions.h +++ b/src/amun/gpu/mblas/matrix_functions.h @@ -5,7 +5,6 @@ #include <cmath> #include <cublas_v2.h> -#include <thrust/execution_policy.h> #include <thrust/functional.h> #include <iostream> @@ -31,8 +30,6 @@ void Debug(const M& m, size_t pos = 0, size_t l = 8) { std::cerr << m.GetVec()[i * m.dim(1) + j] << " "; } std::cerr << std::endl; - // if(i == 4) - // break; } } @@ -180,20 +177,8 @@ Matrix& Broadcast(Functor functor, Matrix& Out, const Matrix& In, const DeviceVe int threads = 512; int blocks = (Temp.size() / threads) + 1; - /* - std::cerr << "\nTemp=" << Temp.Debug() << std::endl; - std::cerr << "Out=" << Out.Debug() << std::endl; - std::cerr << "In=" << In.Debug() << std::endl; - std::cerr << "srcSize=" << srcSize << std::endl; - - std::cerr << "batchMapping=" << batchMapping.size() << ":"; - for (size_t i = 0; i < batchMapping.size(); ++i) { - std::cerr << batchMapping[i] << " "; - } - std::cerr << std::endl; - */ gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>> - (functor, d_out, d_in1, d_in2, srcSize, batchMapping.size(), cols, thrust::raw_pointer_cast(batchMapping.data()), + (functor, d_out, d_in1, d_in2, srcSize, batchMapping.size(), cols, batchMapping.data(), batchMapping.size(), Temp.size(), Out.size(), In.size(), In.dim(0) ); @@ -238,7 +223,7 @@ Matrix& BroadcastVecColumn(Functor functor, Matrix& Out, const DeviceVector<floa size_t cols = Out.dim(1); float* d_out = Out.data(); - const float* d_in = thrust::raw_pointer_cast(In.data()); + const float* d_in = In.data(); int threads = std::min(MAX_THREADS, (int)cols); int blocks = cols / threads + (cols % threads != 0); diff --git a/src/amun/gpu/mblas/thrust_functions.h b/src/amun/gpu/mblas/thrust_functions.h index 1760fdcd..e0266fb3 100644 --- a/src/amun/gpu/mblas/thrust_functions.h +++ b/src/amun/gpu/mblas/thrust_functions.h @@ -1,8 +1,8 @@ #pragma once #include <cmath> -#include <cublas_v2.h> -#include <thrust/device_vector.h> +#include <cublas_v2.h> +// #include <thrust/device_vector.h> #include <thrust/functional.h> @@ -12,7 +12,7 @@ namespace thrust { namespace functional { - + template<typename T> struct unary_exp : public thrust::unary_function<T,T> { __host__ __device__ |