diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-04-10 00:49:43 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-04-10 00:49:43 +0400 |
commit | 6783768eb6e06ebbf0c1bd2d6b5f4bc9f709b32f (patch) | |
tree | e78ffe0e0f7651ef6767a3de09443e9a19ffd474 /LookupTable.lua | |
parent | 2781f4bee6c6f725800c863c3c2e96a551295328 (diff) |
unit test for 1D LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r-- | LookupTable.lua | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index 8c23e82..7db20f8 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -22,7 +22,7 @@ function LookupTable:__init(nIndex, ...) self.size[1] = nIndex batchSize = torch.LongTensor(#self.size + 1) - batchSize:narrow(1, 2,#self.size):copy(self.size) + batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size)) batchSize[1] = 1 self.batchSize = batchSize:storage() @@ -60,7 +60,6 @@ function LookupTable:updateOutput(input) local nIndex = input:size(1) self.size[1] = nIndex self.output:resize(self.size) - for i=1,nIndex do self.output:select(1, i):copy(self.weight:select(1, input[i])) end @@ -91,6 +90,7 @@ function LookupTable:zeroGradParameters() end function LookupTable:accGradParameters(input, gradOutput, scale) + scale = scale or 1 if input:dim() == 1 then self.nBackward = self.nBackward + 1 for i=1,input:size(1) do |