Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-02-03 21:18:54 +0300
committerGregory Chanan <gchanan@fb.com>2017-02-03 21:27:22 +0300
commit57114b97f839c04e7063f8f3db0fb89f73ffcfcc (patch)
treed107047374fa664ae239592371a3faab6579e75d /Min.lua
parentdb4244e6ee30b0ec689815a00fcc0c8c45f91b12 (diff)
Support index parameters with Cuda*Tensor.
These modules check the type explicitly for CudaTensor rather than the regular expression. There are other modules that have this identical check, but do operations on the resulting type which may or may not be support (i.e. require more checking), so I didn't change them. These should be safe because the index type is independent of the type being checking.
Diffstat (limited to 'Min.lua')
-rw-r--r--Min.lua2
1 files changed, 1 insertions, 1 deletions
diff --git a/Min.lua b/Min.lua
index 252f52e..3a3e4a8 100644
--- a/Min.lua
+++ b/Min.lua
@@ -21,7 +21,7 @@ end
function Min:_lazyInit()
self._output = self._output or self.output.new()
if not self._indices then
- if torch.type(self.output) == 'torch.CudaTensor' then
+ if torch.typename(self.output):find('torch%.Cuda.*Tensor') then
self._indices = torch.CudaLongTensor and torch.CudaLongTensor() or torch.CudaTensor()
else
self._indices = torch.LongTensor()