diff options
author | Kai Sheng Tai <kaishengtai@gmail.com> | 2014-11-09 22:39:27 +0300 |
---|---|---|
committer | Kai Sheng Tai <kaishengtai@gmail.com> | 2014-11-10 21:38:53 +0300 |
commit | dbbd8798625149fccea0320e19f72ecc438b03d0 (patch) | |
tree | 4558c6ccf4538c9efdfcb50c93c41e24b7b0c317 /LookupTable.lua | |
parent | 704684a27efd82da3f4ac05cc9ecb6f44aa6d510 (diff) |
Performance improvement to LookupTable
Diffstat (limited to 'LookupTable.lua')
-rw-r--r-- | LookupTable.lua | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/LookupTable.lua b/LookupTable.lua index c286383..fa8febf 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -53,27 +53,26 @@ function LookupTable:reset(stdv) end function LookupTable:updateOutput(input) + input = input:long() if input:dim() == 1 then local nIndex = input:size(1) self.size[1] = nIndex - self.output:resize(self.size) - for i=1,nIndex do - self.output:select(1, i):copy(self.weight:select(1, input[i])) - end + self.output:index(self.weight, 1, input) elseif input:dim() == 2 then local nExample = input:size(1) local nIndex = input:size(2) self.batchSize[1] = nExample self.batchSize[2] = nIndex - self.output:resize(self.batchSize) - - for i=1,nExample do - local output = self.output:select(1, i) - local input = input:select(1, i) - for j=1,nIndex do - output:select(1, j):copy(self.weight:select(1, input[j])) - end + local indices + if input:isContiguous() then + indices = input:view(-1) + else + self._indices = self._indices or torch.LongTensor() + self._indices:resizeAs(input):copy(input) + indices = self._indices:view(-1) end + self.output:index(self.weight, 1, indices) + self.output = self.output:view(nExample, nIndex, self.size[2]) end return self.output |