Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-08-21 10:12:22 +0300
committerAdam Lerer <alerer@fb.com>2015-12-26 01:01:04 +0300
commitdf463695f0cd387736917e02abaafa63c00bbed3 (patch)
treeded8f3bd463e47a4a9c4b13f744e4010e6345703 /Tensor.lua
parentd3c6fa5e2648f902f497fe81d0fa30c62552e4e9 (diff)
Add generic CudaTensor types to cutorch
Diffstat (limited to 'Tensor.lua')
-rw-r--r--Tensor.lua67
1 files changed, 27 insertions, 40 deletions
diff --git a/Tensor.lua b/Tensor.lua
index f9c1ca4..32be5c7 100644
--- a/Tensor.lua
+++ b/Tensor.lua
@@ -21,55 +21,42 @@ end
local function Tensor__typeAs(self,tensor)
return self:type(tensor:type())
end
-local function Tensor__cuda(self)
- return self:type('torch.CudaTensor')
-end
-local function Tensor__double(self)
- return self:type('torch.DoubleTensor')
-end
-local function Tensor__float(self)
- return self:type('torch.FloatTensor')
-end
-local function Tensor__byte(self)
- return self:type('torch.ByteTensor')
-end
+local TensorTypes = {
+ float = 'torch.FloatTensor',
+ double = 'torch.DoubleTensor',
+ byte = 'torch.ByteTensor',
+ char = 'torch.CharTensor',
+ int = 'torch.IntTensor',
+ short = 'torch.ShortTensor',
+ long = 'torch.LongTensor',
+ cuda = 'torch.CudaTensor',
+ cudaDouble = 'torch.CudaDoubleTensor',
+ cudaByte = 'torch.CudaByteTensor',
+ cudaChar = 'torch.CudaCharTensor',
+ cudaInt = 'torch.CudaIntTensor',
+ cudaShort = 'torch.CudaShortTensor',
+ cudaLong = 'torch.CudaLongTensor'
+}
-local function Tensor__char(self)
- return self:type('torch.CharTensor')
-end
-local function Tensor__int(self)
- return self:type('torch.IntTensor')
+local function Tensor__converter(type)
+ return function(self)
+ return self:type(type)
+ end
end
-local function Tensor__short(self)
- return self:type('torch.ShortTensor')
+for _, SrcType in pairs(TensorTypes) do
+ for FuncName, DstType in pairs(TensorTypes) do
+ rawset(torch.getmetatable(SrcType), FuncName, Tensor__converter(DstType))
+ end
end
-local function Tensor__long(self)
- return self:type('torch.LongTensor')
+for _, CudaTensorType in pairs(TensorTypes) do
+ rawset(torch.getmetatable(CudaTensorType), 'type', Tensor__type)
+ rawset(torch.getmetatable(CudaTensorType), 'typeAs', Tensor__typeAs)
end
-rawset(torch.getmetatable('torch.DoubleTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda)
-rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda)
-
-rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type)
-rawset(torch.getmetatable('torch.CudaTensor'), 'typeAs', Tensor__typeAs)
-rawset(torch.getmetatable('torch.CudaTensor'), 'double', Tensor__double)
-rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float)
-rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte)
-rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char)
-rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int)
-rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short)
-rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long)
-
do
local metatable = torch.getmetatable('torch.CudaTensor')
for _,func in pairs{'expand', 'expandAs', 'view', 'viewAs', 'repeatTensor',