diff options
author | Angela Fan <angelafan@fb.com> | 2017-01-27 09:56:33 +0300 |
---|---|---|
committer | Angela Fan <angelafan@fb.com> | 2017-02-03 08:38:25 +0300 |
commit | 32a0ada82bc266e36517c34406f13d5b67117b5c (patch) | |
tree | 9669aa75af5c635faf78e28b8845d1f476288242 /test.lua | |
parent | b1370068e31f8d1fd5a8cadd91effe9ea77d5820 (diff) |
cuda implementation of Gated Linear Unit, fixed issues with genericization
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 11 |
1 files changed, 11 insertions, 0 deletions
@@ -84,6 +84,7 @@ end local function pointwise_forward(proto_module, name, max_error) local size = math.random(1,100) + if name == 'GatedLinearUnit' then size = size*2 end for k, typename in ipairs(typenames) do local input = torch.randn(size):type(typename) @@ -105,10 +106,12 @@ end local function pointwise_backward(proto_module, name, max_error) local size = math.random(1,100) + if name == 'GatedLinearUnit' then size = size*2 end for k, typename in ipairs(typenames) do local input = torch.randn(size):type(typename) local gradOutput = torch.randn(size):type(typename) + if name == 'GatedLinearUnit' then gradOutput = torch.randn(size/2) end local ctype = t2cpu[typename] input = makeNonContiguous(input:type(ctype)) @@ -267,6 +270,14 @@ function cunntest.LogSigmoid_transposed() pointwise_transposed(nn.LogSigmoid(), 'LogSigmoid', 1e-6) end +function cunntest.GatedLinearUnit_forward() + pointwise_forward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_forward) +end + +function cunntest.GatedLinearUnit_backward() + pointwise_backward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_backward) +end + function cunntest.Threshold_forward() pointwise_forward(nn.Threshold(), 'Threshold', precision_forward) pointwise_forward(nn.Threshold(nil, nil, true), 'Threshold_inplace', precision_forward) |