diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-10-13 18:33:15 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-10-20 00:53:10 +0300 |
commit | 09b9966cb3aafc4852806a2a4f5b50dc0711a3ea (patch) | |
tree | a287caec59cafbab848bf0be605bbb2be769a178 /SpatialAdaptiveMaxPooling.lua | |
parent | 1864f6502caf81bbd0e551959dfe0a803162ecbf (diff) |
Use index types for SpatialAdaptiveMaxPooling indices.
Diffstat (limited to 'SpatialAdaptiveMaxPooling.lua')
-rw-r--r-- | SpatialAdaptiveMaxPooling.lua | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/SpatialAdaptiveMaxPooling.lua b/SpatialAdaptiveMaxPooling.lua index 74d4cd6..a2cf104 100644 --- a/SpatialAdaptiveMaxPooling.lua +++ b/SpatialAdaptiveMaxPooling.lua @@ -2,13 +2,18 @@ local SpatialAdaptiveMaxPooling, parent = torch.class('nn.SpatialAdaptiveMaxPool function SpatialAdaptiveMaxPooling:__init(W, H) parent.__init(self) - + self.W = W self.H = H end function SpatialAdaptiveMaxPooling:updateOutput(input) - self.indices = self.indices or input.new() + self.indices = self.indices or torch.LongTensor() + if torch.typename(input):find('torch%.Cuda.*Tensor') then + self.indices = torch.CudaLongTensor and self.indices:cudaLong() or self.indices + else + self.indices = self.indices:long() + end input.THNN.SpatialAdaptiveMaxPooling_updateOutput( input:cdata(), self.output:cdata(), |