Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-06-07 00:15:31 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-14 03:48:03 +0300
commit42c92bfe456981de02393df7836daeb23998f497 (patch)
tree100ae9254d84a9c4514569749eb59859512e032e
parent3d484ecc002a1876e577ba90d326d1b417f54c8d (diff)
Added GLU (gated linear unit)
From https://arxiv.org/abs/1612.08083
-rw-r--r--lib/THCUNN/generic/GatedLinearUnit.cu12
1 files changed, 7 insertions, 5 deletions
diff --git a/lib/THCUNN/generic/GatedLinearUnit.cu b/lib/THCUNN/generic/GatedLinearUnit.cu
index 0684878..f6f09bd 100644
--- a/lib/THCUNN/generic/GatedLinearUnit.cu
+++ b/lib/THCUNN/generic/GatedLinearUnit.cu
@@ -11,9 +11,10 @@ void THNN_(GatedLinear_updateOutput)(
THCUNN_assertSameGPU(state, 2, input, output);
// size output to half of input
- dim = dim - 1;
+ dim = dim - TH_INDEX_BASE;
const long nIn = THCTensor_(size)(state, input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
const long inputSize = THCTensor_(size)(state, input, dim) / 2;
THLongStorage *newSizes = THCTensor_(newSizeOf)(state, input);
THLongStorage_set(newSizes, dim, inputSize);
@@ -39,9 +40,10 @@ void THNN_(GatedLinear_updateGradInput)(
int dim)
{
THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
- dim = dim - 1;
+ dim = dim - TH_INDEX_BASE;
const long nIn = THCTensor_(size)(state, input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
THCTensor_(resizeAs)(state, gradInput, input);
const long inputSize = THCTensor_(size)(state, input, dim) / 2;
@@ -61,4 +63,4 @@ void THNN_(GatedLinear_updateGradInput)(
THCTensor_(free)(state, gradInputsecondHalf);
}
-#endif \ No newline at end of file
+#endif