Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2016-12-29 22:23:26 +0300
committerSoumith Chintala <soumith@gmail.com>2016-12-29 22:23:26 +0300
commita0c0b78471df5f4507791e870cf7df9607a64400 (patch)
treef91b908684ba7ff4727df1331d3a5c9b4b3b9cb8 /Tensor.lua
parent7ca7ec9d08f1ef2c753e72cbd014397736d6b5af (diff)
Add support for torch.HalfTensor (#874)
* Add support for torch.HalfTensor. * Improvements/Simplifications for torch.HalfTensor. Improvements/Simplifications: 1) Defines half type as TH_Half, so as to not conflict with cutorch version. Previously, these were defined as the same "half" type and required proper ordering of includes to ensure type was only defined once, which would have affected all downstream projects. 2) No longer generates math functions that are not actually defined on torch.HalfTensor, e.g. maskedFill, map, etc. 3) Adds tests for all available torch.HalfTensor functions 4) Allows compiling without TH_GENERIC_USE_HALF (so if there's a problem can just unset that in CMakeLists rather than backing out) 5) Some simplifications: removes a new copy optimization and some TH_HALF literal definitions Limitations: Because match functions are not defined, some "non-math" operators on torch.HalfTensor give an error message, e.g. __index__/__newindex__ with a ByteTensor apply a mask, but masks aren't implemented. These limitations aren't always obvious, (e.g. for documentation purposes), but they should always give an error message. * Rename TH_HALF to THHalf.
Diffstat (limited to 'Tensor.lua')
-rw-r--r--Tensor.lua18
1 files changed, 15 insertions, 3 deletions
diff --git a/Tensor.lua b/Tensor.lua
index b4b3e95..36307bd 100644
--- a/Tensor.lua
+++ b/Tensor.lua
@@ -5,14 +5,14 @@ local Storage = {}
local Tensor = {}
-- types
-local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
+local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}
-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end
-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
- if self:size() == 0 then
+ if self:size() == 0 then
return "", nil, 0
end
local intMode = true
@@ -277,6 +277,10 @@ function Tensor.double(self)
return self:type('torch.DoubleTensor')
end
+function Tensor.half(self)
+ return self:type('torch.HalfTensor')
+end
+
function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
@@ -556,6 +560,14 @@ torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
- rawset(metatable, funcname, func)
+ if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then
+ rawset(metatable, funcname, func)
+ else
+ local function Tensor__totable(self)
+ local host_tensor = self:float()
+ return self:float():totable()
+ end
+ rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable)
+ end
end
end