diff options
author | Adam Lerer <alerer@fb.com> | 2015-08-21 10:12:22 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-12-26 01:01:04 +0300 |
commit | df463695f0cd387736917e02abaafa63c00bbed3 (patch) | |
tree | ded8f3bd463e47a4a9c4b13f744e4010e6345703 /Tensor.lua | |
parent | d3c6fa5e2648f902f497fe81d0fa30c62552e4e9 (diff) |
Add generic CudaTensor types to cutorch
Diffstat (limited to 'Tensor.lua')
-rw-r--r-- | Tensor.lua | 67 |
1 files changed, 27 insertions, 40 deletions
@@ -21,55 +21,42 @@ end local function Tensor__typeAs(self,tensor) return self:type(tensor:type()) end -local function Tensor__cuda(self) - return self:type('torch.CudaTensor') -end -local function Tensor__double(self) - return self:type('torch.DoubleTensor') -end -local function Tensor__float(self) - return self:type('torch.FloatTensor') -end -local function Tensor__byte(self) - return self:type('torch.ByteTensor') -end +local TensorTypes = { + float = 'torch.FloatTensor', + double = 'torch.DoubleTensor', + byte = 'torch.ByteTensor', + char = 'torch.CharTensor', + int = 'torch.IntTensor', + short = 'torch.ShortTensor', + long = 'torch.LongTensor', + cuda = 'torch.CudaTensor', + cudaDouble = 'torch.CudaDoubleTensor', + cudaByte = 'torch.CudaByteTensor', + cudaChar = 'torch.CudaCharTensor', + cudaInt = 'torch.CudaIntTensor', + cudaShort = 'torch.CudaShortTensor', + cudaLong = 'torch.CudaLongTensor' +} -local function Tensor__char(self) - return self:type('torch.CharTensor') -end -local function Tensor__int(self) - return self:type('torch.IntTensor') +local function Tensor__converter(type) + return function(self) + return self:type(type) + end end -local function Tensor__short(self) - return self:type('torch.ShortTensor') +for _, SrcType in pairs(TensorTypes) do + for FuncName, DstType in pairs(TensorTypes) do + rawset(torch.getmetatable(SrcType), FuncName, Tensor__converter(DstType)) + end end -local function Tensor__long(self) - return self:type('torch.LongTensor') +for _, CudaTensorType in pairs(TensorTypes) do + rawset(torch.getmetatable(CudaTensorType), 'type', Tensor__type) + rawset(torch.getmetatable(CudaTensorType), 'typeAs', Tensor__typeAs) end -rawset(torch.getmetatable('torch.DoubleTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda) -rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda) - -rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type) -rawset(torch.getmetatable('torch.CudaTensor'), 'typeAs', Tensor__typeAs) -rawset(torch.getmetatable('torch.CudaTensor'), 'double', Tensor__double) -rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float) -rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte) -rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char) -rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int) -rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short) -rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long) - do local metatable = torch.getmetatable('torch.CudaTensor') for _,func in pairs{'expand', 'expandAs', 'view', 'viewAs', 'repeatTensor', |