diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-20 04:56:39 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-20 04:56:39 +0300 |
commit | e839c51c64c563621df1519ef98860e250ad6bdf (patch) | |
tree | 24b7da2070b18825ff38eff2ff9b0eeeebb60344 | |
parent | f9b37b03b5044ea52a9faabb0820ba42e2afa465 (diff) |
Half type (#870)fp16
* torch.HalfTensor
38 files changed, 703 insertions, 109 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 611258b..fb2de09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ IF(MSVC) ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) ENDIF(MSVC) +ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1) + # OpenMP support? SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) @@ -15,6 +15,7 @@ local function checkArgumentType(expected, actual, fn, ud, level) end if ok then + local Real2real = { Byte='unsigned char', Char='char', @@ -22,7 +23,8 @@ if ok then Int='int', Long='long', Float='float', - Double='double' + Double='double', + Half='half' } -- Allocator @@ -33,11 +35,14 @@ typedef struct THAllocator { void (*free)(void*, void*); } THAllocator; ]] - -- Storage for Real, real in pairs(Real2real) do local cdefs = [[ +typedef struct { + unsigned short x; +} half; + typedef struct THRealStorage { real *data; @@ -76,7 +81,7 @@ typedef struct THRealTensor long *size; long *stride; int nDimension; - + THRealStorage *storage; ptrdiff_t storageOffset; int refcount; @@ -5,14 +5,14 @@ local Storage = {} local Tensor = {} -- types -local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'} +local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'} -- Lua 5.2 compatibility local log10 = math.log10 or function(x) return math.log(x, 10) end -- tostring() functions for Tensor and Storage local function Storage__printformat(self) - if self:size() == 0 then + if self:size() == 0 then return "", nil, 0 end local intMode = true @@ -277,6 +277,10 @@ function Tensor.double(self) return self:type('torch.DoubleTensor') end +function Tensor.half(self) + return self:type('torch.HalfTensor') +end + function Tensor.real(self) return self:type(torch.getdefaulttensortype()) end diff --git a/TensorMath.lua b/TensorMath.lua index 682de23..6890ce8 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -6,56 +6,8 @@ local interface = wrap.CInterface.new() local method = wrap.CInterface.new() local argtypes = wrap.CInterface.argtypes -argtypes['ptrdiff_t'] = { - - helpname = function(arg) - return 'ptrdiff_t' - end, - - declare = function(arg) - -- if it is a number we initialize here - local default = tonumber(tostring(arg.default)) or 0 - return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default) - end, - - check = function(arg, idx) - return string.format("lua_isnumber(L, %d)", idx) - end, - - read = function(arg, idx) - return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx) - end, - - init = function(arg) - -- otherwise do it here - if arg.default then - local default = tostring(arg.default) - if not tonumber(default) then - return string.format("arg%d = %s;", arg.i, default) - end - end - end, - - carg = function(arg) - return string.format('arg%d', arg.i) - end, - - creturn = function(arg) - return string.format('arg%d', arg.i) - end, - - precall = function(arg) - if arg.returned then - return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) - end - end, - - postcall = function(arg) - if arg.creturned then - return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) - end - end -} +argtypes['ptrdiff_t'] = wrap.types.ptrdiff_t +argtypes['half'] = wrap.types.half interface:print([[ #include "TH.h" @@ -216,6 +168,7 @@ local reals = {ByteTensor='unsigned char', IntTensor='int', LongTensor='long', FloatTensor='float', + HalfTensor='half', DoubleTensor='double'} local accreals = {ByteTensor='long', @@ -224,11 +177,12 @@ local accreals = {ByteTensor='long', IntTensor='long', LongTensor='long', FloatTensor='double', + HalfTensor='float', DoubleTensor='double'} for _,Tensor in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor", - "FloatTensor", "DoubleTensor"}) do + "FloatTensor", "HalfTensor", "DoubleTensor"}) do local real = reals[Tensor] local accreal = accreals[Tensor] @@ -257,6 +211,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", end end + if Tensor ~= 'HalfTensor' then wrap("zero", cname("zero"), {{name=Tensor, returned=true}}) @@ -1030,6 +985,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b) cname("nonzero"), {{name="IndexTensor", default=true, returned=true}, {name=Tensor}}) + end -- ~= HalfTensor if Tensor == 'ByteTensor' then -- Logical accumulators only apply to ByteTensor @@ -1483,6 +1439,9 @@ void torch_TensorMath_init(lua_State *L) torch_IntTensorMath_init(L); torch_LongTensorMath_init(L); torch_FloatTensorMath_init(L); + #if TH_NATIVE_HALF + torch_HalfTensorMath_init(L); + #endif torch_DoubleTensorMath_init(L); luaT_setfuncs(L, torch_TensorMath__, 0); } @@ -687,6 +687,7 @@ local typesMatching = { ['torch.LongStorage'] = torch.LongTensor, ['torch.FloatStorage'] = torch.FloatTensor, ['torch.DoubleStorage'] = torch.DoubleTensor, + ['torch.HalfStorage'] = torch.HalfTensor, } --[[ Tests for storage equality. diff --git a/generic/Storage.c b/generic/Storage.c index 134dc63..a6652a5 100644 --- a/generic/Storage.c +++ b/generic/Storage.c @@ -41,7 +41,7 @@ static int torch_Storage_(new)(lua_State *L) THStorage_(free)(storage); luaL_error(L, "element at index %d is not a number", i); } - THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1)); + THStorage_(set)(storage, i-1, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1))); lua_pop(L, 1); } } @@ -131,6 +131,8 @@ static int torch_Storage_(copy)(lua_State *L) THStorage_(copyFloat)(storage, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) THStorage_(copyDouble)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) ) + THStorage_(copyHalf)(storage, src); else luaL_typerror(L, 2, "torch.*Storage"); lua_settop(L, 1); diff --git a/generic/Tensor.c b/generic/Tensor.c index abb7819..b1dca69 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -142,7 +142,7 @@ static int torch_Tensor_(new)(lua_State *L) THTensor_(free)(tensor); THError("invalid element (not a number)"); } - THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1)); + THStorage_(set)(THTensor_(storage)(tensor), si++, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1))); lua_pop(L, 1); } @@ -675,6 +675,8 @@ static int torch_Tensor_(copy)(lua_State *L) THTensor_(copyFloat)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) THTensor_(copyDouble)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.HalfTensor")) ) + THTensor_(copyHalf)(tensor, src); else luaL_typerror(L, 2, "torch.*Tensor"); lua_settop(L, 1); @@ -745,6 +747,11 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyDouble)(tensor, src); THTensor_(free)(tensor); + } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { + tensor = THTensor_(newWithTensor)(tensor); + THTensor_(narrow)(tensor, NULL, 0, index, 1); + THTensor_(copyHalf)(tensor, src); + THTensor_(free)(tensor); } else { luaL_typerror(L, 3, "torch.*Tensor"); } @@ -829,7 +836,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) /* doing a copy */ void *src; if (lua_isnumber(L,3)) { - THTensor_(fill)(tensor, lua_tonumber(L,3)); + THTensor_(fill)(tensor, LUA_NUMBER_TO_REAL(lua_tonumber(L,3))); } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) { THTensor_(copy)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) { @@ -846,6 +853,8 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor_(copyFloat)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { THTensor_(copyDouble)(tensor, src); + } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { + THTensor_(copyHalf)(tensor, src); } else { luaL_typerror(L, 3, "torch.*Tensor"); } @@ -916,7 +925,7 @@ static int torch_Tensor_(__index__)(lua_State *L) THArgCheck((z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); index += z*tensor->stride[dim]; } - luaG_(pushreal)(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index)); + luaG_(pushreal)(L, THStorage_(get)(THTensor_(storage)(tensor), index)); lua_pushboolean(L, 1); return 2; } @@ -1143,7 +1152,7 @@ static int torch_Tensor_(apply)(lua_State *L) lua_call(L, 1, 1); if(lua_isnumber(L, 3)) { - *tensor_data = (real)lua_tonumber(L, 3); + *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 3)); lua_pop(L, 1); } else if(lua_isnil(L, 3)) @@ -1169,7 +1178,7 @@ static int torch_Tensor_(map)(lua_State *L) lua_call(L, 2, 1); if(lua_isnumber(L, 4)) { - *tensor_data = (real)lua_tonumber(L, 4); + *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 4)); lua_pop(L, 1); } else if(lua_isnil(L, 4)) @@ -1197,7 +1206,7 @@ static int torch_Tensor_(map2)(lua_State *L) lua_call(L, 3, 1); if(lua_isnumber(L, 5)) { - *tensor_data = (real)lua_tonumber(L, 5); + *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 5)); lua_pop(L, 1); } else if(lua_isnil(L, 5)) @@ -1318,7 +1327,10 @@ void torch_Tensor_(init)(lua_State *L) torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); luaT_setfuncs(L, torch_Tensor_(_), 0); lua_pop(L, 1); +#ifndef TH_GENERIC_NO_MATH THVector_(vectorDispatchInit)(); +#endif + } #endif diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c index eba6c81..c722e88 100644 --- a/generic/TensorOperator.c +++ b/generic/TensorOperator.c @@ -2,6 +2,9 @@ #define TH_GENERIC_FILE "generic/TensorOperator.c" #else +/* Tensor math may be disabled for certain types, e.g. 'half' */ +#ifndef TH_GENERIC_NO_MATH + static int torch_TensorOperator_(__add__)(lua_State *L) { THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor); @@ -14,7 +17,7 @@ static int torch_TensorOperator_(__add__)(lua_State *L) { r = THTensor_(new)(); luaT_pushudata(L, r, torch_Tensor); - + if(!tensor1 && tensor2) { THTensor_(resizeAs)(r, tensor2); @@ -49,7 +52,7 @@ static int torch_TensorOperator_(__sub__)(lua_State *L) { r = THTensor_(new)(); luaT_pushudata(L, r, torch_Tensor); - + if(!tensor1 && tensor2) { THTensor_(resizeAs)(r, tensor2); @@ -98,7 +101,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L) { r = THTensor_(new)(); luaT_pushudata(L, r, torch_Tensor); - + if(!tensor1 && tensor2) { THTensor_(resizeAs)(r, tensor2); @@ -115,7 +118,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L) { int dimt = tensor1->nDimension; int dims = tensor2->nDimension; - + if(dimt == 1 && dims == 1) lua_pushnumber(L, THTensor_(dot)(tensor1, tensor2)); /* ok, we wasted r, but who cares */ else if(dimt == 2 && dims == 1) @@ -131,7 +134,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L) THTensor_(addmm)(r, 1, r, 1, tensor1, tensor2); } else - luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension); + luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension); } } return 1; @@ -187,5 +190,6 @@ void torch_TensorOperator_(init)(lua_State *L) luaT_setfuncs(L, torch_TensorOperator_(_), 0); lua_pop(L, 1); } +#endif #endif diff --git a/generic/luaG.h b/generic/luaG.h index 950eae9..f1ffce2 100644 --- a/generic/luaG.h +++ b/generic/luaG.h @@ -6,10 +6,24 @@ #define luaG_(NAME) TH_CONCAT_3(luaG_,Real,NAME) #endif -static void luaG_(pushreal)(lua_State *L, accreal n) { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503 - lua_pushnumber(L, (lua_Number)n); -#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG) +#undef REAL_TO_LUA_NUMBER +#undef LUA_NUMBER_TO_REAL + +#if defined(TH_REAL_IS_HALF) +# define REAL_TO_LUA_NUMBER(n) (lua_Number)TH_half2float(n) +# define LUA_NUMBER_TO_REAL(n) TH_float2half((lua_Number)n) +#else +# define REAL_TO_LUA_NUMBER(n) (lua_Number)(n) +# define LUA_NUMBER_TO_REAL(n) (real)n +#endif + + + +static void luaG_(pushreal)(lua_State *L, real n) { +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || LUA_VERSION_NUM < 503 + lua_pushnumber(L, REAL_TO_LUA_NUMBER(n)); +#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) \ + || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG) lua_pushinteger(L, (lua_Integer)n); #else #error "unhandled real type in luaG_pushreal" @@ -17,8 +31,8 @@ static void luaG_(pushreal)(lua_State *L, accreal n) { } static real luaG_(checkreal)(lua_State *L, int idx) { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) - return (lua_Number)luaL_checknumber(L, idx); +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) + return LUA_NUMBER_TO_REAL(luaL_checknumber(L, idx)); #elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG) int type = lua_type(L, idx); if (type == LUA_TSTRING) { @@ -38,8 +52,8 @@ static real luaG_(checkreal)(lua_State *L, int idx) { } static real luaG_(optreal)(lua_State *L, int idx, real n) { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503 - return (lua_Number)luaL_optnumber(L, idx, (lua_Number)n); +#if defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503 + return LUA_NUMBER_TO_REAL(luaL_optnumber(L, idx, REAL_TO_LUA_NUMBER(n))); #elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG) return (lua_Integer)luaL_optinteger(L, idx, (lua_Integer)n); #else @@ -16,6 +16,7 @@ extern void torch_IntStorage_init(lua_State *L); extern void torch_LongStorage_init(lua_State *L); extern void torch_FloatStorage_init(lua_State *L); extern void torch_DoubleStorage_init(lua_State *L); +extern void torch_HalfStorage_init(lua_State *L); extern void torch_ByteTensor_init(lua_State *L); extern void torch_CharTensor_init(lua_State *L); @@ -24,6 +25,7 @@ extern void torch_IntTensor_init(lua_State *L); extern void torch_LongTensor_init(lua_State *L); extern void torch_FloatTensor_init(lua_State *L); extern void torch_DoubleTensor_init(lua_State *L); +extern void torch_HalfTensor_init(lua_State *L); extern void torch_ByteTensorOperator_init(lua_State *L); extern void torch_CharTensorOperator_init(lua_State *L); @@ -33,8 +35,29 @@ extern void torch_LongTensorOperator_init(lua_State *L); extern void torch_FloatTensorOperator_init(lua_State *L); extern void torch_DoubleTensorOperator_init(lua_State *L); +#if TH_NATIVE_HALF +extern void torch_HalfTensorOperator_init(lua_State *L); +#endif + extern void torch_TensorMath_init(lua_State *L); +static int torch_hashalfmath(lua_State *L) { + lua_pushboolean(L, TH_NATIVE_HALF); + return 1; +} + +static void torch_half_init(lua_State *L) +{ + const struct luaL_Reg half_funcs__ [] = { + {"hashalfmath", torch_hashalfmath}, + {NULL, NULL} + }; + luaT_setfuncs(L, half_funcs__, 0); + + lua_pushboolean(L, 1); + lua_setfield(L, -2, "hasHalf"); +} + LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L); int luaopen_libtorch(lua_State *L) @@ -45,8 +68,8 @@ int luaopen_libtorch(lua_State *L) lua_setglobal(L, "torch"); torch_utils_init(L); - torch_File_init(L); + torch_half_init(L); torch_ByteStorage_init(L); torch_CharStorage_init(L); @@ -55,6 +78,7 @@ int luaopen_libtorch(lua_State *L) torch_LongStorage_init(L); torch_FloatStorage_init(L); torch_DoubleStorage_init(L); + torch_HalfStorage_init(L); torch_ByteTensor_init(L); torch_CharTensor_init(L); @@ -63,6 +87,7 @@ int luaopen_libtorch(lua_State *L) torch_LongTensor_init(L); torch_FloatTensor_init(L); torch_DoubleTensor_init(L); + torch_HalfTensor_init(L); torch_ByteTensorOperator_init(L); torch_CharTensorOperator_init(L); @@ -72,6 +97,10 @@ int luaopen_libtorch(lua_State *L) torch_FloatTensorOperator_init(L); torch_DoubleTensorOperator_init(L); +#if TH_NATIVE_HALF + torch_HalfTensorOperator_init(L); +#endif + torch_Timer_init(L); torch_DiskFile_init(L); torch_PipeFile_init(L); diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt index e6cf91d..3912bf2 100644 --- a/lib/TH/CMakeLists.txt +++ b/lib/TH/CMakeLists.txt @@ -122,11 +122,11 @@ IF(C_AVX_FOUND) ENDIF(C_AVX_FOUND) SET(hdr - THGeneral.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h - THLapack.h THLogAdd.h THRandom.h THVector.h THAtomic.h) + THGeneral.h THHalf.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h + THLapack.h THLogAdd.h THRandom.h THVector.h THAtomic.h ) SET(src - THGeneral.c THAllocator.c THStorage.c THTensor.c THBlas.c THLapack.c + THGeneral.c THHalf.c THAllocator.c THStorage.c THTensor.c THBlas.c THLapack.c THLogAdd.c THRandom.c THFile.c THDiskFile.c THMemoryFile.c THAtomic.c THVector.c) SET(src ${src} ${hdr} ${simd}) @@ -333,6 +333,7 @@ INSTALL(FILES THTensorMacros.h THVector.h THAtomic.h + THHalf.h DESTINATION "${TH_INSTALL_INCLUDE_SUBDIR}/TH") INSTALL(FILES diff --git a/lib/TH/TH.h b/lib/TH/TH.h index cdf331d..8b676cf 100644 --- a/lib/TH/TH.h +++ b/lib/TH/TH.h @@ -3,6 +3,10 @@ #include "THGeneral.h" +#if TH_GENERIC_USE_HALF +# include "THHalf.h" +#endif + #include "THBlas.h" #ifdef USE_LAPACK #include "THLapack.h" diff --git a/lib/TH/THDiskFile.c b/lib/TH/THDiskFile.c index 9d9cbae..a397027 100644 --- a/lib/TH/THDiskFile.c +++ b/lib/TH/THDiskFile.c @@ -355,7 +355,11 @@ READ_WRITE_METHODS(int, Int, READ_WRITE_METHODS(float, Float, int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++) - +#if TH_GENERIC_USE_HALF +READ_WRITE_METHODS(half, Half, + float buf; int ret = fscanf(dfself->handle, "%g", &buf); if(ret <= 0) break; else { data[i]= TH_float2half(buf); nread++; }, + int ret = fprintf(dfself->handle, "%.9g", TH_half2float(data[i])); if(ret <= 0) break; else nwrite++) +#endif READ_WRITE_METHODS(double, Double, int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++) @@ -618,6 +622,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_readLong, THDiskFile_readFloat, THDiskFile_readDouble, + THDiskFile_readHalf, THDiskFile_readString, THDiskFile_writeByte, @@ -627,6 +632,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_writeLong, THDiskFile_writeFloat, THDiskFile_writeDouble, + THDiskFile_writeHalf, THDiskFile_writeString, THDiskFile_synchronize, @@ -730,6 +736,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_readLong, THDiskFile_readFloat, THDiskFile_readDouble, + THDiskFile_readHalf, THDiskFile_readString, THDiskFile_writeByte, @@ -739,6 +746,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_writeLong, THDiskFile_writeFloat, THDiskFile_writeDouble, + THDiskFile_writeHalf, THDiskFile_writeString, THDiskFile_synchronize, diff --git a/lib/TH/THFile.c b/lib/TH/THFile.c index c8913af..f2bc8d7 100644 --- a/lib/TH/THFile.c +++ b/lib/TH/THFile.c @@ -19,6 +19,9 @@ IMPLEMENT_THFILE_RW(Int, int) IMPLEMENT_THFILE_RW(Long, long) IMPLEMENT_THFILE_RW(Float, float) IMPLEMENT_THFILE_RW(Double, double) +#if TH_GENERIC_USE_HALF +IMPLEMENT_THFILE_RW(Half, half) +#endif size_t THFile_readStringRaw(THFile *self, const char *format, char **str_) { @@ -133,6 +136,7 @@ IMPLEMENT_THFILE_SCALAR(Int, int) IMPLEMENT_THFILE_SCALAR(Long, long) IMPLEMENT_THFILE_SCALAR(Float, float) IMPLEMENT_THFILE_SCALAR(Double, double) +IMPLEMENT_THFILE_SCALAR(Half, half) #define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ size_t THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ @@ -152,3 +156,4 @@ IMPLEMENT_THFILE_STORAGE(Int, int) IMPLEMENT_THFILE_STORAGE(Long, long) IMPLEMENT_THFILE_STORAGE(Float, float) IMPLEMENT_THFILE_STORAGE(Double, double) +IMPLEMENT_THFILE_STORAGE(Half, half) diff --git a/lib/TH/THFile.h b/lib/TH/THFile.h index 64dd2da..7727196 100644 --- a/lib/TH/THFile.h +++ b/lib/TH/THFile.h @@ -74,6 +74,15 @@ TH_API size_t THFile_writeFloatRaw(THFile *self, float *data, size_t n); TH_API size_t THFile_writeDoubleRaw(THFile *self, double *data, size_t n); TH_API size_t THFile_writeStringRaw(THFile *self, const char *str, size_t size); +#if TH_GENERIC_USE_HALF +TH_API half THFile_readHalfScalar(THFile *self); +TH_API void THFile_writeHalfScalar(THFile *self, half scalar); +TH_API size_t THFile_readHalf(THFile *self, THHalfStorage *storage); +TH_API size_t THFile_writeHalf(THFile *self, THHalfStorage *storage); +TH_API size_t THFile_readHalfRaw(THFile *self, half* data, size_t size); +TH_API size_t THFile_writeHalfRaw(THFile *self, half* data, size_t size); +#endif + TH_API void THFile_synchronize(THFile *self); TH_API void THFile_seek(THFile *self, size_t position); TH_API void THFile_seekEnd(THFile *self); diff --git a/lib/TH/THFilePrivate.h b/lib/TH/THFilePrivate.h index d268041..f0b0885 100644 --- a/lib/TH/THFilePrivate.h +++ b/lib/TH/THFilePrivate.h @@ -1,3 +1,10 @@ +#include "THGeneral.h" + +#if TH_GENERIC_USE_HALF +# include "THHalf.h" +#endif + + struct THFile__ { struct THFileVTable *vtable; @@ -23,6 +30,9 @@ struct THFileVTable size_t (*readLong)(THFile *self, long *data, size_t n); size_t (*readFloat)(THFile *self, float *data, size_t n); size_t (*readDouble)(THFile *self, double *data, size_t n); +#if TH_GENERIC_USE_HALF + size_t (*readHalf)(THFile *self, half *data, size_t n); +#endif size_t (*readString)(THFile *self, const char *format, char **str_); size_t (*writeByte)(THFile *self, unsigned char *data, size_t n); @@ -32,6 +42,9 @@ struct THFileVTable size_t (*writeLong)(THFile *self, long *data, size_t n); size_t (*writeFloat)(THFile *self, float *data, size_t n); size_t (*writeDouble)(THFile *self, double *data, size_t n); +#if TH_GENERIC_USE_HALF + size_t (*writeHalf)(THFile *self, half *data, size_t n); +#endif size_t (*writeString)(THFile *self, const char *str, size_t size); void (*synchronize)(THFile *self); diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index bc7e448..6873db6 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -46,6 +46,14 @@ #define TH_INDEX_BASE 1 #endif +#ifndef TH_GENERIC_USE_HALF +# define TH_GENERIC_USE_HALF 0 +#endif + +#ifndef TH_NATIVE_HALF +# define TH_NATIVE_HALF 0 +#endif + typedef void (*THErrorHandlerFunction)(const char *msg, void *data); typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data); diff --git a/lib/TH/THGenerateAllTypes.h b/lib/TH/THGenerateAllTypes.h index 539629b..093bfd9 100644 --- a/lib/TH/THGenerateAllTypes.h +++ b/lib/TH/THGenerateAllTypes.h @@ -2,6 +2,16 @@ #error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h" #endif +#define THTypeIdxByte 1 +#define THTypeIdxChar 2 +#define THTypeIdxShort 3 +#define THTypeIdxInt 4 +#define THTypeIdxLong 5 +#define THTypeIdxFloat 6 +#define THTypeIdxDouble 7 +#define THTypeIdxHalf 8 +#define THTypeIdx_(T) TH_CONCAT_2(THTypeIdx,T) + #define real unsigned char #define accreal long #define Real Byte @@ -94,4 +104,24 @@ #undef THInf #undef TH_REAL_IS_DOUBLE +#if TH_GENERIC_USE_HALF +#include "THHalf.h" +#define real half +#define accreal float +#define Real Half +#define THInf TH_HALF_MAX +#define TH_REAL_IS_HALF +#if !TH_NATIVE_HALF +# define TH_GENERIC_NO_MATH 1 +#endif +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef THInf +#undef TH_REAL_IS_HALF +#undef TH_GENERIC_NO_MATH +#endif + #undef TH_GENERIC_FILE diff --git a/lib/TH/THHalf.c b/lib/TH/THHalf.c new file mode 100644 index 0000000..fbc3e15 --- /dev/null +++ b/lib/TH/THHalf.c @@ -0,0 +1,141 @@ +#include "THHalf.h" +#include "TH.h" + +static half half_max = TH_HALF_MAX; + +/* + * Copyright 1993-2014 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +// Host functions for converting between FP32 and FP16 formats +// Paulius Micikevicius (pauliusm@nvidia.com) + +float TH_half2float(half h) +{ + unsigned sign = ((h.x >> 15) & 1); + unsigned exponent = ((h.x >> 10) & 0x1f); + unsigned mantissa = ((h.x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + int temp = ((sign << 31) | (exponent << 23) | mantissa); + + return *((float*)((void*)&temp)); +} + +half TH_float2half(float f) +{ + half ret; + + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + ret.x = 0x7fffU; + return ret; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + ret.x = sign | 0x7c00U; + return ret; + } + if (u < 0x33000001) { + ret.x = (sign | 0x0000); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + ret.x = (sign | (exponent << 10) | mantissa); + return ret; +} diff --git a/lib/TH/THHalf.h b/lib/TH/THHalf.h new file mode 100644 index 0000000..01dab4e --- /dev/null +++ b/lib/TH/THHalf.h @@ -0,0 +1,56 @@ +#ifndef _THHALF_H +# define _THHALF_H + +#include <stdint.h> + +#include "THGeneral.h" + +# if defined (TH_HALF_TYPE) +typedef TH_HALF_TYPE half; +# else +/* Neither built-in nor included from Cutorch, use our definition lifted from CUDA */ +#if defined(__GNUC__) +#define __align__(n) __attribute__((aligned(n))) +#elif defined(_WIN32) +#define __align__(n) __declspec(align(n)) +#else +#define __align__(n) +#endif + +typedef struct __align__(2){ + unsigned short x; +} __half; + +typedef struct __align__(4) { + unsigned int x; +} __half2; + +typedef __half half; +typedef __half2 half2; +# endif + +/* numeric limits */ + + +TH_API half TH_float2half(float a); +TH_API float TH_half2float(half a); + +#ifndef TH_HALF_BITS_TO_LITERAL +# define TH_HALF_BITS_TO_LITERAL(n) { n } +#endif + +#define TH_HALF_ZERO TH_HALF_BITS_TO_LITERAL(0x0) +#define TH_HALF_MIN TH_HALF_BITS_TO_LITERAL(0x0400) +#define TH_HALF_MAX TH_HALF_BITS_TO_LITERAL(0x7BFF) +#define TH_HALF_EPSILON TH_HALF_BITS_TO_LITERAL(0x1400) +#define TH_HALF_INF TH_HALF_BITS_TO_LITERAL(0x7C00) +#define TH_HALF_QNAN TH_HALF_BITS_TO_LITERAL(0x7FFF) +#define TH_HALF_SNAN TH_HALF_BITS_TO_LITERAL(0x7DFF) +#define TH_HALF_DENORM_MIN TH_HALF_BITS_TO_LITERAL(0x0001) +#define TH_HALF_DIGITS 11 +#define TH_HALF_DIGITS10 3 +#define TH_HALF_DIGITS10_MAX 5 +#define TH_HALF_MAX_EXPONENT 16 +#define TH_HALF_MAX_EXPONENT10 4 + +#endif diff --git a/lib/TH/THMemoryFile.c b/lib/TH/THMemoryFile.c index 8d97621..9079d6c 100644 --- a/lib/TH/THMemoryFile.c +++ b/lib/TH/THMemoryFile.c @@ -341,7 +341,14 @@ READ_WRITE_METHODS(float, Float, int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", data[i]), 1) - +#if TH_GENERIC_USE_HALF +READ_WRITE_METHODS(half, Half, + int nByteRead_; float buf; \ + int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &buf, &nByteRead_); \ + data[i] = TH_float2half(buf); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, + nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", TH_half2float(data[i])), + 1) +#endif READ_WRITE_METHODS(double, Double, int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.17g", data[i]), @@ -621,6 +628,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) THMemoryFile_readLong, THMemoryFile_readFloat, THMemoryFile_readDouble, + THMemoryFile_readHalf, THMemoryFile_readString, THMemoryFile_writeByte, @@ -630,6 +638,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) THMemoryFile_writeLong, THMemoryFile_writeFloat, THMemoryFile_writeDouble, + THMemoryFile_writeHalf, THMemoryFile_writeString, THMemoryFile_synchronize, diff --git a/lib/TH/THTensor.c b/lib/TH/THTensor.c index 2878fc9..6305668 100644 --- a/lib/TH/THTensor.c +++ b/lib/TH/THTensor.c @@ -14,10 +14,10 @@ #include "generic/THTensorCopy.c" #include "THGenerateAllTypes.h" -#include "generic/THTensorRandom.c" +#include "generic/THTensorMath.c" #include "THGenerateAllTypes.h" -#include "generic/THTensorMath.c" +#include "generic/THTensorRandom.c" #include "THGenerateAllTypes.h" #include "generic/THTensorConv.c" diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c index 6179d89..f530a84 100644 --- a/lib/TH/THVector.c +++ b/lib/TH/THVector.c @@ -1,4 +1,5 @@ #include "THVector.h" + #include "generic/simd/simd.h" #ifdef __NEON__ diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c index 6452f94..8b3a403 100644 --- a/lib/TH/generic/THBlas.c +++ b/lib/TH/generic/THBlas.c @@ -2,6 +2,8 @@ #define TH_GENERIC_FILE "generic/THBlas.c" #else +# ifndef TH_GENERIC_NO_MATH + #ifdef BLAS_F2C # define ffloat double #else @@ -24,8 +26,8 @@ TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, doubl TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda); TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc); TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc); - - + + void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) { @@ -182,9 +184,9 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re { if(n == 1) lda = m; - + #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) - if( (m <= INT_MAX) && (n <= INT_MAX) && + if( (m <= INT_MAX) && (n <= INT_MAX) && (lda > 0) && (lda <= INT_MAX) && (incx > 0) && (incx <= INT_MAX) && (incy > 0) && (incy <= INT_MAX) ) @@ -224,7 +226,7 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re { if(beta != 1) THBlas_(scal)(m, beta, y, incy); - + for(j = 0; j < n; j++) { real *column_ = a+lda*j; @@ -402,5 +404,5 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, } } } - +# endif /* TH_GENERIC_NO_MATH */ #endif diff --git a/lib/TH/generic/THBlas.h b/lib/TH/generic/THBlas.h index 9e14f5a..a49d79c 100644 --- a/lib/TH/generic/THBlas.h +++ b/lib/TH/generic/THBlas.h @@ -1,7 +1,7 @@ #ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "generic/THBlas.h" #else - +# ifndef TH_GENERIC_NO_MATH /* Level 1 */ TH_API void THBlas_(swap)(long n, real *x, long incx, real *y, long incy); TH_API void THBlas_(scal)(long n, real a, real *x, long incx); @@ -15,5 +15,5 @@ TH_API void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y /* Level 3 */ TH_API void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc); - +# endif #endif diff --git a/lib/TH/generic/THStorageCopy.c b/lib/TH/generic/THStorageCopy.c index 583e088..a0b2c95 100644 --- a/lib/TH/generic/THStorageCopy.c +++ b/lib/TH/generic/THStorageCopy.c @@ -15,16 +15,38 @@ void THStorage_(copy)(THStorage *storage, THStorage *src) THStorage_(rawCopy)(storage, src->data); } - #define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \ void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ { \ - ptrdiff_t i; \ THArgCheck(storage->size == src->size, 2, "size mismatch"); \ - for(i = 0; i < storage->size; i++) \ - storage->data[i] = (real)src->data[i]; \ + if(THTypeIdx_(Real) == THTypeIdx_(TYPENAMESRC)) { \ + memcpy(storage->data, src->data, sizeof(real)*storage->size); /* cast just removes compiler warning */ \ + } else { \ + ptrdiff_t i; \ + for(i = 0; i < storage->size; i++) \ + storage->data[i] = (real)src->data[i]; \ + } \ +} + +#define IMPLEMENT_THStorage_COPY_FROM_HALF(TYPENAMESRC) \ +void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ +{ \ + THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + ptrdiff_t i; \ + for(i = 0; i < storage->size; i++) \ + storage->data[i] = (real)TH_half2float(src->data[i]); \ } +#define IMPLEMENT_THStorage_COPY_TO_HALF(TYPENAMESRC) \ +void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \ +{ \ + THArgCheck(storage->size == src->size, 2, "size mismatch"); \ + ptrdiff_t i; \ + for(i = 0; i < storage->size; i++) \ + storage->data[i] = TH_float2half((float)(src->data[i])); \ +} + +#ifndef TH_REAL_IS_HALF IMPLEMENT_THStorage_COPY(Byte) IMPLEMENT_THStorage_COPY(Char) IMPLEMENT_THStorage_COPY(Short) @@ -32,5 +54,20 @@ IMPLEMENT_THStorage_COPY(Int) IMPLEMENT_THStorage_COPY(Long) IMPLEMENT_THStorage_COPY(Float) IMPLEMENT_THStorage_COPY(Double) +#if TH_GENERIC_USE_HALF +IMPLEMENT_THStorage_COPY_FROM_HALF(Half) +#endif +#else +/* only allow pass-through for Half */ +IMPLEMENT_THStorage_COPY(Half) +IMPLEMENT_THStorage_COPY_TO_HALF(Byte) +IMPLEMENT_THStorage_COPY_TO_HALF(Char) +IMPLEMENT_THStorage_COPY_TO_HALF(Short) +IMPLEMENT_THStorage_COPY_TO_HALF(Int) +IMPLEMENT_THStorage_COPY_TO_HALF(Long) +IMPLEMENT_THStorage_COPY_TO_HALF(Float) +IMPLEMENT_THStorage_COPY_TO_HALF(Double) +#endif + #endif diff --git a/lib/TH/generic/THStorageCopy.h b/lib/TH/generic/THStorageCopy.h index f853a82..bb2f406 100644 --- a/lib/TH/generic/THStorageCopy.h +++ b/lib/TH/generic/THStorageCopy.h @@ -13,5 +13,8 @@ TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src); TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src); TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src); TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); +#if TH_GENERIC_USE_HALF +TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src); +#endif #endif diff --git a/lib/TH/generic/THTensorConv.c b/lib/TH/generic/THTensorConv.c index d98a2aa..aa864ad 100644 --- a/lib/TH/generic/THTensorConv.c +++ b/lib/TH/generic/THTensorConv.c @@ -2,6 +2,8 @@ #define TH_GENERIC_FILE "generic/THTensorConv.c" #else +/* Tensor math may be disabled for certain types, e.g. 'TH_half' */ +#ifndef TH_GENERIC_NO_MATH /* 2D Input, 2D kernel : convolve given image with the given kernel. @@ -775,7 +777,7 @@ void THTensor_(conv2DRevgerm)(THTensor *r_, real beta, real alpha, THTensor *t_, real *ptr_output = output_data + k*nInputPlane*nOutputCols*nOutputRows + i*nOutputCols*nOutputRows; /* get input */ real *ptr_input = input_data + p*istride0 + i*istride1; - + /* do image, kernel convolution */ THTensor_(validXCorr2DRevptr)(ptr_output, alpha, @@ -1174,7 +1176,7 @@ void THTensor_(conv2Dmm)(THTensor *r_, real beta, real alpha, THTensor *t_, THTe real *ptr_weight = weight_data + k*kstride0 + i*kstride1; /* get input */ real *ptr_input = input_data + p*nInputPlane*nInputRows*nInputCols + i*nInputRows*nInputCols; - + /* do image, kernel convolution */ if (*vf == 'F') if (*xc == 'X') @@ -1955,5 +1957,5 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT THTensor_(free)(input); THTensor_(free)(kernel); } - +#endif #endif diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c index ea6d6f1..6c7daa8 100644 --- a/lib/TH/generic/THTensorCopy.c +++ b/lib/TH/generic/THTensorCopy.c @@ -13,6 +13,19 @@ void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \ } +#define IMPLEMENT_THTensor_COPY_TO_HALF(TYPENAMESRC, TYPE_SRC) \ +void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \ +{ \ + TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = TH_float2half((float)*src_data);) \ +} + +#define IMPLEMENT_THTensor_COPY_FROM_HALF(TYPENAMESRC, TYPE_SRC) \ +void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \ +{ \ + TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)TH_half2float(*src_data);) \ +} + +#ifndef TH_REAL_IS_HALF IMPLEMENT_THTensor_COPY(Byte, unsigned char) IMPLEMENT_THTensor_COPY(Char, char) IMPLEMENT_THTensor_COPY(Short, short) @@ -20,5 +33,20 @@ IMPLEMENT_THTensor_COPY(Int, int) IMPLEMENT_THTensor_COPY(Long, long) IMPLEMENT_THTensor_COPY(Float, float) IMPLEMENT_THTensor_COPY(Double, double) +#if TH_GENERIC_USE_HALF +IMPLEMENT_THTensor_COPY_FROM_HALF(Half, half) +#endif +#else +/* only allow pass-through for Half */ +IMPLEMENT_THTensor_COPY(Half, half) +IMPLEMENT_THTensor_COPY_TO_HALF(Byte, unsigned char) +IMPLEMENT_THTensor_COPY_TO_HALF(Char, char) +IMPLEMENT_THTensor_COPY_TO_HALF(Short, short) +IMPLEMENT_THTensor_COPY_TO_HALF(Int, int) +IMPLEMENT_THTensor_COPY_TO_HALF(Long, long) +IMPLEMENT_THTensor_COPY_TO_HALF(Float, float) +IMPLEMENT_THTensor_COPY_TO_HALF(Double, double) + +#endif /* REAL_IS_HALF */ #endif diff --git a/lib/TH/generic/THTensorCopy.h b/lib/TH/generic/THTensorCopy.h index 8d03b22..8a0abcf 100644 --- a/lib/TH/generic/THTensorCopy.h +++ b/lib/TH/generic/THTensorCopy.h @@ -12,5 +12,7 @@ TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src); TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src); TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src); TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src); - +#if TH_GENERIC_USE_HALF +TH_API void THTensor_(copyHalf)(THTensor *tensor, struct THHalfTensor *src); +#endif #endif diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index b275d8f..5570dc2 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -2,6 +2,7 @@ #define TH_GENERIC_FILE "generic/THTensorMath.c" #else +#ifndef TH_GENERIC_NO_MATH #define TH_OMP_OVERHEAD_THRESHOLD 100000 void THTensor_(fill)(THTensor *r_, real value) @@ -98,10 +99,15 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) long i = 0; long dim; long div = 1; +#ifdef TH_REAL_IS_HALF +#define IS_NONZERO(val) ((val).x!=0) +#else +#define IS_NONZERO(val) ((val)!=0) +#endif /* First Pass to determine size of subscripts */ TH_TENSOR_APPLY(real, tensor, - if (*tensor_data != 0) { + if IS_NONZERO(*tensor_data) { ++numel; }); #ifdef DEBUG @@ -112,7 +118,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) /* Second pass populates subscripts */ subscript_data = THLongTensor_data(subscript); TH_TENSOR_APPLY(real, tensor, - if (*tensor_data != 0) { + if IS_NONZERO(*tensor_data) { div = 1; for (dim = tensor->nDimension - 1; dim >= 0; dim--) { @@ -396,6 +402,7 @@ accreal THTensor_(dot)(THTensor *tensor, THTensor *src) return sum; } + #undef th_isnan #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) #define th_isnan(val) \ @@ -2499,4 +2506,6 @@ void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalu } #endif /* floating point only part */ +#endif /* TH_GENERIC_NO_MATH */ +#undef IS_NONZERO #endif diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index 87f1616..781318c 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -128,6 +128,16 @@ TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); #endif +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) +TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size); +TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size); +#endif + +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) +TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size); +TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size); +#endif + #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) TH_API void THTensor_(sigmoid)(THTensor *r_, THTensor *t); @@ -171,9 +181,6 @@ TH_API accreal THTensor_(normall)(THTensor *t, real value); TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n); TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n); -TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size); -TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size); - #endif #if defined(TH_REAL_IS_BYTE) diff --git a/lib/TH/generic/THTensorRandom.c b/lib/TH/generic/THTensorRandom.c index 514d3dd..18b0471 100644 --- a/lib/TH/generic/THTensorRandom.c +++ b/lib/TH/generic/THTensorRandom.c @@ -2,6 +2,9 @@ #define TH_GENERIC_FILE "generic/THTensorRandom.c" #else +/* Tensor math may be disabled for certain types, e.g. 'half' */ +#ifndef TH_GENERIC_NO_MATH + void THTensor_(random)(THTensor *self, THGenerator *_generator) { #if defined(TH_REAL_IS_BYTE) @@ -247,4 +250,6 @@ void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self) } #endif +#endif /* TH_GENERIC_NO_MATH */ + #endif diff --git a/lib/TH/generic/THVector.h b/lib/TH/generic/THVector.h index 67fdcfa..eaeb008 100644 --- a/lib/TH/generic/THVector.h +++ b/lib/TH/generic/THVector.h @@ -2,11 +2,13 @@ #define TH_GENERIC_FILE "generic/THVector.h" #else +#ifndef TH_GENERIC_NO_MATH TH_API void THVector_(fill)(real *x, const real c, const ptrdiff_t n); TH_API void THVector_(add)(real *y, const real *x, const real c, const ptrdiff_t n); TH_API void THVector_(diff)(real *z, const real *x, const real *y, const ptrdiff_t n); TH_API void THVector_(scale)(real *y, const real c, const ptrdiff_t n); TH_API void THVector_(mul)(real *y, const real *x, const ptrdiff_t n); +#endif /* Initialize the dispatch pointers */ TH_API void THVector_(vectorDispatchInit)(void); diff --git a/lib/TH/generic/THVectorDefault.c b/lib/TH/generic/THVectorDefault.c index aabc16c..7554d45 100644 --- a/lib/TH/generic/THVectorDefault.c +++ b/lib/TH/generic/THVectorDefault.c @@ -2,6 +2,8 @@ #define TH_GENERIC_FILE "generic/THVectorDefault.c" #else +#ifndef TH_GENERIC_NO_MATH + void THVector_(fill_DEFAULT)(real *x, const real c, const ptrdiff_t n) { ptrdiff_t i = 0; @@ -80,5 +82,5 @@ void THVector_(mul_DEFAULT)(real *y, const real *x, const ptrdiff_t n) for(; i < n; i++) y[i] *= x[i]; } - +# endif #endif diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c index 6fd1d68..3624af0 100644 --- a/lib/TH/generic/THVectorDispatch.c +++ b/lib/TH/generic/THVectorDispatch.c @@ -2,6 +2,8 @@ #define TH_GENERIC_FILE "generic/THVectorDispatch.c" #else +#ifndef TH_GENERIC_NO_MATH + /* For now there are only SIMD implementations for FLOAT and DOUBLE. * Hopefully in the future this can be made totally generic (e.g, there are SIMD implementations * for a lot of functions */ @@ -32,7 +34,6 @@ void THVector_(fill)(real *x, const real c, const ptrdiff_t n) { THVector_(fill_DISPATCHPTR)(x, c, n); } - static void (*THVector_(add_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(add_DEFAULT); static FunctionDescription THVector_(add_DISPATCHTABLE)[] = { #if defined(__NEON__) @@ -136,5 +137,5 @@ void THVector_(vectorDispatchInit)(void) INIT_DISPATCH_PTR(scale); INIT_DISPATCH_PTR(mul); } - +#endif /* TH_GENERIC_NO_MATH */ #endif diff --git a/test/test_half.lua b/test/test_half.lua new file mode 100644 index 0000000..480fae8 --- /dev/null +++ b/test/test_half.lua @@ -0,0 +1,124 @@ +local mytester +local torchtest = torch.TestSuite() +local msize = 100 +local precision + +-- Lua 5.2 compatibility +local loadstring = loadstring or load +local unpack = unpack or table.unpack + +local function maxdiff(x,y) + local d = x-y + if x:type() == 'torch.DoubleTensor' or x:type() == 'torch.FloatTensor' or x:type() == 'torch.HalfTensor' then + return d:abs():max() + else + local dd = torch.Tensor():resize(d:size()):copy(d) + return dd:abs():max() + end +end + + +function torchtest.elementSize() + local byte = torch.ByteStorage():elementSize() + local char = torch.CharStorage():elementSize() + local short = torch.ShortStorage():elementSize() + local int = torch.IntStorage():elementSize() + local long = torch.LongStorage():elementSize() + local float = torch.FloatStorage():elementSize() + local double = torch.DoubleStorage():elementSize() + local half = torch.HalfStorage():elementSize() + + mytester:asserteq(byte, torch.ByteTensor():elementSize()) + mytester:asserteq(char, torch.CharTensor():elementSize()) + mytester:asserteq(short, torch.ShortTensor():elementSize()) + mytester:asserteq(int, torch.IntTensor():elementSize()) + mytester:asserteq(long, torch.LongTensor():elementSize()) + mytester:asserteq(float, torch.FloatTensor():elementSize()) + mytester:asserteq(double, torch.DoubleTensor():elementSize()) + mytester:asserteq(half, torch.HalfTensor():elementSize()) + + mytester:assertne(byte, 0) + mytester:assertne(char, 0) + mytester:assertne(short, 0) + mytester:assertne(int, 0) + mytester:assertne(long, 0) + mytester:assertne(float, 0) + mytester:assertne(double, 0) + mytester:assertne(half, 0) + + -- These tests are portable, not necessarily strict for your system. + mytester:asserteq(byte, 1) + mytester:asserteq(char, 1) + mytester:assert(short >= 2) + mytester:assert(int >= 2) + mytester:assert(int >= short) + mytester:assert(long >= 4) + mytester:assert(long >= int) + mytester:assert(double >= float) + mytester:assert(half <= float) +end + +function torchtest.isTensor() + local t = torch.randn(3,4):half() + print("\n Tensor:half() result: ", t) + + mytester:assert(torch.isTensor(t), 'error in isTensor') + mytester:assert(torch.isTensor(t[1]), 'error in isTensor for subTensor') + mytester:assert(not torch.isTensor(t[1][2]), 'false positive in isTensor') + mytester:assert(torch.Tensor.isTensor(t), 'alias not working') +end +function torchtest.isStorage() + local t = torch.randn(3,4):half() + mytester:assert(torch.isStorage(t:storage()), 'error in isStorage') + mytester:assert(not torch.isStorage(t), 'false positive in isStorage') +end + +function torchtest.expand() + local result = torch.Tensor():half() + local tensor = torch.rand(8,1) + local template = torch.rand(8,5) + local target = template:size():totable() + mytester:assertTableEq(tensor:expandAs(template):size():totable(), target, 'Error in expandAs') + mytester:assertTableEq(tensor:expand(8,5):size():totable(), target, 'Error in expand') + mytester:assertTableEq(tensor:expand(torch.LongStorage{8,5}):size():totable(), target, 'Error in expand using LongStorage') + result:expandAs(tensor,template) + mytester:assertTableEq(result:size():totable(), target, 'Error in expandAs using result') + result:expand(tensor,8,5) + mytester:assertTableEq(result:size():totable(), target, 'Error in expand using result') + result:expand(tensor,torch.LongStorage{8,5}) + mytester:assertTableEq(result:size():totable(), target, 'Error in expand using result and LongStorage') + mytester:asserteq((result:mean(2):view(8,1)-tensor):abs():max(), 0, 'Error in expand (not equal)') +end + +function torchtest.repeatTensor() + local result = torch.Tensor():half() + local tensor = torch.rand(8,4) + local size = {3,1,1} + local sizeStorage = torch.LongStorage(size) + local target = {3,8,4} + mytester:assertTableEq(tensor:repeatTensor(unpack(size)):size():totable(), target, 'Error in repeatTensor') + mytester:assertTableEq(tensor:repeatTensor(sizeStorage):size():totable(), target, 'Error in repeatTensor using LongStorage') + result:repeatTensor(tensor,unpack(size)) + mytester:assertTableEq(result:size():totable(), target, 'Error in repeatTensor using result') + result:repeatTensor(tensor,sizeStorage) + mytester:assertTableEq(result:size():totable(), target, 'Error in repeatTensor using result and LongStorage') + mytester:asserteq((result:mean(1):view(8,4)-tensor):abs():max(), 0, 'Error in repeatTensor (not equal)') +end + +function torchtest.isSameSizeAs() + local t1 = torch.Tensor(3, 4, 9, 10):half() + local t2 = torch.Tensor(3, 4):half() + local t3 = torch.Tensor(1, 9, 3, 3):half() + local t4 = torch.Tensor(3, 4, 9, 10):half() + + mytester:assert(t1:isSameSizeAs(t2) == false, "wrong answer ") + mytester:assert(t1:isSameSizeAs(t3) == false, "wrong answer ") + mytester:assert(t1:isSameSizeAs(t4) == true, "wrong answer ") +end + + torch.setheaptracking(true) + math.randomseed(os.time()) + precision = 1e-4 + mytester = torch.Tester() + mytester:add(torchtest) + mytester:run(tests) diff --git a/torchcwrap.lua b/torchcwrap.lua index ab0df43..551bd05 100644 --- a/torchcwrap.lua +++ b/torchcwrap.lua @@ -202,7 +202,7 @@ types.IndexTensor = { } for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor", - "FloatTensor", "DoubleTensor"}) do + "FloatTensor", "HalfTensor", "DoubleTensor"}) do types[typename] = { @@ -460,3 +460,56 @@ types.charoption = { postcall = function(arg) end } + +for _,typename in ipairs({"ptrdiff_t", "size_t"}) do + types[typename] = { + + helpname = function(arg) + return typename + end, + + declare = function(arg) + -- if it is a number we initialize here + local default = tonumber(tostring(arg.default)) or 0 + return string.format("%s arg%d = %g;", typename, arg.i, default) + end, + + check = function(arg, idx) + return string.format("lua_isnumber(L, %d)", idx) + end, + + read = function(arg, idx) + return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, typename, idx) + end, + + init = function(arg) + -- otherwise do it here + if arg.default then + local default = tostring(arg.default) + if not tonumber(default) then + return string.format("arg%d = %s;", arg.i, default) + end + end + end, + + carg = function(arg) + return string.format('arg%d', arg.i) + end, + + creturn = function(arg) + return string.format('arg%d', arg.i) + end, + + precall = function(arg) + if arg.returned then + return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) + end + end, + + postcall = function(arg) + if arg.creturned then + return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) + end + end + } +end |