Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-02-02 08:20:47 +0300
committerGitHub <noreply@github.com>2017-02-02 08:20:47 +0300
commitdc71d25df659c3f1eb6f7cc6e9c67213c9a166ee (patch)
tree6bfdecf012829b34a32fb4d68db035bd8cb19cab
parent2a8d3962d326a3cece411746495209e1a60bc6c8 (diff)
Revert "cuda implementation of Gated Linear Unit, fixed issues with genericization"revert-430-newCudaGLU
-rw-r--r--lib/THCUNN/GatedLinearUnit.cu29
-rw-r--r--lib/THCUNN/generic/GatedLinearUnit.cu64
-rw-r--r--lib/THCUNN/generic/THCUNN.h13
-rw-r--r--test.lua11
4 files changed, 0 insertions, 117 deletions
diff --git a/lib/THCUNN/GatedLinearUnit.cu b/lib/THCUNN/GatedLinearUnit.cu
deleted file mode 100644
index d9f3471..0000000
--- a/lib/THCUNN/GatedLinearUnit.cu
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "THCUNN.h"
-#include "THCHalf.h"
-#include "THCHalfAutoNumerics.cuh"
-#include "common.h"
-
-template <typename Dtype, typename Acctype>
-struct gatedLinearCSigMul_functor
-{
- __device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const
- {
- const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor)));
- const Dtype mulNum = *mulTensor;
- *target = ScalarConvert<Acctype, Dtype>::to(sigNum * mulNum);
- }
-};
-
-template <typename Dtype, typename Acctype>
-struct gatedLinearDerivativeSecondHalf_functor
-{
- __device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const
- {
- const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor)));
- const Dtype mulNum = *mulTensor;
- *target *= ScalarConvert<Acctype, Dtype>::to((Acctype(1) - sigNum) * sigNum * mulNum);
- }
-};
-
-#include "generic/GatedLinearUnit.cu"
-#include "THCGenerateFloatTypes.h" \ No newline at end of file
diff --git a/lib/THCUNN/generic/GatedLinearUnit.cu b/lib/THCUNN/generic/GatedLinearUnit.cu
deleted file mode 100644
index 0684878..0000000
--- a/lib/THCUNN/generic/GatedLinearUnit.cu
+++ /dev/null
@@ -1,64 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "generic/GatedLinearUnit.cu"
-#else
-
-void THNN_(GatedLinear_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- int dim)
-{
- THCUNN_assertSameGPU(state, 2, input, output);
-
- // size output to half of input
- dim = dim - 1;
- const long nIn = THCTensor_(size)(state, input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
- const long inputSize = THCTensor_(size)(state, input, dim) / 2;
- THLongStorage *newSizes = THCTensor_(newSizeOf)(state, input);
- THLongStorage_set(newSizes, dim, inputSize);
- THCTensor_(resize)(state, output, newSizes, NULL);
-
- // halve tensor
- THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize);
- THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize);
-
- // x = x1:cmul( sigmoid(x2) )
- THC_pointwiseApply3(state, output, secondHalf, firstHalf, gatedLinearCSigMul_functor<real, accreal>());
-
- THLongStorage_free(newSizes);
- THCTensor_(free)(state, firstHalf);
- THCTensor_(free)(state, secondHalf);
-}
-
-void THNN_(GatedLinear_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int dim)
-{
- THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
- dim = dim - 1;
- const long nIn = THCTensor_(size)(state, input, dim);
- THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
-
- THCTensor_(resizeAs)(state, gradInput, input);
- const long inputSize = THCTensor_(size)(state, input, dim) / 2;
- THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize);
- THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize);
- THCTensor *gradInputfirstHalf = THCTensor_(newNarrow)(state, gradInput, dim, 0, inputSize);
- THCTensor *gradInputsecondHalf = THCTensor_(newNarrow)(state, gradInput, dim, inputSize, inputSize);
- // first half of derivative
- THC_pointwiseApply3(state, gradInputfirstHalf, secondHalf, gradOutput, gatedLinearCSigMul_functor<real, accreal>());
- // second half of derivative
- THCTensor_(copy)(state, gradInputsecondHalf, firstHalf);
- THC_pointwiseApply3(state, gradInputsecondHalf, secondHalf, gradOutput, gatedLinearDerivativeSecondHalf_functor<real, accreal>());
-
- THCTensor_(free)(state, firstHalf);
- THCTensor_(free)(state, secondHalf);
- THCTensor_(free)(state, gradInputfirstHalf);
- THCTensor_(free)(state, gradInputsecondHalf);
-}
-
-#endif \ No newline at end of file
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index de05294..ec3d287 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -138,19 +138,6 @@ TH_API void THNN_(HardTanh_updateGradInput)(
real max_val,
bool inplace);
-TH_API void THNN_(GatedLinear_updateOutput)(
- THCState *state,
- THCudaTensor *input,
- THCudaTensor *output,
- int dim);
-
-TH_API void THNN_(GatedLinear_updateGradInput)(
- THCState *state,
- THCudaTensor *input,
- THCudaTensor *gradOutput,
- THCudaTensor *gradInput,
- int dim);
-
TH_API void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/test.lua b/test.lua
index 1fb1205..14d072d 100644
--- a/test.lua
+++ b/test.lua
@@ -84,7 +84,6 @@ end
local function pointwise_forward(proto_module, name, max_error)
local size = math.random(1,100)
- if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
@@ -106,12 +105,10 @@ end
local function pointwise_backward(proto_module, name, max_error)
local size = math.random(1,100)
- if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
local gradOutput = torch.randn(size):type(typename)
- if name == 'GatedLinearUnit' then gradOutput = torch.randn(size/2) end
local ctype = t2cpu[typename]
input = makeNonContiguous(input:type(ctype))
@@ -270,14 +267,6 @@ function cunntest.LogSigmoid_transposed()
pointwise_transposed(nn.LogSigmoid(), 'LogSigmoid', 1e-6)
end
-function cunntest.GatedLinearUnit_forward()
- pointwise_forward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_forward)
-end
-
-function cunntest.GatedLinearUnit_backward()
- pointwise_backward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_backward)
-end
-
function cunntest.Threshold_forward()
pointwise_forward(nn.Threshold(), 'Threshold', precision_forward)
pointwise_forward(nn.Threshold(nil, nil, true), 'Threshold_inplace', precision_forward)