Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-13 19:45:22 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-20 00:53:19 +0300
commit915bd8711b224467262e2e7bfb1f5ace3f7b99ad (patch)
tree4015454631538c734faff0c41fd9772411bb25fd /TemporalMaxPooling.lua
parent09b9966cb3aafc4852806a2a4f5b50dc0711a3ea (diff)
Indices for nn.
Diffstat (limited to 'TemporalMaxPooling.lua')
-rw-r--r--TemporalMaxPooling.lua7
1 files changed, 6 insertions, 1 deletions
diff --git a/TemporalMaxPooling.lua b/TemporalMaxPooling.lua
index 91723e6..894f4a9 100644
--- a/TemporalMaxPooling.lua
+++ b/TemporalMaxPooling.lua
@@ -10,7 +10,12 @@ function TemporalMaxPooling:__init(kW, dW)
end
function TemporalMaxPooling:updateOutput(input)
- self.indices = self.indices or input.new()
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
input.THNN.TemporalMaxPooling_updateOutput(
input:cdata(), self.output:cdata(),
self.indices:cdata(), self.kW, self.dW