Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-16 02:05:55 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-16 02:30:41 +0300
commit0effc8d28d25fbf80e3cea34d6f5b4e42490a7f2 (patch)
tree0581ddd34c53d722bf9184fccc643cceca121c03
parentce44cd9e287804860538dc80af0f87ae3ff8cfec (diff)
Refactor convolution and and poolingWithMasking
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/examples/mnist/model_lenet.h16
-rw-r--r--src/graph/expression_operators.cu31
-rw-r--r--src/graph/expression_operators.h5
-rw-r--r--src/graph/node_operators_binary.h14
-rw-r--r--src/graph/node_operators_unary.h41
-rw-r--r--src/kernels/cudnn_wrappers.h4
-rw-r--r--src/kernels/tensor_operators.cu129
-rw-r--r--src/kernels/tensor_operators.h13
-rw-r--r--src/layers/convolution.cu41
-rw-r--r--src/layers/convolution.h83
11 files changed, 294 insertions, 84 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 1cdca70c..3a2a6e7d 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -21,6 +21,7 @@ cuda_add_library(marian
layers/param_initializers.cu
layers/generic.cpp
layers/guided_alignment.cpp
+ layers/convolution.cu
models/model_factory.cpp
rnn/attention.cu
rnn/cells.cu
diff --git a/src/examples/mnist/model_lenet.h b/src/examples/mnist/model_lenet.h
index 07dbc148..fa5aa831 100644
--- a/src/examples/mnist/model_lenet.h
+++ b/src/examples/mnist/model_lenet.h
@@ -31,9 +31,19 @@ protected:
// Construct hidden layers
- auto conv_1 = Convolution("Conv1", 3, 3, 32)(x);
- auto conv_2 = relu(Convolution("Conv2", 3, 3, 64)(conv_1));
- auto pool = max_pooling(conv_2, 2, 2, 1, 1, 1, 1);
+ auto conv_1 = convolution(g)
+ ("prefix", "conv_1")
+ ("kernel-dims", std::make_pair(3,3))
+ ("kernel-num", 32)
+ .apply(x);
+
+ auto conv_2 = convolution(g)
+ ("prefix", "conv_2")
+ ("kernel-dims", std::make_pair(3,3))
+ ("kernel-num", 64)
+ .apply(conv_1);
+ auto relued = relu(conv_2);
+ auto pool = max_pooling(relued, 2, 2, 1, 1, 1, 1);
auto flatten
= reshape(pool,
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index c0f7c2dd..4c4e0feb 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -1,5 +1,6 @@
#include "graph/expression_operators.h"
#include "kernels/sparse.h"
+#include "layers/constructors.h"
#include "graph/node_operators.h"
#include "graph/node_operators_binary.h"
@@ -278,15 +279,18 @@ Expr highway(Expr y, Expr x, Expr t) {
}
Expr highway(const std::string prefix, Expr x) {
- size_t out_dim = x->shape()[-1];
- auto g = Dense(prefix + "_highway_d1",
- out_dim,
- keywords::activation = act::logit)(x);
- auto dense_2 = Dense(prefix+ "_highway_d2",
- out_dim,
- keywords::activation = act::linear)(x);
- auto rr = relu(dense_2);
- return (g * rr) + ((1 - g) * x);
+ size_t outDim = x->shape()[-1];
+ auto g = mlp::dense(x->graph())
+ ("prefix", prefix + "_highway_d1")
+ ("dim", outDim)
+ ("activation", mlp::act::logit)
+ .construct()->apply(x);
+ auto relued = mlp::dense(x->graph())
+ ("prefix", prefix + "_highway_d2")
+ ("dim", outDim)
+ ("activation", mlp::act::ReLU)
+ .construct()->apply(x);
+ return (g * relued) + ((1 - g) * x);
}
// Expr batch_norm(Expr x, Expr gamma, Expr beta) {
@@ -308,11 +312,6 @@ Expr shift(Expr a, Shape shift) {
// return Expression<LexicalProbNodeOp>(logits, att, eps, lf);
//}
-Expr convolution(Expr x, Expr filters, Expr bias) {
- std::vector<Expr> nodes = {x, filters, bias};
- return Expression<ConvolutionOp>(nodes);
-}
-
Expr avg_pooling(
Expr x,
int height,
@@ -386,4 +385,8 @@ Expr convertFromcudnnFormat(Expr x) {
return reshape(rows(reshapedX, newIndeces), shape);
}
+Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) {
+ return Expression<PoolingWithMaskingOp>(x, mask, width, isEven);
+}
+
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index e728edaf..37e7c137 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -135,8 +135,6 @@ Expr convert2cudnnFormat(Expr x);
Expr convertFromcudnnFormat(Expr x);
-Expr convolution(Expr x, Expr filters, Expr bias);
-
Expr avg_pooling(Expr x,
int height,
int width,
@@ -152,4 +150,7 @@ Expr max_pooling(Expr x,
int padWidth = 0,
int strideHeight = 1,
int strideWidth = 1);
+
+Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven=false);
+
}
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index a3d24179..e6647134 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -710,9 +710,19 @@ struct HighwayNodeOp : public NaryNodeOp {
class ConvolutionOp : public NaryNodeOp {
public:
- ConvolutionOp(const std::vector<Expr>& nodes)
+ ConvolutionOp(
+ const std::vector<Expr>& nodes,
+ int hPad = 0,
+ int wPad = 0,
+ int hStride = 1,
+ int wStride = 1)
: NaryNodeOp(nodes),
- conv_(nodes[1]->shape(), nodes[2]->shape()) {
+ conv_(nodes[1]->shape(),
+ nodes[2]->shape(),
+ hPad,
+ wPad,
+ hStride,
+ wStride) {
conv_.getOutputShape(nodes[0]->shape(), shape_);
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 90995ddf..faf21dee 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -1051,4 +1051,45 @@ protected:
PoolingWrapper pooling_;
};
+class PoolingWithMaskingOp : public UnaryNodeOp {
+ public:
+ PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
+ : UnaryNodeOp(x),
+ mask_(mask),
+ width_(width),
+ isEven_(isEven)
+ {
+ auto xShape = x->shape();
+ int dimBatch = xShape[0];
+ int dimWord = xShape[1];
+ int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
+ int dimSentence = (cols / width_) + (cols % width_ != 0);
+ shape_ = {dimBatch, dimWord, dimSentence};
+ }
+
+ NodeOps forwardOps() {
+ return {NodeOp(PoolingWithMaskingForward(val_,
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(PoolingWithMaskingBackward(adj_,
+ child(0)->grad(),
+ child(0)->val(),
+ mask_->val(),
+ width_,
+ isEven_))};
+ }
+
+ const std::string type() {return "layer_pooling";}
+
+ protected:
+ Expr mask_;
+ int width_;
+ bool isEven_;
+};
+
}
diff --git a/src/kernels/cudnn_wrappers.h b/src/kernels/cudnn_wrappers.h
index ce8f44b9..fca4b6e0 100644
--- a/src/kernels/cudnn_wrappers.h
+++ b/src/kernels/cudnn_wrappers.h
@@ -29,8 +29,8 @@ class ConvolutionWrapper : public CUDNNWrapper {
public:
ConvolutionWrapper(const Shape& kernelShape,
const Shape& biasShape,
- int hPad = 1,
- int wPad = 1,
+ int hPad = 0,
+ int wPad = 0,
int hStride = 1,
int wStride = 1);
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index 271bc341..3c017cb3 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -2117,4 +2117,133 @@ void HighwayBackward(Tensor out1,
length);
}
+__global__ void gMaxPoolingForward(float* out,
+ int outRows,
+ int outCols,
+ float* in,
+ int inRows,
+ int inCols,
+ float* mask,
+ int numKernels,
+ int maskCols,
+ int width,
+ int lastWidth) {
+ int tid = threadIdx.x + blockIdx.x * blockDim.x;
+
+ if (tid >= outRows * outCols) return;
+
+ int rowId = tid / outRows;
+ int colId = tid % outRows;
+
+ float* b = in + (rowId * inCols) + (colId * width);
+
+ if (colId == outRows - 1) {
+ width = lastWidth;
+ }
+
+ float* localMask = mask + (rowId / numKernels) * maskCols + colId * width;
+ float currentMax = b[0] * localMask[0];
+ for (int i = 1; i < width; ++i) {
+ if (b[i] * localMask[i] > currentMax) {
+ currentMax = b[i] * localMask[i];
+ }
+ }
+
+ out[rowId + (colId * outCols)] = currentMax;
+}
+
+void PoolingWithMaskingForward(Tensor out,
+ Tensor in,
+ Tensor mask,
+ int width,
+ bool isEven) {
+ int n = out->shape().elements();
+ int threads = std::min(n, MAX_THREADS);
+ int blocks = n / threads + (n % threads != 0);
+
+ Shape& inShape = in->shape();
+ int inRows = inShape[0] * inShape[1];
+ int inCols = inShape[2];
+
+ Shape& outShape = out->shape();
+ int outRows = outShape[2];
+ int outCols = outShape[0] * outShape[1];
+
+ int lastWidth = ((inCols - isEven) % width == 0)
+ ? width
+ : (inCols - isEven) % width;
+
+ gMaxPoolingForward<<<blocks, threads>>>(
+ out->data(), outRows, outCols,
+ in->data(), inRows, inCols,
+ mask->data(), outShape[1], mask->shape()[2],
+ width, lastWidth);
+}
+
+__global__ void gMaxPoolingBackward(float* adj,
+ int adjRows,
+ int adjCols,
+ float* in,
+ float* adjIn,
+ int inRows,
+ int inCols,
+ float* mask,
+ int numKernels,
+ int maskCols,
+ int width,
+ int lastWidth)
+{
+ int tid = threadIdx.x + blockIdx.x * blockDim.x;
+
+ if (tid >= adjRows * adjCols) return;
+
+ int rowId = tid / adjRows;
+ int colId = tid % adjRows;
+
+ float* b = in + (rowId * inCols) + (colId * width);
+
+ if (colId == adjRows - 1) {
+ width = lastWidth;
+ }
+
+ float* localMask = mask + (rowId / numKernels) * maskCols + colId * width;
+ size_t currentMaxIdx = 0;
+ for (int i = 1; i < width; ++i) {
+ if (b[i] * localMask[i] > b[currentMaxIdx] * localMask[currentMaxIdx]) {
+ currentMaxIdx = i;
+ }
+ }
+
+ adjIn[(rowId * inCols) + (colId * width) + currentMaxIdx] += adj[rowId + (colId * adjCols)];
+}
+
+void PoolingWithMaskingBackward(Tensor adj,
+ Tensor adjIn,
+ Tensor in,
+ Tensor mask,
+ int width,
+ bool isEven) {
+ int n = adj->shape().elements();
+ int threads = std::min(n, 512);
+ int blocks = n / threads + (n % threads != 0);
+
+ Shape& inShape = in->shape();
+ int inRows = inShape[0] * inShape[1];
+ int inCols = inShape[2];
+
+ Shape& adjShape = adj->shape();
+ int adjRows = adjShape[2];
+ int adjCols = adjShape[0] * adjShape[1];
+
+ int lastWidth = ((inCols - isEven) % width == 0)
+ ? width
+ : (inCols - isEven) % width;
+
+ gMaxPoolingBackward<<<blocks, threads>>>(
+ adj->data(), adjRows, adjCols,
+ in->data(), adjIn->data(), inRows, inCols,
+ mask->data(), adjShape[1], mask->shape()[2],
+ width, lastWidth);
+}
+
} // namespace marian
diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h
index f9160cd1..06cb188c 100644
--- a/src/kernels/tensor_operators.h
+++ b/src/kernels/tensor_operators.h
@@ -383,4 +383,17 @@ void HighwayBackward(Tensor out1,
const Tensor in2,
const Tensor t,
const Tensor adj);
+
+void PoolingWithMaskingForward(Tensor out,
+ Tensor in,
+ Tensor mask,
+ int width,
+ bool isEven=false);
+
+void PoolingWithMaskingBackward(Tensor adj,
+ Tensor adjIn,
+ Tensor in,
+ Tensor mask,
+ int width,
+ bool isEven=false);
}
diff --git a/src/layers/convolution.cu b/src/layers/convolution.cu
new file mode 100644
index 00000000..958ff4b4
--- /dev/null
+++ b/src/layers/convolution.cu
@@ -0,0 +1,41 @@
+#include "layers/convolution.h"
+#include "graph/node_operators_binary.h"
+
+namespace marian {
+Convolution::Convolution(Ptr<ExpressionGraph> graph)
+ : Factory(graph) {}
+
+Expr Convolution::apply(Expr x) {
+ auto prefix = opt<std::string>("prefix");
+ auto kernelDims = opt<std::pair<int, int>>("kernel-dims");
+ auto kernelNum = opt<int>("kernel-num");
+ auto paddings = opt<std::pair<int, int>>("paddings", std::make_pair(0, 0));
+ auto strides = opt<std::pair<int, int>>("strides", std::make_pair(1, 1));
+
+ int layerIn = x->shape()[1];
+ auto kernel = graph_->param(prefix + "_conv_kernels",
+ {layerIn,
+ kernelNum,
+ kernelDims.first,
+ kernelDims.second},
+ keywords::init=inits::glorot_uniform);
+
+ auto bias = graph_->param(prefix + "_conv_bias",
+ {1, kernelNum, 1, 1},
+ keywords::init=inits::zeros);
+
+ std::vector<Expr> nodes = {x, kernel, bias};
+ return Expression<ConvolutionOp>(nodes,
+ paddings.first,
+ paddings.second,
+ strides.first,
+ strides.second);
+}
+
+Expr Convolution::apply(const std::vector<Expr>&) {
+ ABORT("Can't apply convolution on many inputs at once");
+ return nullptr;
+}
+
+}
+
diff --git a/src/layers/convolution.h b/src/layers/convolution.h
index 36af652d..1416da84 100644
--- a/src/layers/convolution.h
+++ b/src/layers/convolution.h
@@ -3,72 +3,33 @@
#include <string>
#include "layers/generic.h"
+#include "graph/expression_graph.h"
namespace marian {
-class Convolution {
- public:
- Convolution(
- const std::string& prefix,
- int kernelHeight = 3,
- int kernelWidth = 3,
- int kernelNum = 1,
- int paddingHeight = 0,
- int paddingWidth = 0,
- int strideHeight = 1,
- int strideWidth = 1)
- : prefix_(prefix),
- kernelHeight_(kernelHeight),
- kernelWidth_(kernelWidth),
- kernelNum_(kernelNum),
- strideHeight_(strideHeight),
- strideWidth_(strideWidth),
- paddingHeight_(paddingHeight),
- paddingWidth_(paddingWidth)
- {
- }
-
- Expr operator()(Expr x) {
- auto graph = x->graph();
-
- int layerIn = x->shape()[1];
+class Convolution : public Factory {
+protected:
+ Ptr<Options> getOptions() { return options_; }
- auto kernel = graph->param(prefix_ + "_conv_kernels",
- {layerIn, kernelNum_, kernelHeight_, kernelWidth_},
- keywords::init=inits::glorot_uniform);
- auto bias = graph->param(prefix_ + "_conv_bias", {1, kernelNum_, 1, 1},
- keywords::init=inits::zeros);
+public:
+ Convolution(Ptr<ExpressionGraph> graph);
- auto output = convolution(x, kernel, bias,
- paddingHeight_,
- paddingWidth_,
- strideHeight_,
- strideWidth_);
+ Expr apply(Expr x);
- return output;
- }
-
- protected:
- std::string prefix_;
- int depth_;
- int kernelHeight_;
- int kernelWidth_;
- int kernelNum_;
- int strideHeight_;
- int strideWidth_;
- int paddingHeight_;
- int paddingWidth_;
+ virtual Expr apply(const std::vector<Expr>&);
};
+typedef Accumulator<Convolution> convolution;
+
class CharConvPooling {
public:
CharConvPooling(
- const std::string& name,
+ const std::string& prefix,
int kernelHeight,
std::vector<int> kernelWidths,
std::vector<int> kernelNums)
- : ConvPoolingBase(name),
+ : name_(prefix),
size_(kernelNums.size()),
kernelHeight_(kernelHeight),
kernelWidths_(kernelWidths),
@@ -78,8 +39,8 @@ class CharConvPooling {
auto graph = x->graph();
auto masked = x * mask;
- auto xNCHW = convert2NCHW(masked);
- auto maskNCHW = convert2NCHW(mask);
+ auto xNCHW = convert2cudnnFormat(masked);
+ auto maskNCHW = convert2cudnnFormat(mask);
int layerIn = xNCHW->shape()[1];
Expr input = xNCHW;
@@ -87,18 +48,18 @@ class CharConvPooling {
for (int i = 0; i < size_; ++i) {
int kernelWidth = kernelWidths_[i];
- int kernelDim = kernelNums_[i];
+ int kernelNum = kernelNums_[i];
int padWidth = kernelWidth / 2;
- auto kernel = graph->param(name_ + std::to_string(i),
- {layerIn, kernelDim, kernelWidth, x->shape()[1]},
- keywords::init=inits::glorot_uniform);
- auto bias = graph->param(name_ + std::to_string(i) + "_bias", {1, kernelDim, 1, 1},
- keywords::init=inits::zeros);
- auto output = convolution(input, kernel, bias, padWidth, 0, 1, 1);
+ auto output = convolution(graph)
+ ("prefix", name_)
+ ("kernel-dims", std::make_pair(kernelWidth, x->shape()[-1]))
+ ("kernel-num", kernelNum)
+ ("paddings", std::make_pair(padWidth, 0))
+ .apply(input);;
auto relued = relu(output);
- auto output2 = max_pooling2(relued, maskNCHW, 5, kernelWidth % 2 == 0);
+ auto output2 = pooling_with_masking(relued, maskNCHW, 5, kernelWidth % 2 == 0);
outputs.push_back(output2);
}