diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-16 02:05:55 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-16 02:30:41 +0300 |
commit | 0effc8d28d25fbf80e3cea34d6f5b4e42490a7f2 (patch) | |
tree | 0581ddd34c53d722bf9184fccc643cceca121c03 | |
parent | ce44cd9e287804860538dc80af0f87ae3ff8cfec (diff) |
Refactor convolution and and poolingWithMasking
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/examples/mnist/model_lenet.h | 16 | ||||
-rw-r--r-- | src/graph/expression_operators.cu | 31 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 5 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 14 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 41 | ||||
-rw-r--r-- | src/kernels/cudnn_wrappers.h | 4 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 129 | ||||
-rw-r--r-- | src/kernels/tensor_operators.h | 13 | ||||
-rw-r--r-- | src/layers/convolution.cu | 41 | ||||
-rw-r--r-- | src/layers/convolution.h | 83 |
11 files changed, 294 insertions, 84 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1cdca70c..3a2a6e7d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,7 @@ cuda_add_library(marian layers/param_initializers.cu layers/generic.cpp layers/guided_alignment.cpp + layers/convolution.cu models/model_factory.cpp rnn/attention.cu rnn/cells.cu diff --git a/src/examples/mnist/model_lenet.h b/src/examples/mnist/model_lenet.h index 07dbc148..fa5aa831 100644 --- a/src/examples/mnist/model_lenet.h +++ b/src/examples/mnist/model_lenet.h @@ -31,9 +31,19 @@ protected: // Construct hidden layers - auto conv_1 = Convolution("Conv1", 3, 3, 32)(x); - auto conv_2 = relu(Convolution("Conv2", 3, 3, 64)(conv_1)); - auto pool = max_pooling(conv_2, 2, 2, 1, 1, 1, 1); + auto conv_1 = convolution(g) + ("prefix", "conv_1") + ("kernel-dims", std::make_pair(3,3)) + ("kernel-num", 32) + .apply(x); + + auto conv_2 = convolution(g) + ("prefix", "conv_2") + ("kernel-dims", std::make_pair(3,3)) + ("kernel-num", 64) + .apply(conv_1); + auto relued = relu(conv_2); + auto pool = max_pooling(relued, 2, 2, 1, 1, 1, 1); auto flatten = reshape(pool, diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index c0f7c2dd..4c4e0feb 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -1,5 +1,6 @@ #include "graph/expression_operators.h" #include "kernels/sparse.h" +#include "layers/constructors.h" #include "graph/node_operators.h" #include "graph/node_operators_binary.h" @@ -278,15 +279,18 @@ Expr highway(Expr y, Expr x, Expr t) { } Expr highway(const std::string prefix, Expr x) { - size_t out_dim = x->shape()[-1]; - auto g = Dense(prefix + "_highway_d1", - out_dim, - keywords::activation = act::logit)(x); - auto dense_2 = Dense(prefix+ "_highway_d2", - out_dim, - keywords::activation = act::linear)(x); - auto rr = relu(dense_2); - return (g * rr) + ((1 - g) * x); + size_t outDim = x->shape()[-1]; + auto g = mlp::dense(x->graph()) + ("prefix", prefix + "_highway_d1") + ("dim", outDim) + ("activation", mlp::act::logit) + .construct()->apply(x); + auto relued = mlp::dense(x->graph()) + ("prefix", prefix + "_highway_d2") + ("dim", outDim) + ("activation", mlp::act::ReLU) + .construct()->apply(x); + return (g * relued) + ((1 - g) * x); } // Expr batch_norm(Expr x, Expr gamma, Expr beta) { @@ -308,11 +312,6 @@ Expr shift(Expr a, Shape shift) { // return Expression<LexicalProbNodeOp>(logits, att, eps, lf); //} -Expr convolution(Expr x, Expr filters, Expr bias) { - std::vector<Expr> nodes = {x, filters, bias}; - return Expression<ConvolutionOp>(nodes); -} - Expr avg_pooling( Expr x, int height, @@ -386,4 +385,8 @@ Expr convertFromcudnnFormat(Expr x) { return reshape(rows(reshapedX, newIndeces), shape); } +Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) { + return Expression<PoolingWithMaskingOp>(x, mask, width, isEven); +} + } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index e728edaf..37e7c137 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -135,8 +135,6 @@ Expr convert2cudnnFormat(Expr x); Expr convertFromcudnnFormat(Expr x); -Expr convolution(Expr x, Expr filters, Expr bias); - Expr avg_pooling(Expr x, int height, int width, @@ -152,4 +150,7 @@ Expr max_pooling(Expr x, int padWidth = 0, int strideHeight = 1, int strideWidth = 1); + +Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven=false); + } diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index a3d24179..e6647134 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -710,9 +710,19 @@ struct HighwayNodeOp : public NaryNodeOp { class ConvolutionOp : public NaryNodeOp { public: - ConvolutionOp(const std::vector<Expr>& nodes) + ConvolutionOp( + const std::vector<Expr>& nodes, + int hPad = 0, + int wPad = 0, + int hStride = 1, + int wStride = 1) : NaryNodeOp(nodes), - conv_(nodes[1]->shape(), nodes[2]->shape()) { + conv_(nodes[1]->shape(), + nodes[2]->shape(), + hPad, + wPad, + hStride, + wStride) { conv_.getOutputShape(nodes[0]->shape(), shape_); } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 90995ddf..faf21dee 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1051,4 +1051,45 @@ protected: PoolingWrapper pooling_; }; +class PoolingWithMaskingOp : public UnaryNodeOp { + public: + PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) + : UnaryNodeOp(x), + mask_(mask), + width_(width), + isEven_(isEven) + { + auto xShape = x->shape(); + int dimBatch = xShape[0]; + int dimWord = xShape[1]; + int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; + int dimSentence = (cols / width_) + (cols % width_ != 0); + shape_ = {dimBatch, dimWord, dimSentence}; + } + + NodeOps forwardOps() { + return {NodeOp(PoolingWithMaskingForward(val_, + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + NodeOps backwardOps() { + return {NodeOp(PoolingWithMaskingBackward(adj_, + child(0)->grad(), + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + const std::string type() {return "layer_pooling";} + + protected: + Expr mask_; + int width_; + bool isEven_; +}; + } diff --git a/src/kernels/cudnn_wrappers.h b/src/kernels/cudnn_wrappers.h index ce8f44b9..fca4b6e0 100644 --- a/src/kernels/cudnn_wrappers.h +++ b/src/kernels/cudnn_wrappers.h @@ -29,8 +29,8 @@ class ConvolutionWrapper : public CUDNNWrapper { public: ConvolutionWrapper(const Shape& kernelShape, const Shape& biasShape, - int hPad = 1, - int wPad = 1, + int hPad = 0, + int wPad = 0, int hStride = 1, int wStride = 1); diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index 271bc341..3c017cb3 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -2117,4 +2117,133 @@ void HighwayBackward(Tensor out1, length); } +__global__ void gMaxPoolingForward(float* out, + int outRows, + int outCols, + float* in, + int inRows, + int inCols, + float* mask, + int numKernels, + int maskCols, + int width, + int lastWidth) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid >= outRows * outCols) return; + + int rowId = tid / outRows; + int colId = tid % outRows; + + float* b = in + (rowId * inCols) + (colId * width); + + if (colId == outRows - 1) { + width = lastWidth; + } + + float* localMask = mask + (rowId / numKernels) * maskCols + colId * width; + float currentMax = b[0] * localMask[0]; + for (int i = 1; i < width; ++i) { + if (b[i] * localMask[i] > currentMax) { + currentMax = b[i] * localMask[i]; + } + } + + out[rowId + (colId * outCols)] = currentMax; +} + +void PoolingWithMaskingForward(Tensor out, + Tensor in, + Tensor mask, + int width, + bool isEven) { + int n = out->shape().elements(); + int threads = std::min(n, MAX_THREADS); + int blocks = n / threads + (n % threads != 0); + + Shape& inShape = in->shape(); + int inRows = inShape[0] * inShape[1]; + int inCols = inShape[2]; + + Shape& outShape = out->shape(); + int outRows = outShape[2]; + int outCols = outShape[0] * outShape[1]; + + int lastWidth = ((inCols - isEven) % width == 0) + ? width + : (inCols - isEven) % width; + + gMaxPoolingForward<<<blocks, threads>>>( + out->data(), outRows, outCols, + in->data(), inRows, inCols, + mask->data(), outShape[1], mask->shape()[2], + width, lastWidth); +} + +__global__ void gMaxPoolingBackward(float* adj, + int adjRows, + int adjCols, + float* in, + float* adjIn, + int inRows, + int inCols, + float* mask, + int numKernels, + int maskCols, + int width, + int lastWidth) +{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid >= adjRows * adjCols) return; + + int rowId = tid / adjRows; + int colId = tid % adjRows; + + float* b = in + (rowId * inCols) + (colId * width); + + if (colId == adjRows - 1) { + width = lastWidth; + } + + float* localMask = mask + (rowId / numKernels) * maskCols + colId * width; + size_t currentMaxIdx = 0; + for (int i = 1; i < width; ++i) { + if (b[i] * localMask[i] > b[currentMaxIdx] * localMask[currentMaxIdx]) { + currentMaxIdx = i; + } + } + + adjIn[(rowId * inCols) + (colId * width) + currentMaxIdx] += adj[rowId + (colId * adjCols)]; +} + +void PoolingWithMaskingBackward(Tensor adj, + Tensor adjIn, + Tensor in, + Tensor mask, + int width, + bool isEven) { + int n = adj->shape().elements(); + int threads = std::min(n, 512); + int blocks = n / threads + (n % threads != 0); + + Shape& inShape = in->shape(); + int inRows = inShape[0] * inShape[1]; + int inCols = inShape[2]; + + Shape& adjShape = adj->shape(); + int adjRows = adjShape[2]; + int adjCols = adjShape[0] * adjShape[1]; + + int lastWidth = ((inCols - isEven) % width == 0) + ? width + : (inCols - isEven) % width; + + gMaxPoolingBackward<<<blocks, threads>>>( + adj->data(), adjRows, adjCols, + in->data(), adjIn->data(), inRows, inCols, + mask->data(), adjShape[1], mask->shape()[2], + width, lastWidth); +} + } // namespace marian diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h index f9160cd1..06cb188c 100644 --- a/src/kernels/tensor_operators.h +++ b/src/kernels/tensor_operators.h @@ -383,4 +383,17 @@ void HighwayBackward(Tensor out1, const Tensor in2, const Tensor t, const Tensor adj); + +void PoolingWithMaskingForward(Tensor out, + Tensor in, + Tensor mask, + int width, + bool isEven=false); + +void PoolingWithMaskingBackward(Tensor adj, + Tensor adjIn, + Tensor in, + Tensor mask, + int width, + bool isEven=false); } diff --git a/src/layers/convolution.cu b/src/layers/convolution.cu new file mode 100644 index 00000000..958ff4b4 --- /dev/null +++ b/src/layers/convolution.cu @@ -0,0 +1,41 @@ +#include "layers/convolution.h" +#include "graph/node_operators_binary.h" + +namespace marian { +Convolution::Convolution(Ptr<ExpressionGraph> graph) + : Factory(graph) {} + +Expr Convolution::apply(Expr x) { + auto prefix = opt<std::string>("prefix"); + auto kernelDims = opt<std::pair<int, int>>("kernel-dims"); + auto kernelNum = opt<int>("kernel-num"); + auto paddings = opt<std::pair<int, int>>("paddings", std::make_pair(0, 0)); + auto strides = opt<std::pair<int, int>>("strides", std::make_pair(1, 1)); + + int layerIn = x->shape()[1]; + auto kernel = graph_->param(prefix + "_conv_kernels", + {layerIn, + kernelNum, + kernelDims.first, + kernelDims.second}, + keywords::init=inits::glorot_uniform); + + auto bias = graph_->param(prefix + "_conv_bias", + {1, kernelNum, 1, 1}, + keywords::init=inits::zeros); + + std::vector<Expr> nodes = {x, kernel, bias}; + return Expression<ConvolutionOp>(nodes, + paddings.first, + paddings.second, + strides.first, + strides.second); +} + +Expr Convolution::apply(const std::vector<Expr>&) { + ABORT("Can't apply convolution on many inputs at once"); + return nullptr; +} + +} + diff --git a/src/layers/convolution.h b/src/layers/convolution.h index 36af652d..1416da84 100644 --- a/src/layers/convolution.h +++ b/src/layers/convolution.h @@ -3,72 +3,33 @@ #include <string> #include "layers/generic.h" +#include "graph/expression_graph.h" namespace marian { -class Convolution { - public: - Convolution( - const std::string& prefix, - int kernelHeight = 3, - int kernelWidth = 3, - int kernelNum = 1, - int paddingHeight = 0, - int paddingWidth = 0, - int strideHeight = 1, - int strideWidth = 1) - : prefix_(prefix), - kernelHeight_(kernelHeight), - kernelWidth_(kernelWidth), - kernelNum_(kernelNum), - strideHeight_(strideHeight), - strideWidth_(strideWidth), - paddingHeight_(paddingHeight), - paddingWidth_(paddingWidth) - { - } - - Expr operator()(Expr x) { - auto graph = x->graph(); - - int layerIn = x->shape()[1]; +class Convolution : public Factory { +protected: + Ptr<Options> getOptions() { return options_; } - auto kernel = graph->param(prefix_ + "_conv_kernels", - {layerIn, kernelNum_, kernelHeight_, kernelWidth_}, - keywords::init=inits::glorot_uniform); - auto bias = graph->param(prefix_ + "_conv_bias", {1, kernelNum_, 1, 1}, - keywords::init=inits::zeros); +public: + Convolution(Ptr<ExpressionGraph> graph); - auto output = convolution(x, kernel, bias, - paddingHeight_, - paddingWidth_, - strideHeight_, - strideWidth_); + Expr apply(Expr x); - return output; - } - - protected: - std::string prefix_; - int depth_; - int kernelHeight_; - int kernelWidth_; - int kernelNum_; - int strideHeight_; - int strideWidth_; - int paddingHeight_; - int paddingWidth_; + virtual Expr apply(const std::vector<Expr>&); }; +typedef Accumulator<Convolution> convolution; + class CharConvPooling { public: CharConvPooling( - const std::string& name, + const std::string& prefix, int kernelHeight, std::vector<int> kernelWidths, std::vector<int> kernelNums) - : ConvPoolingBase(name), + : name_(prefix), size_(kernelNums.size()), kernelHeight_(kernelHeight), kernelWidths_(kernelWidths), @@ -78,8 +39,8 @@ class CharConvPooling { auto graph = x->graph(); auto masked = x * mask; - auto xNCHW = convert2NCHW(masked); - auto maskNCHW = convert2NCHW(mask); + auto xNCHW = convert2cudnnFormat(masked); + auto maskNCHW = convert2cudnnFormat(mask); int layerIn = xNCHW->shape()[1]; Expr input = xNCHW; @@ -87,18 +48,18 @@ class CharConvPooling { for (int i = 0; i < size_; ++i) { int kernelWidth = kernelWidths_[i]; - int kernelDim = kernelNums_[i]; + int kernelNum = kernelNums_[i]; int padWidth = kernelWidth / 2; - auto kernel = graph->param(name_ + std::to_string(i), - {layerIn, kernelDim, kernelWidth, x->shape()[1]}, - keywords::init=inits::glorot_uniform); - auto bias = graph->param(name_ + std::to_string(i) + "_bias", {1, kernelDim, 1, 1}, - keywords::init=inits::zeros); - auto output = convolution(input, kernel, bias, padWidth, 0, 1, 1); + auto output = convolution(graph) + ("prefix", name_) + ("kernel-dims", std::make_pair(kernelWidth, x->shape()[-1])) + ("kernel-num", kernelNum) + ("paddings", std::make_pair(padWidth, 0)) + .apply(input);; auto relued = relu(output); - auto output2 = max_pooling2(relued, maskNCHW, 5, kernelWidth % 2 == 0); + auto output2 = pooling_with_masking(relued, maskNCHW, 5, kernelWidth % 2 == 0); outputs.push_back(output2); } |