diff options
author | Adam Lerer <alerer@fb.com> | 2015-04-30 19:28:54 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-04-30 20:43:06 +0300 |
commit | d88ac24c712e3a40d4aaf3ac2d043bd79ba4280e (patch) | |
tree | 0fc0015fdbf8d532d986bbc738cb753dd06cbd46 /Tensor.lua | |
parent | 911f1fa0db6b00b82b705c3bf0761e26382fc2f9 (diff) |
Auto device: API changes, bug fixes, README.md
- Change :cuda(device) overload to :cudaOn(device)
- Add :cloneOn(device)
- Fix bug in +,-,*,/ metamethods: checkGPU wasn't being called on these
metamethods.
- Add description of auto-device mode to README.md
Diffstat (limited to 'Tensor.lua')
-rw-r--r-- | Tensor.lua | 62 |
1 files changed, 36 insertions, 26 deletions
@@ -30,20 +30,21 @@ end local function Tensor__typeAs(self,tensor) return self:type(tensor:type()) end -local function Tensor__cuda(self,device) - if device ~= nil then - local curDev = cutorch.getDevice() - cutorch.setDevice(device) - local res = self:type('torch.CudaTensor') - if res:nElement() == 0 then - res:setDevice(device) - end - cutorch.setDevice(curDev) - return res - else - return self:type('torch.CudaTensor') +local function Tensor__cuda(self) + return self:type('torch.CudaTensor') +end + +local function Tensor__cudaOn(self, device) + local curDev = cutorch.getDevice() + cutorch.setDevice(device) + local res = self:type('torch.CudaTensor') + if res:nElement() == 0 then + res:setDevice(device) end + cutorch.setDevice(curDev) + return res end + local function Tensor__double(self) return self:type('torch.DoubleTensor') end @@ -72,23 +73,32 @@ local function Tensor__long(self) end rawset(torch.getmetatable('torch.DoubleTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda) + +rawset(torch.getmetatable('torch.DoubleTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.FloatTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.ByteTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.CharTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.IntTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.ShortTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.LongTensor'), 'cudaOn', Tensor__cudaOn) +rawset(torch.getmetatable('torch.CudaTensor'), 'cudaOn', Tensor__cudaOn) -rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type) +rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type) rawset(torch.getmetatable('torch.CudaTensor'), 'typeAs', Tensor__typeAs) rawset(torch.getmetatable('torch.CudaTensor'), 'double', Tensor__double) -rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float) -rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte) -rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char) -rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int) -rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short) -rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long) +rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float) +rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte) +rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char) +rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int) +rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short) +rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long) do local metatable = torch.getmetatable('torch.CudaTensor') |