Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-06-07 00:15:31 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-14 03:47:55 +0300
commitd1a89001f1b71885c692e3a9ecbfbc9ede545cc9 (patch)
tree4470d9516da87a2cf0cca8cb7f86333a381543b2
parentd578b37dcb62f8df4bcc0a3e004b81eaefc6e21f (diff)
Added GLU (gated linear unit)
From https://arxiv.org/abs/1612.08083
-rw-r--r--lib/THNN/generic/GatedLinearUnit.c10
1 files changed, 6 insertions, 4 deletions
diff --git a/lib/THNN/generic/GatedLinearUnit.c b/lib/THNN/generic/GatedLinearUnit.c
index d412a7b..274a27e 100644
--- a/lib/THNN/generic/GatedLinearUnit.c
+++ b/lib/THNN/generic/GatedLinearUnit.c
@@ -9,9 +9,10 @@ void THNN_(GatedLinear_updateOutput)(
int dim)
{
// size output to half of input
- dim = dim - 1;
+ dim = dim - TH_INDEX_BASE;
const long nIn = THTensor_(size)(input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
const long inputSize = THTensor_(size)(input, dim) / 2;
THLongStorage *newSizes = THTensor_(newSizeOf)(input);
@@ -39,9 +40,10 @@ void THNN_(GatedLinear_updateGradInput)(
int dim)
{
// set up tensors
- dim = dim - 1;
+ dim = dim - TH_INDEX_BASE;
const long nIn = THTensor_(size)(input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
THTensor_(resizeAs)(gradInput, input);
const long inputSize = THTensor_(size)(input, dim) / 2;