Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-04-10 00:49:43 +0400
committerNicholas Leonard <nick@nikopia.org>2014-04-10 00:49:43 +0400
commit6783768eb6e06ebbf0c1bd2d6b5f4bc9f709b32f (patch)
treee78ffe0e0f7651ef6767a3de09443e9a19ffd474 /LookupTable.lua
parent2781f4bee6c6f725800c863c3c2e96a551295328 (diff)
unit test for 1D LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r--LookupTable.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index 8c23e82..7db20f8 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -22,7 +22,7 @@ function LookupTable:__init(nIndex, ...)
self.size[1] = nIndex
batchSize = torch.LongTensor(#self.size + 1)
- batchSize:narrow(1, 2,#self.size):copy(self.size)
+ batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size))
batchSize[1] = 1
self.batchSize = batchSize:storage()
@@ -60,7 +60,6 @@ function LookupTable:updateOutput(input)
local nIndex = input:size(1)
self.size[1] = nIndex
self.output:resize(self.size)
-
for i=1,nIndex do
self.output:select(1, i):copy(self.weight:select(1, input[i]))
end
@@ -91,6 +90,7 @@ function LookupTable:zeroGradParameters()
end
function LookupTable:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
if input:dim() == 1 then
self.nBackward = self.nBackward + 1
for i=1,input:size(1) do