Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-04-30 19:28:54 +0300
committerAdam Lerer <alerer@fb.com>2015-04-30 20:43:06 +0300
commitd88ac24c712e3a40d4aaf3ac2d043bd79ba4280e (patch)
tree0fc0015fdbf8d532d986bbc738cb753dd06cbd46 /Tensor.lua
parent911f1fa0db6b00b82b705c3bf0761e26382fc2f9 (diff)
Auto device: API changes, bug fixes, README.md
- Change :cuda(device) overload to :cudaOn(device) - Add :cloneOn(device) - Fix bug in +,-,*,/ metamethods: checkGPU wasn't being called on these metamethods. - Add description of auto-device mode to README.md
Diffstat (limited to 'Tensor.lua')
-rw-r--r--Tensor.lua62
1 files changed, 36 insertions, 26 deletions
diff --git a/Tensor.lua b/Tensor.lua
index d78e236..4d19a90 100644
--- a/Tensor.lua
+++ b/Tensor.lua
@@ -30,20 +30,21 @@ end
local function Tensor__typeAs(self,tensor)
return self:type(tensor:type())
end
-local function Tensor__cuda(self,device)
- if device ~= nil then
- local curDev = cutorch.getDevice()
- cutorch.setDevice(device)
- local res = self:type('torch.CudaTensor')
- if res:nElement() == 0 then
- res:setDevice(device)
- end
- cutorch.setDevice(curDev)
- return res
- else
- return self:type('torch.CudaTensor')
+local function Tensor__cuda(self)
+ return self:type('torch.CudaTensor')
+end
+
+local function Tensor__cudaOn(self, device)
+ local curDev = cutorch.getDevice()
+ cutorch.setDevice(device)
+ local res = self:type('torch.CudaTensor')
+ if res:nElement() == 0 then
+ res:setDevice(device)
end
+ cutorch.setDevice(curDev)
+ return res
end
+
local function Tensor__double(self)
return self:type('torch.DoubleTensor')
end
@@ -72,23 +73,32 @@ local function Tensor__long(self)
end
rawset(torch.getmetatable('torch.DoubleTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda)
+rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda)
+
+rawset(torch.getmetatable('torch.DoubleTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.FloatTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.ByteTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.CharTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.IntTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.ShortTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.LongTensor'), 'cudaOn', Tensor__cudaOn)
+rawset(torch.getmetatable('torch.CudaTensor'), 'cudaOn', Tensor__cudaOn)
-rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type)
+rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type)
rawset(torch.getmetatable('torch.CudaTensor'), 'typeAs', Tensor__typeAs)
rawset(torch.getmetatable('torch.CudaTensor'), 'double', Tensor__double)
-rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float)
-rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte)
-rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char)
-rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int)
-rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short)
-rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long)
+rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float)
+rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte)
+rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char)
+rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int)
+rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short)
+rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long)
do
local metatable = torch.getmetatable('torch.CudaTensor')