diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-10-13 19:45:22 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-10-20 00:53:19 +0300 |
commit | 915bd8711b224467262e2e7bfb1f5ace3f7b99ad (patch) | |
tree | 4015454631538c734faff0c41fd9772411bb25fd /TemporalMaxPooling.lua | |
parent | 09b9966cb3aafc4852806a2a4f5b50dc0711a3ea (diff) |
Indices for nn.
Diffstat (limited to 'TemporalMaxPooling.lua')
-rw-r--r-- | TemporalMaxPooling.lua | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/TemporalMaxPooling.lua b/TemporalMaxPooling.lua index 91723e6..894f4a9 100644 --- a/TemporalMaxPooling.lua +++ b/TemporalMaxPooling.lua @@ -10,7 +10,12 @@ function TemporalMaxPooling:__init(kW, dW) end function TemporalMaxPooling:updateOutput(input) - self.indices = self.indices or input.new() + self.indices = self.indices or torch.LongTensor() + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices + else + self.indices = self.indices:long() + end input.THNN.TemporalMaxPooling_updateOutput( input:cdata(), self.output:cdata(), self.indices:cdata(), self.kW, self.dW |