Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-13 18:33:15 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-20 00:53:10 +0300
commit09b9966cb3aafc4852806a2a4f5b50dc0711a3ea (patch)
treea287caec59cafbab848bf0be605bbb2be769a178 /SpatialAdaptiveMaxPooling.lua
parent1864f6502caf81bbd0e551959dfe0a803162ecbf (diff)
Use index types for SpatialAdaptiveMaxPooling indices.
Diffstat (limited to 'SpatialAdaptiveMaxPooling.lua')
-rw-r--r--SpatialAdaptiveMaxPooling.lua9
1 files changed, 7 insertions, 2 deletions
diff --git a/SpatialAdaptiveMaxPooling.lua b/SpatialAdaptiveMaxPooling.lua
index 74d4cd6..a2cf104 100644
--- a/SpatialAdaptiveMaxPooling.lua
+++ b/SpatialAdaptiveMaxPooling.lua
@@ -2,13 +2,18 @@ local SpatialAdaptiveMaxPooling, parent = torch.class('nn.SpatialAdaptiveMaxPool
function SpatialAdaptiveMaxPooling:__init(W, H)
parent.__init(self)
-
+
self.W = W
self.H = H
end
function SpatialAdaptiveMaxPooling:updateOutput(input)
- self.indices = self.indices or input.new()
+ self.indices = self.indices or torch.LongTensor()
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices
+ else
+ self.indices = self.indices:long()
+ end
input.THNN.SpatialAdaptiveMaxPooling_updateOutput(
input:cdata(),
self.output:cdata(),