Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAngela Fan <angelafan@fb.com>2017-01-27 09:56:33 +0300
committerAngela Fan <angelafan@fb.com>2017-02-03 08:38:25 +0300
commit32a0ada82bc266e36517c34406f13d5b67117b5c (patch)
tree9669aa75af5c635faf78e28b8845d1f476288242 /test.lua
parentb1370068e31f8d1fd5a8cadd91effe9ea77d5820 (diff)
cuda implementation of Gated Linear Unit, fixed issues with genericization
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua11
1 files changed, 11 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 14d072d..1fb1205 100644
--- a/test.lua
+++ b/test.lua
@@ -84,6 +84,7 @@ end
local function pointwise_forward(proto_module, name, max_error)
local size = math.random(1,100)
+ if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
@@ -105,10 +106,12 @@ end
local function pointwise_backward(proto_module, name, max_error)
local size = math.random(1,100)
+ if name == 'GatedLinearUnit' then size = size*2 end
for k, typename in ipairs(typenames) do
local input = torch.randn(size):type(typename)
local gradOutput = torch.randn(size):type(typename)
+ if name == 'GatedLinearUnit' then gradOutput = torch.randn(size/2) end
local ctype = t2cpu[typename]
input = makeNonContiguous(input:type(ctype))
@@ -267,6 +270,14 @@ function cunntest.LogSigmoid_transposed()
pointwise_transposed(nn.LogSigmoid(), 'LogSigmoid', 1e-6)
end
+function cunntest.GatedLinearUnit_forward()
+ pointwise_forward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_forward)
+end
+
+function cunntest.GatedLinearUnit_backward()
+ pointwise_backward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_backward)
+end
+
function cunntest.Threshold_forward()
pointwise_forward(nn.Threshold(), 'Threshold', precision_forward)
pointwise_forward(nn.Threshold(nil, nil, true), 'Threshold_inplace', precision_forward)