Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJin-Hwa Kim <jnhwkim@gmail.com>2017-02-12 11:20:02 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-12 11:20:02 +0300
commitec06db00ac38d8210e25b4773d792c86b2c2a64a (patch)
tree13b36aa81d4ed94f58afa9124be74674355f3135 /test.lua
parent9bb440d1b95a28532a1e327a8976b00540294013 (diff)
Support Tensor constant (#1129)
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua9
1 files changed, 9 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index e5c92ab..39fc7f3 100644
--- a/test.lua
+++ b/test.lua
@@ -5585,6 +5585,15 @@ function nntest.AddConstant()
local err = (input1-input2):abs():max()
mytester:asserteq(err, 0, torch.typename(module1) ..
' - inplace input change err ')
+
+ local module3 = nn.AddConstant(torch.Tensor{1,2,3})
+ local out3 = module3:forward(torch.Tensor{-1,-2,-3})
+ mytester:asserteq(0, out3:abs():max(), torch.typename(module3) ..
+ ' - tensor constant forward err ')
+ local module4 = nn.AddConstant(torch.Tensor{1,2,3})
+ local out4 = module3:forward(torch.Tensor{{-1,-2,-3},{-1,-2,-3}})
+ mytester:asserteq(0, out4:abs():max(), torch.typename(module4) ..
+ ' - batch tensor constant forward err ')
end
function nntest.MulConstant()