Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-02-02 05:46:35 +0300
committerGitHub <noreply@github.com>2017-02-02 05:46:35 +0300
commit2a8d3962d326a3cece411746495209e1a60bc6c8 (patch)
tree65fa5f93aba381423ea35e871f665eba7a140316 /test.lua
parentfc0dc59a434eb4d212891bf21fcee52e014c4b09 (diff)
parentd253af054affc75d42efc6c33bae5e8673f8a563 (diff)
Merge pull request #430 from huihuifan/newCudaGLU
cuda implementation of Gated Linear Unit, fixed issues with genericization
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua11
1 files changed, 11 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 14d072d..1fb1205 100644
--- a/test.lua
+++ b/test.lua
@@ -84,6 +84,7 @@ end
local function pointwise_forward(proto_module, name, max_error)
local size = math.random(1,100)
+ if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
@@ -105,10 +106,12 @@ end
local function pointwise_backward(proto_module, name, max_error)
local size = math.random(1,100)
+ if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
local gradOutput = torch.randn(size):type(typename)
+ if name == 'GatedLinearUnit' then gradOutput = torch.randn(size/2) end
local ctype = t2cpu[typename]
input = makeNonContiguous(input:type(ctype))
@@ -267,6 +270,14 @@ function cunntest.LogSigmoid_transposed()
pointwise_transposed(nn.LogSigmoid(), 'LogSigmoid', 1e-6)
end
+function cunntest.GatedLinearUnit_forward()
+ pointwise_forward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_forward)
+end
+
+function cunntest.GatedLinearUnit_backward()
+ pointwise_backward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_backward)
+end
+
function cunntest.Threshold_forward()
pointwise_forward(nn.Threshold(), 'Threshold', precision_forward)
pointwise_forward(nn.Threshold(nil, nil, true), 'Threshold_inplace', precision_forward)