Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-27 23:04:44 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-27 23:04:44 +0300
commitc6350c666f293c64f3745d20d9cb9796eed849c8 (patch)
treed39a93fdcf734cc7abcb884992cab29702af998f
parentd1d0df98d2b0df859d06fb38ba376a5a717c3bfc (diff)
fix transpose operator
-rw-r--r--src/graph/expression_operators.h9
-rw-r--r--src/graph/node_operators_unary.h4
-rw-r--r--src/tensors/cpu/tensor_operators.cpp26
-rw-r--r--src/tensors/gpu/tensor_operators.cu17
-rw-r--r--src/tensors/tensor_operators.h2
5 files changed, 38 insertions, 20 deletions
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 036445ee..c8c04763 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -34,6 +34,7 @@ Expr log(Expr a);
Expr exp(Expr a);
+// check
Expr clip(Expr a, float c);
Expr operator-(Expr a);
@@ -79,10 +80,14 @@ Expr affine(Expr a,
bool transB = false,
float scalar = 1.f);
+// check
Expr transpose(Expr a);
+// check
Expr transpose(Expr a, const std::vector<int>& axes);
+// check
Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0);
+// check
Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0);
Expr reshape(Expr a, Shape shape);
@@ -96,9 +101,11 @@ Expr atleast_nd(Expr a, size_t dims);
Expr flatten(Expr a);
Expr flatten_2d(Expr a);
+// check
Expr rows(Expr a, const std::vector<size_t>& indices);
+// check
Expr cols(Expr a, const std::vector<size_t>& indices);
-
+// check
Expr select(Expr a, int axis, const std::vector<size_t>& indices);
/*********************************************************/
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index bb0b66f4..0fc17d28 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -806,11 +806,11 @@ struct TransposeNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {}
NodeOps forwardOps() {
- return {NodeOp(TransposeND(val_, child(0)->val(), axes_))};
+ return {NodeOp(TransposeND(val_, child(0)->val(), axes_, 0.f))};
}
NodeOps backwardOps() {
- return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_))};
+ return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_, 1.f))};
}
template <class... Args>
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index 74748c6d..cba03c12 100644
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -124,7 +124,7 @@ void Deconcatenate(std::vector<Tensor>& outputs, const Tensor in, int ax) {
SplitCont(outputs, in, ax);
}
-void Transpose0213(Tensor out, Tensor in) {
+void Transpose0213(Tensor out, Tensor in, float beta) {
int cols = in->shape()[-1];
int rows = in->shape().elements() / in->shape()[-1];
@@ -141,7 +141,15 @@ void Transpose0213(Tensor out, Tensor in) {
const float* inRow = in->data() + src * cols ;
float* outRow = out->data() + dst * cols;
- std::copy(inRow, inRow + cols, outRow);
+ if(beta == 0) {
+ // mostly for fast forward computation
+ std::copy(inRow, inRow + cols, outRow);
+ }
+ else {
+ for(int i = 0; i < cols; ++i) {
+ outRow[i] = inRow[i] + beta * outRow[i];
+ }
+ }
}
}
}
@@ -186,7 +194,7 @@ void Transpose10(Tensor out, const Tensor in) {
}
// @TODO: optimize this, currently it's quite horrible
-void TransposeGeneric(Tensor out, Tensor in, const std::vector<int>& vAxis) {
+void TransposeGeneric(Tensor out, Tensor in, const std::vector<int>& vAxis, float beta) {
functional::Array<int, functional::Shape::size()> permute;
int diff = functional::Shape::size() - vAxis.size();
for(int i = 0; i < permute.size(); ++i)
@@ -207,19 +215,19 @@ void TransposeGeneric(Tensor out, Tensor in, const std::vector<int>& vAxis) {
gOut.shape().dims(index, oDims);
for(int i = 0; i < N; ++i)
pDims[permute[i]] = oDims[i];
- gOut[index] = gIn[pDims];
+ gOut[index] = gIn[pDims] + beta * gOut[index];
}
}
-void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
+void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis, float beta) {
if(vAxis == std::vector<int>({0, 2, 1, 3}))
- Transpose0213(out, in);
- else if(vAxis == std::vector<int>({1, 0})
- && in->shape()[-1] % 16 == 0
+ Transpose0213(out, in, beta);
+ else if(vAxis == std::vector<int>({1, 0}) && beta == 0
+ && in->shape()[-1] % 16 == 0
&& in->shape()[-2] % 16 == 0)
Transpose10(out, in);
else
- TransposeGeneric(out, in, vAxis);
+ TransposeGeneric(out, in, vAxis, beta);
}
void Softmax(Tensor out_, Tensor in_, Tensor mask_) {
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index 06f44d3d..70cdd900 100644
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -235,7 +235,8 @@ void Deconcatenate(std::vector<Tensor>& outputs, const Tensor in, int ax) {
__global__ void gTransposeND(
functional::Tensor<float> out,
const functional::Tensor<float> in,
- const functional::Array<int, functional::Shape::size()> permute) {
+ const functional::Array<int, functional::Shape::size()> permute,
+ float beta) {
constexpr size_t N = functional::Shape::size();
functional::Array<int, N> oDims;
functional::Array<int, N> pDims;
@@ -247,7 +248,7 @@ __global__ void gTransposeND(
out.shape().dims(index, oDims);
for(int i = 0; i < N; ++i)
pDims[permute[i]] = oDims[i];
- out[index] = in[pDims];
+ out[index] = in[pDims] + beta * out[index];
}
}
}
@@ -257,7 +258,8 @@ void gTranspose0213(float* out, const float* in,
int rows,
int cols,
int stride1,
- int stride2) {
+ int stride2,
+ float beta) {
int stride = stride1 * stride2;
for(int bid = 0; bid < rows; bid += gridDim.x) {
@@ -275,14 +277,14 @@ void gTranspose0213(float* out, const float* in,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
- rowOut[i] = rowIn[i];
+ rowOut[i] = rowIn[i] + beta * rowOut[i];
}
}
}
}
-void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
+void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis, float beta) {
cudaSetDevice(out->getDevice().no);
if(vAxis == std::vector<int>({0, 2, 1, 3})) {
@@ -295,7 +297,8 @@ void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int stride1 = out->shape()[-2];
int stride2 = out->shape()[-3];
- gTranspose0213<<<blocks, threads>>>(out->data(), in->data(), rows, cols, stride1, stride2);
+ gTranspose0213<<<blocks, threads>>>(out->data(), in->data(),
+ rows, cols, stride1, stride2, beta);
}
else {
@@ -311,7 +314,7 @@ void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gTransposeND<<<blocks, threads>>>(out, in, axes);
+ gTransposeND<<<blocks, threads>>>(out, in, axes, beta);
}
}
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 44d2e193..b7422e0c 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -78,7 +78,7 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
DISPATCH3(CrossEntropyPick, marian::Tensor, marian::Tensor, marian::Tensor)
DISPATCH4(CrossEntropyPickBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor)
- DISPATCH3(TransposeND, marian::Tensor, marian::Tensor, const std::vector<int>&)
+ DISPATCH4(TransposeND, marian::Tensor, marian::Tensor, const std::vector<int>&, float)
DISPATCH4(Shift, marian::Tensor, marian::Tensor, marian::Shape, bool)
DISPATCH3(Concatenate, marian::Tensor, const std::vector<marian::Tensor>&, int)