Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKai Sheng Tai <kaishengtai@gmail.com>2014-11-09 22:39:27 +0300
committerKai Sheng Tai <kaishengtai@gmail.com>2014-11-10 21:38:53 +0300
commitdbbd8798625149fccea0320e19f72ecc438b03d0 (patch)
tree4558c6ccf4538c9efdfcb50c93c41e24b7b0c317 /LookupTable.lua
parent704684a27efd82da3f4ac05cc9ecb6f44aa6d510 (diff)
Performance improvement to LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r--LookupTable.lua23
1 files changed, 11 insertions, 12 deletions
diff --git a/LookupTable.lua b/LookupTable.lua
index c286383..fa8febf 100644
--- a/LookupTable.lua
+++ b/LookupTable.lua
@@ -53,27 +53,26 @@ function LookupTable:reset(stdv)
end
function LookupTable:updateOutput(input)
+ input = input:long()
if input:dim() == 1 then
local nIndex = input:size(1)
self.size[1] = nIndex
- self.output:resize(self.size)
- for i=1,nIndex do
- self.output:select(1, i):copy(self.weight:select(1, input[i]))
- end
+ self.output:index(self.weight, 1, input)
elseif input:dim() == 2 then
local nExample = input:size(1)
local nIndex = input:size(2)
self.batchSize[1] = nExample
self.batchSize[2] = nIndex
- self.output:resize(self.batchSize)
-
- for i=1,nExample do
- local output = self.output:select(1, i)
- local input = input:select(1, i)
- for j=1,nIndex do
- output:select(1, j):copy(self.weight:select(1, input[j]))
- end
+ local indices
+ if input:isContiguous() then
+ indices = input:view(-1)
+ else
+ self._indices = self._indices or torch.LongTensor()
+ self._indices:resizeAs(input):copy(input)
+ indices = self._indices:view(-1)
end
+ self.output:index(self.weight, 1, indices)
+ self.output = self.output:view(nExample, nIndex, self.size[2])
end
return self.output