diff options
author | gchanan <gregchanan@gmail.com> | 2016-12-29 22:23:26 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-12-29 22:23:26 +0300 |
commit | a0c0b78471df5f4507791e870cf7df9607a64400 (patch) | |
tree | f91b908684ba7ff4727df1331d3a5c9b4b3b9cb8 /Tensor.lua | |
parent | 7ca7ec9d08f1ef2c753e72cbd014397736d6b5af (diff) |
Add support for torch.HalfTensor (#874)
* Add support for torch.HalfTensor.
* Improvements/Simplifications for torch.HalfTensor.
Improvements/Simplifications:
1) Defines half type as TH_Half, so as to not conflict with cutorch
version. Previously, these were defined as the same "half" type and
required proper ordering of includes to ensure type was only defined
once, which would have affected all downstream projects.
2) No longer generates math functions that are not actually defined
on torch.HalfTensor, e.g. maskedFill, map, etc.
3) Adds tests for all available torch.HalfTensor functions
4) Allows compiling without TH_GENERIC_USE_HALF (so if there's a
problem can just unset that in CMakeLists rather than backing out)
5) Some simplifications: removes a new copy optimization and
some TH_HALF literal definitions
Limitations:
Because match functions are not defined, some "non-math" operators
on torch.HalfTensor give an error message, e.g. __index__/__newindex__
with a ByteTensor apply a mask, but masks aren't implemented. These
limitations aren't always obvious, (e.g. for documentation purposes),
but they should always give an error message.
* Rename TH_HALF to THHalf.
Diffstat (limited to 'Tensor.lua')
-rw-r--r-- | Tensor.lua | 18 |
1 files changed, 15 insertions, 3 deletions
@@ -5,14 +5,14 @@ local Storage = {} local Tensor = {} -- types -local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} +local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'} -- Lua 5.2 compatibility local log10 = math.log10 or function(x) return math.log(x, 10) end -- tostring() functions for Tensor and Storage local function Storage__printformat(self) - if self:size() == 0 then + if self:size() == 0 then return "", nil, 0 end local intMode = true @@ -277,6 +277,10 @@ function Tensor.double(self) return self:type('torch.DoubleTensor') end +function Tensor.half(self) + return self:type('torch.HalfTensor') +end + function Tensor.real(self) return self:type(torch.getdefaulttensortype()) end @@ -556,6 +560,14 @@ torch.permute = Tensor.permute for _,type in ipairs(types) do local metatable = torch.getmetatable('torch.' .. type .. 'Tensor') for funcname, func in pairs(Tensor) do - rawset(metatable, funcname, func) + if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then + rawset(metatable, funcname, func) + else + local function Tensor__totable(self) + local host_tensor = self:float() + return self:float():totable() + end + rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable) + end end end |