Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-12-20 04:56:39 +0300
committerGitHub <noreply@github.com>2016-12-20 04:56:39 +0300
commite839c51c64c563621df1519ef98860e250ad6bdf (patch)
tree24b7da2070b18825ff38eff2ff9b0eeeebb60344
parentf9b37b03b5044ea52a9faabb0820ba42e2afa465 (diff)
Half type (#870)fp16
* torch.HalfTensor
-rw-r--r--CMakeLists.txt2
-rw-r--r--FFI.lua11
-rw-r--r--Tensor.lua8
-rw-r--r--TensorMath.lua61
-rw-r--r--Tester.lua1
-rw-r--r--generic/Storage.c4
-rw-r--r--generic/Tensor.c24
-rw-r--r--generic/TensorOperator.c14
-rw-r--r--generic/luaG.h30
-rw-r--r--init.c31
-rw-r--r--lib/TH/CMakeLists.txt7
-rw-r--r--lib/TH/TH.h4
-rw-r--r--lib/TH/THDiskFile.c10
-rw-r--r--lib/TH/THFile.c5
-rw-r--r--lib/TH/THFile.h9
-rw-r--r--lib/TH/THFilePrivate.h13
-rw-r--r--lib/TH/THGeneral.h.in8
-rw-r--r--lib/TH/THGenerateAllTypes.h30
-rw-r--r--lib/TH/THHalf.c141
-rw-r--r--lib/TH/THHalf.h56
-rw-r--r--lib/TH/THMemoryFile.c11
-rw-r--r--lib/TH/THTensor.c4
-rw-r--r--lib/TH/THVector.c1
-rw-r--r--lib/TH/generic/THBlas.c14
-rw-r--r--lib/TH/generic/THBlas.h4
-rw-r--r--lib/TH/generic/THStorageCopy.c45
-rw-r--r--lib/TH/generic/THStorageCopy.h3
-rw-r--r--lib/TH/generic/THTensorConv.c8
-rw-r--r--lib/TH/generic/THTensorCopy.c28
-rw-r--r--lib/TH/generic/THTensorCopy.h4
-rw-r--r--lib/TH/generic/THTensorMath.c13
-rw-r--r--lib/TH/generic/THTensorMath.h13
-rw-r--r--lib/TH/generic/THTensorRandom.c5
-rw-r--r--lib/TH/generic/THVector.h2
-rw-r--r--lib/TH/generic/THVectorDefault.c4
-rw-r--r--lib/TH/generic/THVectorDispatch.c5
-rw-r--r--test/test_half.lua124
-rw-r--r--torchcwrap.lua55
38 files changed, 703 insertions, 109 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 611258b..fb2de09 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,6 +25,8 @@ IF(MSVC)
ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
ENDIF(MSVC)
+ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1)
+
# OpenMP support?
SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?")
IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
diff --git a/FFI.lua b/FFI.lua
index 3cc0b21..365b248 100644
--- a/FFI.lua
+++ b/FFI.lua
@@ -15,6 +15,7 @@ local function checkArgumentType(expected, actual, fn, ud, level)
end
if ok then
+
local Real2real = {
Byte='unsigned char',
Char='char',
@@ -22,7 +23,8 @@ if ok then
Int='int',
Long='long',
Float='float',
- Double='double'
+ Double='double',
+ Half='half'
}
-- Allocator
@@ -33,11 +35,14 @@ typedef struct THAllocator {
void (*free)(void*, void*);
} THAllocator;
]]
-
-- Storage
for Real, real in pairs(Real2real) do
local cdefs = [[
+typedef struct {
+ unsigned short x;
+} half;
+
typedef struct THRealStorage
{
real *data;
@@ -76,7 +81,7 @@ typedef struct THRealTensor
long *size;
long *stride;
int nDimension;
-
+
THRealStorage *storage;
ptrdiff_t storageOffset;
int refcount;
diff --git a/Tensor.lua b/Tensor.lua
index b4b3e95..a5dee40 100644
--- a/Tensor.lua
+++ b/Tensor.lua
@@ -5,14 +5,14 @@ local Storage = {}
local Tensor = {}
-- types
-local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
+local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}
-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end
-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
- if self:size() == 0 then
+ if self:size() == 0 then
return "", nil, 0
end
local intMode = true
@@ -277,6 +277,10 @@ function Tensor.double(self)
return self:type('torch.DoubleTensor')
end
+function Tensor.half(self)
+ return self:type('torch.HalfTensor')
+end
+
function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
diff --git a/TensorMath.lua b/TensorMath.lua
index 682de23..6890ce8 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -6,56 +6,8 @@ local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
local argtypes = wrap.CInterface.argtypes
-argtypes['ptrdiff_t'] = {
-
- helpname = function(arg)
- return 'ptrdiff_t'
- end,
-
- declare = function(arg)
- -- if it is a number we initialize here
- local default = tonumber(tostring(arg.default)) or 0
- return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
- end,
-
- check = function(arg, idx)
- return string.format("lua_isnumber(L, %d)", idx)
- end,
-
- read = function(arg, idx)
- return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx)
- end,
-
- init = function(arg)
- -- otherwise do it here
- if arg.default then
- local default = tostring(arg.default)
- if not tonumber(default) then
- return string.format("arg%d = %s;", arg.i, default)
- end
- end
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- if arg.returned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end,
-
- postcall = function(arg)
- if arg.creturned then
- return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
- end
- end
-}
+argtypes['ptrdiff_t'] = wrap.types.ptrdiff_t
+argtypes['half'] = wrap.types.half
interface:print([[
#include "TH.h"
@@ -216,6 +168,7 @@ local reals = {ByteTensor='unsigned char',
IntTensor='int',
LongTensor='long',
FloatTensor='float',
+ HalfTensor='half',
DoubleTensor='double'}
local accreals = {ByteTensor='long',
@@ -224,11 +177,12 @@ local accreals = {ByteTensor='long',
IntTensor='long',
LongTensor='long',
FloatTensor='double',
+ HalfTensor='float',
DoubleTensor='double'}
for _,Tensor in ipairs({"ByteTensor", "CharTensor",
"ShortTensor", "IntTensor", "LongTensor",
- "FloatTensor", "DoubleTensor"}) do
+ "FloatTensor", "HalfTensor", "DoubleTensor"}) do
local real = reals[Tensor]
local accreal = accreals[Tensor]
@@ -257,6 +211,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
end
end
+ if Tensor ~= 'HalfTensor' then
wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})
@@ -1030,6 +985,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
cname("nonzero"),
{{name="IndexTensor", default=true, returned=true},
{name=Tensor}})
+ end -- ~= HalfTensor
if Tensor == 'ByteTensor' then
-- Logical accumulators only apply to ByteTensor
@@ -1483,6 +1439,9 @@ void torch_TensorMath_init(lua_State *L)
torch_IntTensorMath_init(L);
torch_LongTensorMath_init(L);
torch_FloatTensorMath_init(L);
+ #if TH_NATIVE_HALF
+ torch_HalfTensorMath_init(L);
+ #endif
torch_DoubleTensorMath_init(L);
luaT_setfuncs(L, torch_TensorMath__, 0);
}
diff --git a/Tester.lua b/Tester.lua
index f512edb..6509413 100644
--- a/Tester.lua
+++ b/Tester.lua
@@ -687,6 +687,7 @@ local typesMatching = {
['torch.LongStorage'] = torch.LongTensor,
['torch.FloatStorage'] = torch.FloatTensor,
['torch.DoubleStorage'] = torch.DoubleTensor,
+ ['torch.HalfStorage'] = torch.HalfTensor,
}
--[[ Tests for storage equality.
diff --git a/generic/Storage.c b/generic/Storage.c
index 134dc63..a6652a5 100644
--- a/generic/Storage.c
+++ b/generic/Storage.c
@@ -41,7 +41,7 @@ static int torch_Storage_(new)(lua_State *L)
THStorage_(free)(storage);
luaL_error(L, "element at index %d is not a number", i);
}
- THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1));
+ THStorage_(set)(storage, i-1, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1)));
lua_pop(L, 1);
}
}
@@ -131,6 +131,8 @@ static int torch_Storage_(copy)(lua_State *L)
THStorage_(copyFloat)(storage, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
THStorage_(copyDouble)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) )
+ THStorage_(copyHalf)(storage, src);
else
luaL_typerror(L, 2, "torch.*Storage");
lua_settop(L, 1);
diff --git a/generic/Tensor.c b/generic/Tensor.c
index abb7819..b1dca69 100644
--- a/generic/Tensor.c
+++ b/generic/Tensor.c
@@ -142,7 +142,7 @@ static int torch_Tensor_(new)(lua_State *L)
THTensor_(free)(tensor);
THError("invalid element (not a number)");
}
- THStorage_(set)(THTensor_(storage)(tensor), si++, (real)lua_tonumber(L, -1));
+ THStorage_(set)(THTensor_(storage)(tensor), si++, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1)));
lua_pop(L, 1);
}
@@ -675,6 +675,8 @@ static int torch_Tensor_(copy)(lua_State *L)
THTensor_(copyFloat)(tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
THTensor_(copyDouble)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.HalfTensor")) )
+ THTensor_(copyHalf)(tensor, src);
else
luaL_typerror(L, 2, "torch.*Tensor");
lua_settop(L, 1);
@@ -745,6 +747,11 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyDouble)(tensor, src);
THTensor_(free)(tensor);
+ } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) {
+ tensor = THTensor_(newWithTensor)(tensor);
+ THTensor_(narrow)(tensor, NULL, 0, index, 1);
+ THTensor_(copyHalf)(tensor, src);
+ THTensor_(free)(tensor);
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
@@ -829,7 +836,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
/* doing a copy */
void *src;
if (lua_isnumber(L,3)) {
- THTensor_(fill)(tensor, lua_tonumber(L,3));
+ THTensor_(fill)(tensor, LUA_NUMBER_TO_REAL(lua_tonumber(L,3)));
} else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) {
THTensor_(copy)(tensor, src);
} else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) {
@@ -846,6 +853,8 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THTensor_(copyFloat)(tensor, src);
} else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) {
THTensor_(copyDouble)(tensor, src);
+ } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) {
+ THTensor_(copyHalf)(tensor, src);
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
@@ -916,7 +925,7 @@ static int torch_Tensor_(__index__)(lua_State *L)
THArgCheck((z >= 0) && (z < tensor->size[dim]), 2, "index out of bound");
index += z*tensor->stride[dim];
}
- luaG_(pushreal)(L, (double)THStorage_(get)(THTensor_(storage)(tensor), index));
+ luaG_(pushreal)(L, THStorage_(get)(THTensor_(storage)(tensor), index));
lua_pushboolean(L, 1);
return 2;
}
@@ -1143,7 +1152,7 @@ static int torch_Tensor_(apply)(lua_State *L)
lua_call(L, 1, 1);
if(lua_isnumber(L, 3))
{
- *tensor_data = (real)lua_tonumber(L, 3);
+ *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 3));
lua_pop(L, 1);
}
else if(lua_isnil(L, 3))
@@ -1169,7 +1178,7 @@ static int torch_Tensor_(map)(lua_State *L)
lua_call(L, 2, 1);
if(lua_isnumber(L, 4))
{
- *tensor_data = (real)lua_tonumber(L, 4);
+ *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 4));
lua_pop(L, 1);
}
else if(lua_isnil(L, 4))
@@ -1197,7 +1206,7 @@ static int torch_Tensor_(map2)(lua_State *L)
lua_call(L, 3, 1);
if(lua_isnumber(L, 5))
{
- *tensor_data = (real)lua_tonumber(L, 5);
+ *tensor_data = LUA_NUMBER_TO_REAL(lua_tonumber(L, 5));
lua_pop(L, 1);
}
else if(lua_isnil(L, 5))
@@ -1318,7 +1327,10 @@ void torch_Tensor_(init)(lua_State *L)
torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
luaT_setfuncs(L, torch_Tensor_(_), 0);
lua_pop(L, 1);
+#ifndef TH_GENERIC_NO_MATH
THVector_(vectorDispatchInit)();
+#endif
+
}
#endif
diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c
index eba6c81..c722e88 100644
--- a/generic/TensorOperator.c
+++ b/generic/TensorOperator.c
@@ -2,6 +2,9 @@
#define TH_GENERIC_FILE "generic/TensorOperator.c"
#else
+/* Tensor math may be disabled for certain types, e.g. 'half' */
+#ifndef TH_GENERIC_NO_MATH
+
static int torch_TensorOperator_(__add__)(lua_State *L)
{
THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor);
@@ -14,7 +17,7 @@ static int torch_TensorOperator_(__add__)(lua_State *L)
{
r = THTensor_(new)();
luaT_pushudata(L, r, torch_Tensor);
-
+
if(!tensor1 && tensor2)
{
THTensor_(resizeAs)(r, tensor2);
@@ -49,7 +52,7 @@ static int torch_TensorOperator_(__sub__)(lua_State *L)
{
r = THTensor_(new)();
luaT_pushudata(L, r, torch_Tensor);
-
+
if(!tensor1 && tensor2)
{
THTensor_(resizeAs)(r, tensor2);
@@ -98,7 +101,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L)
{
r = THTensor_(new)();
luaT_pushudata(L, r, torch_Tensor);
-
+
if(!tensor1 && tensor2)
{
THTensor_(resizeAs)(r, tensor2);
@@ -115,7 +118,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L)
{
int dimt = tensor1->nDimension;
int dims = tensor2->nDimension;
-
+
if(dimt == 1 && dims == 1)
lua_pushnumber(L, THTensor_(dot)(tensor1, tensor2)); /* ok, we wasted r, but who cares */
else if(dimt == 2 && dims == 1)
@@ -131,7 +134,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L)
THTensor_(addmm)(r, 1, r, 1, tensor1, tensor2);
}
else
- luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension);
+ luaL_error(L, "multiplication between %dD and %dD tensors not yet supported", tensor1->nDimension, tensor2->nDimension);
}
}
return 1;
@@ -187,5 +190,6 @@ void torch_TensorOperator_(init)(lua_State *L)
luaT_setfuncs(L, torch_TensorOperator_(_), 0);
lua_pop(L, 1);
}
+#endif
#endif
diff --git a/generic/luaG.h b/generic/luaG.h
index 950eae9..f1ffce2 100644
--- a/generic/luaG.h
+++ b/generic/luaG.h
@@ -6,10 +6,24 @@
#define luaG_(NAME) TH_CONCAT_3(luaG_,Real,NAME)
#endif
-static void luaG_(pushreal)(lua_State *L, accreal n) {
-#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503
- lua_pushnumber(L, (lua_Number)n);
-#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
+#undef REAL_TO_LUA_NUMBER
+#undef LUA_NUMBER_TO_REAL
+
+#if defined(TH_REAL_IS_HALF)
+# define REAL_TO_LUA_NUMBER(n) (lua_Number)TH_half2float(n)
+# define LUA_NUMBER_TO_REAL(n) TH_float2half((lua_Number)n)
+#else
+# define REAL_TO_LUA_NUMBER(n) (lua_Number)(n)
+# define LUA_NUMBER_TO_REAL(n) (real)n
+#endif
+
+
+
+static void luaG_(pushreal)(lua_State *L, real n) {
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) || LUA_VERSION_NUM < 503
+ lua_pushnumber(L, REAL_TO_LUA_NUMBER(n));
+#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) \
+ || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
lua_pushinteger(L, (lua_Integer)n);
#else
#error "unhandled real type in luaG_pushreal"
@@ -17,8 +31,8 @@ static void luaG_(pushreal)(lua_State *L, accreal n) {
}
static real luaG_(checkreal)(lua_State *L, int idx) {
-#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
- return (lua_Number)luaL_checknumber(L, idx);
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
+ return LUA_NUMBER_TO_REAL(luaL_checknumber(L, idx));
#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
int type = lua_type(L, idx);
if (type == LUA_TSTRING) {
@@ -38,8 +52,8 @@ static real luaG_(checkreal)(lua_State *L, int idx) {
}
static real luaG_(optreal)(lua_State *L, int idx, real n) {
-#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503
- return (lua_Number)luaL_optnumber(L, idx, (lua_Number)n);
+#if defined(TH_REAL_IS_HALF) || defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || LUA_VERSION_NUM < 503
+ return LUA_NUMBER_TO_REAL(luaL_optnumber(L, idx, REAL_TO_LUA_NUMBER(n)));
#elif defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) || defined(TH_REAL_IS_LONG)
return (lua_Integer)luaL_optinteger(L, idx, (lua_Integer)n);
#else
diff --git a/init.c b/init.c
index 08eedba..0c413f9 100644
--- a/init.c
+++ b/init.c
@@ -16,6 +16,7 @@ extern void torch_IntStorage_init(lua_State *L);
extern void torch_LongStorage_init(lua_State *L);
extern void torch_FloatStorage_init(lua_State *L);
extern void torch_DoubleStorage_init(lua_State *L);
+extern void torch_HalfStorage_init(lua_State *L);
extern void torch_ByteTensor_init(lua_State *L);
extern void torch_CharTensor_init(lua_State *L);
@@ -24,6 +25,7 @@ extern void torch_IntTensor_init(lua_State *L);
extern void torch_LongTensor_init(lua_State *L);
extern void torch_FloatTensor_init(lua_State *L);
extern void torch_DoubleTensor_init(lua_State *L);
+extern void torch_HalfTensor_init(lua_State *L);
extern void torch_ByteTensorOperator_init(lua_State *L);
extern void torch_CharTensorOperator_init(lua_State *L);
@@ -33,8 +35,29 @@ extern void torch_LongTensorOperator_init(lua_State *L);
extern void torch_FloatTensorOperator_init(lua_State *L);
extern void torch_DoubleTensorOperator_init(lua_State *L);
+#if TH_NATIVE_HALF
+extern void torch_HalfTensorOperator_init(lua_State *L);
+#endif
+
extern void torch_TensorMath_init(lua_State *L);
+static int torch_hashalfmath(lua_State *L) {
+ lua_pushboolean(L, TH_NATIVE_HALF);
+ return 1;
+}
+
+static void torch_half_init(lua_State *L)
+{
+ const struct luaL_Reg half_funcs__ [] = {
+ {"hashalfmath", torch_hashalfmath},
+ {NULL, NULL}
+ };
+ luaT_setfuncs(L, half_funcs__, 0);
+
+ lua_pushboolean(L, 1);
+ lua_setfield(L, -2, "hasHalf");
+}
+
LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L);
int luaopen_libtorch(lua_State *L)
@@ -45,8 +68,8 @@ int luaopen_libtorch(lua_State *L)
lua_setglobal(L, "torch");
torch_utils_init(L);
-
torch_File_init(L);
+ torch_half_init(L);
torch_ByteStorage_init(L);
torch_CharStorage_init(L);
@@ -55,6 +78,7 @@ int luaopen_libtorch(lua_State *L)
torch_LongStorage_init(L);
torch_FloatStorage_init(L);
torch_DoubleStorage_init(L);
+ torch_HalfStorage_init(L);
torch_ByteTensor_init(L);
torch_CharTensor_init(L);
@@ -63,6 +87,7 @@ int luaopen_libtorch(lua_State *L)
torch_LongTensor_init(L);
torch_FloatTensor_init(L);
torch_DoubleTensor_init(L);
+ torch_HalfTensor_init(L);
torch_ByteTensorOperator_init(L);
torch_CharTensorOperator_init(L);
@@ -72,6 +97,10 @@ int luaopen_libtorch(lua_State *L)
torch_FloatTensorOperator_init(L);
torch_DoubleTensorOperator_init(L);
+#if TH_NATIVE_HALF
+ torch_HalfTensorOperator_init(L);
+#endif
+
torch_Timer_init(L);
torch_DiskFile_init(L);
torch_PipeFile_init(L);
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt
index e6cf91d..3912bf2 100644
--- a/lib/TH/CMakeLists.txt
+++ b/lib/TH/CMakeLists.txt
@@ -122,11 +122,11 @@ IF(C_AVX_FOUND)
ENDIF(C_AVX_FOUND)
SET(hdr
- THGeneral.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h
- THLapack.h THLogAdd.h THRandom.h THVector.h THAtomic.h)
+ THGeneral.h THHalf.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h
+ THLapack.h THLogAdd.h THRandom.h THVector.h THAtomic.h )
SET(src
- THGeneral.c THAllocator.c THStorage.c THTensor.c THBlas.c THLapack.c
+ THGeneral.c THHalf.c THAllocator.c THStorage.c THTensor.c THBlas.c THLapack.c
THLogAdd.c THRandom.c THFile.c THDiskFile.c THMemoryFile.c THAtomic.c THVector.c)
SET(src ${src} ${hdr} ${simd})
@@ -333,6 +333,7 @@ INSTALL(FILES
THTensorMacros.h
THVector.h
THAtomic.h
+ THHalf.h
DESTINATION "${TH_INSTALL_INCLUDE_SUBDIR}/TH")
INSTALL(FILES
diff --git a/lib/TH/TH.h b/lib/TH/TH.h
index cdf331d..8b676cf 100644
--- a/lib/TH/TH.h
+++ b/lib/TH/TH.h
@@ -3,6 +3,10 @@
#include "THGeneral.h"
+#if TH_GENERIC_USE_HALF
+# include "THHalf.h"
+#endif
+
#include "THBlas.h"
#ifdef USE_LAPACK
#include "THLapack.h"
diff --git a/lib/TH/THDiskFile.c b/lib/TH/THDiskFile.c
index 9d9cbae..a397027 100644
--- a/lib/TH/THDiskFile.c
+++ b/lib/TH/THDiskFile.c
@@ -355,7 +355,11 @@ READ_WRITE_METHODS(int, Int,
READ_WRITE_METHODS(float, Float,
int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++)
-
+#if TH_GENERIC_USE_HALF
+READ_WRITE_METHODS(half, Half,
+ float buf; int ret = fscanf(dfself->handle, "%g", &buf); if(ret <= 0) break; else { data[i]= TH_float2half(buf); nread++; },
+ int ret = fprintf(dfself->handle, "%.9g", TH_half2float(data[i])); if(ret <= 0) break; else nwrite++)
+#endif
READ_WRITE_METHODS(double, Double,
int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++)
@@ -618,6 +622,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
+ THDiskFile_readHalf,
THDiskFile_readString,
THDiskFile_writeByte,
@@ -627,6 +632,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
+ THDiskFile_writeHalf,
THDiskFile_writeString,
THDiskFile_synchronize,
@@ -730,6 +736,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
+ THDiskFile_readHalf,
THDiskFile_readString,
THDiskFile_writeByte,
@@ -739,6 +746,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
+ THDiskFile_writeHalf,
THDiskFile_writeString,
THDiskFile_synchronize,
diff --git a/lib/TH/THFile.c b/lib/TH/THFile.c
index c8913af..f2bc8d7 100644
--- a/lib/TH/THFile.c
+++ b/lib/TH/THFile.c
@@ -19,6 +19,9 @@ IMPLEMENT_THFILE_RW(Int, int)
IMPLEMENT_THFILE_RW(Long, long)
IMPLEMENT_THFILE_RW(Float, float)
IMPLEMENT_THFILE_RW(Double, double)
+#if TH_GENERIC_USE_HALF
+IMPLEMENT_THFILE_RW(Half, half)
+#endif
size_t THFile_readStringRaw(THFile *self, const char *format, char **str_)
{
@@ -133,6 +136,7 @@ IMPLEMENT_THFILE_SCALAR(Int, int)
IMPLEMENT_THFILE_SCALAR(Long, long)
IMPLEMENT_THFILE_SCALAR(Float, float)
IMPLEMENT_THFILE_SCALAR(Double, double)
+IMPLEMENT_THFILE_SCALAR(Half, half)
#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \
size_t THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
@@ -152,3 +156,4 @@ IMPLEMENT_THFILE_STORAGE(Int, int)
IMPLEMENT_THFILE_STORAGE(Long, long)
IMPLEMENT_THFILE_STORAGE(Float, float)
IMPLEMENT_THFILE_STORAGE(Double, double)
+IMPLEMENT_THFILE_STORAGE(Half, half)
diff --git a/lib/TH/THFile.h b/lib/TH/THFile.h
index 64dd2da..7727196 100644
--- a/lib/TH/THFile.h
+++ b/lib/TH/THFile.h
@@ -74,6 +74,15 @@ TH_API size_t THFile_writeFloatRaw(THFile *self, float *data, size_t n);
TH_API size_t THFile_writeDoubleRaw(THFile *self, double *data, size_t n);
TH_API size_t THFile_writeStringRaw(THFile *self, const char *str, size_t size);
+#if TH_GENERIC_USE_HALF
+TH_API half THFile_readHalfScalar(THFile *self);
+TH_API void THFile_writeHalfScalar(THFile *self, half scalar);
+TH_API size_t THFile_readHalf(THFile *self, THHalfStorage *storage);
+TH_API size_t THFile_writeHalf(THFile *self, THHalfStorage *storage);
+TH_API size_t THFile_readHalfRaw(THFile *self, half* data, size_t size);
+TH_API size_t THFile_writeHalfRaw(THFile *self, half* data, size_t size);
+#endif
+
TH_API void THFile_synchronize(THFile *self);
TH_API void THFile_seek(THFile *self, size_t position);
TH_API void THFile_seekEnd(THFile *self);
diff --git a/lib/TH/THFilePrivate.h b/lib/TH/THFilePrivate.h
index d268041..f0b0885 100644
--- a/lib/TH/THFilePrivate.h
+++ b/lib/TH/THFilePrivate.h
@@ -1,3 +1,10 @@
+#include "THGeneral.h"
+
+#if TH_GENERIC_USE_HALF
+# include "THHalf.h"
+#endif
+
+
struct THFile__
{
struct THFileVTable *vtable;
@@ -23,6 +30,9 @@ struct THFileVTable
size_t (*readLong)(THFile *self, long *data, size_t n);
size_t (*readFloat)(THFile *self, float *data, size_t n);
size_t (*readDouble)(THFile *self, double *data, size_t n);
+#if TH_GENERIC_USE_HALF
+ size_t (*readHalf)(THFile *self, half *data, size_t n);
+#endif
size_t (*readString)(THFile *self, const char *format, char **str_);
size_t (*writeByte)(THFile *self, unsigned char *data, size_t n);
@@ -32,6 +42,9 @@ struct THFileVTable
size_t (*writeLong)(THFile *self, long *data, size_t n);
size_t (*writeFloat)(THFile *self, float *data, size_t n);
size_t (*writeDouble)(THFile *self, double *data, size_t n);
+#if TH_GENERIC_USE_HALF
+ size_t (*writeHalf)(THFile *self, half *data, size_t n);
+#endif
size_t (*writeString)(THFile *self, const char *str, size_t size);
void (*synchronize)(THFile *self);
diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in
index bc7e448..6873db6 100644
--- a/lib/TH/THGeneral.h.in
+++ b/lib/TH/THGeneral.h.in
@@ -46,6 +46,14 @@
#define TH_INDEX_BASE 1
#endif
+#ifndef TH_GENERIC_USE_HALF
+# define TH_GENERIC_USE_HALF 0
+#endif
+
+#ifndef TH_NATIVE_HALF
+# define TH_NATIVE_HALF 0
+#endif
+
typedef void (*THErrorHandlerFunction)(const char *msg, void *data);
typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data);
diff --git a/lib/TH/THGenerateAllTypes.h b/lib/TH/THGenerateAllTypes.h
index 539629b..093bfd9 100644
--- a/lib/TH/THGenerateAllTypes.h
+++ b/lib/TH/THGenerateAllTypes.h
@@ -2,6 +2,16 @@
#error "You must define TH_GENERIC_FILE before including THGenerateAllTypes.h"
#endif
+#define THTypeIdxByte 1
+#define THTypeIdxChar 2
+#define THTypeIdxShort 3
+#define THTypeIdxInt 4
+#define THTypeIdxLong 5
+#define THTypeIdxFloat 6
+#define THTypeIdxDouble 7
+#define THTypeIdxHalf 8
+#define THTypeIdx_(T) TH_CONCAT_2(THTypeIdx,T)
+
#define real unsigned char
#define accreal long
#define Real Byte
@@ -94,4 +104,24 @@
#undef THInf
#undef TH_REAL_IS_DOUBLE
+#if TH_GENERIC_USE_HALF
+#include "THHalf.h"
+#define real half
+#define accreal float
+#define Real Half
+#define THInf TH_HALF_MAX
+#define TH_REAL_IS_HALF
+#if !TH_NATIVE_HALF
+# define TH_GENERIC_NO_MATH 1
+#endif
+#line 1 TH_GENERIC_FILE
+#include TH_GENERIC_FILE
+#undef real
+#undef accreal
+#undef Real
+#undef THInf
+#undef TH_REAL_IS_HALF
+#undef TH_GENERIC_NO_MATH
+#endif
+
#undef TH_GENERIC_FILE
diff --git a/lib/TH/THHalf.c b/lib/TH/THHalf.c
new file mode 100644
index 0000000..fbc3e15
--- /dev/null
+++ b/lib/TH/THHalf.c
@@ -0,0 +1,141 @@
+#include "THHalf.h"
+#include "TH.h"
+
+static half half_max = TH_HALF_MAX;
+
+/*
+ * Copyright 1993-2014 NVIDIA Corporation. All rights reserved.
+ *
+ * NOTICE TO LICENSEE:
+ *
+ * This source code and/or documentation ("Licensed Deliverables") are
+ * subject to NVIDIA intellectual property rights under U.S. and
+ * international Copyright laws.
+ *
+ * These Licensed Deliverables contained herein is PROPRIETARY and
+ * CONFIDENTIAL to NVIDIA and is being provided under the terms and
+ * conditions of a form of NVIDIA software license agreement by and
+ * between NVIDIA and Licensee ("License Agreement") or electronically
+ * accepted by Licensee. Notwithstanding any terms or conditions to
+ * the contrary in the License Agreement, reproduction or disclosure
+ * of the Licensed Deliverables to any third party without the express
+ * written consent of NVIDIA is prohibited.
+ *
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
+ * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
+ * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
+ * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
+ * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
+ * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
+ * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
+ * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
+ * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
+ * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
+ * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
+ * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
+ * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
+ * OF THESE LICENSED DELIVERABLES.
+ *
+ * U.S. Government End Users. These Licensed Deliverables are a
+ * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
+ * 1995), consisting of "commercial computer software" and "commercial
+ * computer software documentation" as such terms are used in 48
+ * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
+ * only as a commercial end item. Consistent with 48 C.F.R.12.212 and
+ * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
+ * U.S. Government End Users acquire the Licensed Deliverables with
+ * only those rights set forth herein.
+ *
+ * Any use of the Licensed Deliverables in individual and commercial
+ * software must include, in the user documentation and internal
+ * comments to the code, the above Disclaimer and U.S. Government End
+ * Users Notice.
+ */
+
+// Host functions for converting between FP32 and FP16 formats
+// Paulius Micikevicius (pauliusm@nvidia.com)
+
+float TH_half2float(half h)
+{
+ unsigned sign = ((h.x >> 15) & 1);
+ unsigned exponent = ((h.x >> 10) & 0x1f);
+ unsigned mantissa = ((h.x & 0x3ff) << 13);
+
+ if (exponent == 0x1f) { /* NaN or Inf */
+ mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
+ exponent = 0xff;
+ } else if (!exponent) { /* Denorm or Zero */
+ if (mantissa) {
+ unsigned int msb;
+ exponent = 0x71;
+ do {
+ msb = (mantissa & 0x400000);
+ mantissa <<= 1; /* normalize */
+ --exponent;
+ } while (!msb);
+ mantissa &= 0x7fffff; /* 1.mantissa is implicit */
+ }
+ } else {
+ exponent += 0x70;
+ }
+
+ int temp = ((sign << 31) | (exponent << 23) | mantissa);
+
+ return *((float*)((void*)&temp));
+}
+
+half TH_float2half(float f)
+{
+ half ret;
+
+ unsigned x = *((int*)(void*)(&f));
+ unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
+ unsigned sign, exponent, mantissa;
+
+ // Get rid of +NaN/-NaN case first.
+ if (u > 0x7f800000) {
+ ret.x = 0x7fffU;
+ return ret;
+ }
+
+ sign = ((x >> 16) & 0x8000);
+
+ // Get rid of +Inf/-Inf, +0/-0.
+ if (u > 0x477fefff) {
+ ret.x = sign | 0x7c00U;
+ return ret;
+ }
+ if (u < 0x33000001) {
+ ret.x = (sign | 0x0000);
+ return ret;
+ }
+
+ exponent = ((u >> 23) & 0xff);
+ mantissa = (u & 0x7fffff);
+
+ if (exponent > 0x70) {
+ shift = 13;
+ exponent -= 0x70;
+ } else {
+ shift = 0x7e - exponent;
+ exponent = 0;
+ mantissa |= 0x800000;
+ }
+ lsb = (1 << shift);
+ lsb_s1 = (lsb >> 1);
+ lsb_m1 = (lsb - 1);
+
+ // Round to nearest even.
+ remainder = (mantissa & lsb_m1);
+ mantissa >>= shift;
+ if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
+ ++mantissa;
+ if (!(mantissa & 0x3ff)) {
+ ++exponent;
+ mantissa = 0;
+ }
+ }
+
+ ret.x = (sign | (exponent << 10) | mantissa);
+ return ret;
+}
diff --git a/lib/TH/THHalf.h b/lib/TH/THHalf.h
new file mode 100644
index 0000000..01dab4e
--- /dev/null
+++ b/lib/TH/THHalf.h
@@ -0,0 +1,56 @@
+#ifndef _THHALF_H
+# define _THHALF_H
+
+#include <stdint.h>
+
+#include "THGeneral.h"
+
+# if defined (TH_HALF_TYPE)
+typedef TH_HALF_TYPE half;
+# else
+/* Neither built-in nor included from Cutorch, use our definition lifted from CUDA */
+#if defined(__GNUC__)
+#define __align__(n) __attribute__((aligned(n)))
+#elif defined(_WIN32)
+#define __align__(n) __declspec(align(n))
+#else
+#define __align__(n)
+#endif
+
+typedef struct __align__(2){
+ unsigned short x;
+} __half;
+
+typedef struct __align__(4) {
+ unsigned int x;
+} __half2;
+
+typedef __half half;
+typedef __half2 half2;
+# endif
+
+/* numeric limits */
+
+
+TH_API half TH_float2half(float a);
+TH_API float TH_half2float(half a);
+
+#ifndef TH_HALF_BITS_TO_LITERAL
+# define TH_HALF_BITS_TO_LITERAL(n) { n }
+#endif
+
+#define TH_HALF_ZERO TH_HALF_BITS_TO_LITERAL(0x0)
+#define TH_HALF_MIN TH_HALF_BITS_TO_LITERAL(0x0400)
+#define TH_HALF_MAX TH_HALF_BITS_TO_LITERAL(0x7BFF)
+#define TH_HALF_EPSILON TH_HALF_BITS_TO_LITERAL(0x1400)
+#define TH_HALF_INF TH_HALF_BITS_TO_LITERAL(0x7C00)
+#define TH_HALF_QNAN TH_HALF_BITS_TO_LITERAL(0x7FFF)
+#define TH_HALF_SNAN TH_HALF_BITS_TO_LITERAL(0x7DFF)
+#define TH_HALF_DENORM_MIN TH_HALF_BITS_TO_LITERAL(0x0001)
+#define TH_HALF_DIGITS 11
+#define TH_HALF_DIGITS10 3
+#define TH_HALF_DIGITS10_MAX 5
+#define TH_HALF_MAX_EXPONENT 16
+#define TH_HALF_MAX_EXPONENT10 4
+
+#endif
diff --git a/lib/TH/THMemoryFile.c b/lib/TH/THMemoryFile.c
index 8d97621..9079d6c 100644
--- a/lib/TH/THMemoryFile.c
+++ b/lib/TH/THMemoryFile.c
@@ -341,7 +341,14 @@ READ_WRITE_METHODS(float, Float,
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", data[i]),
1)
-
+#if TH_GENERIC_USE_HALF
+READ_WRITE_METHODS(half, Half,
+ int nByteRead_; float buf; \
+ int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &buf, &nByteRead_); \
+ data[i] = TH_float2half(buf); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
+ nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", TH_half2float(data[i])),
+ 1)
+#endif
READ_WRITE_METHODS(double, Double,
int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++,
nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.17g", data[i]),
@@ -621,6 +628,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode)
THMemoryFile_readLong,
THMemoryFile_readFloat,
THMemoryFile_readDouble,
+ THMemoryFile_readHalf,
THMemoryFile_readString,
THMemoryFile_writeByte,
@@ -630,6 +638,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode)
THMemoryFile_writeLong,
THMemoryFile_writeFloat,
THMemoryFile_writeDouble,
+ THMemoryFile_writeHalf,
THMemoryFile_writeString,
THMemoryFile_synchronize,
diff --git a/lib/TH/THTensor.c b/lib/TH/THTensor.c
index 2878fc9..6305668 100644
--- a/lib/TH/THTensor.c
+++ b/lib/TH/THTensor.c
@@ -14,10 +14,10 @@
#include "generic/THTensorCopy.c"
#include "THGenerateAllTypes.h"
-#include "generic/THTensorRandom.c"
+#include "generic/THTensorMath.c"
#include "THGenerateAllTypes.h"
-#include "generic/THTensorMath.c"
+#include "generic/THTensorRandom.c"
#include "THGenerateAllTypes.h"
#include "generic/THTensorConv.c"
diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c
index 6179d89..f530a84 100644
--- a/lib/TH/THVector.c
+++ b/lib/TH/THVector.c
@@ -1,4 +1,5 @@
#include "THVector.h"
+
#include "generic/simd/simd.h"
#ifdef __NEON__
diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c
index 6452f94..8b3a403 100644
--- a/lib/TH/generic/THBlas.c
+++ b/lib/TH/generic/THBlas.c
@@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/THBlas.c"
#else
+# ifndef TH_GENERIC_NO_MATH
+
#ifdef BLAS_F2C
# define ffloat double
#else
@@ -24,8 +26,8 @@ TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, doubl
TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda);
TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);
-
-
+
+
void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
{
@@ -182,9 +184,9 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re
{
if(n == 1)
lda = m;
-
+
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
- if( (m <= INT_MAX) && (n <= INT_MAX) &&
+ if( (m <= INT_MAX) && (n <= INT_MAX) &&
(lda > 0) && (lda <= INT_MAX) &&
(incx > 0) && (incx <= INT_MAX) &&
(incy > 0) && (incy <= INT_MAX) )
@@ -224,7 +226,7 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re
{
if(beta != 1)
THBlas_(scal)(m, beta, y, incy);
-
+
for(j = 0; j < n; j++)
{
real *column_ = a+lda*j;
@@ -402,5 +404,5 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha,
}
}
}
-
+# endif /* TH_GENERIC_NO_MATH */
#endif
diff --git a/lib/TH/generic/THBlas.h b/lib/TH/generic/THBlas.h
index 9e14f5a..a49d79c 100644
--- a/lib/TH/generic/THBlas.h
+++ b/lib/TH/generic/THBlas.h
@@ -1,7 +1,7 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THBlas.h"
#else
-
+# ifndef TH_GENERIC_NO_MATH
/* Level 1 */
TH_API void THBlas_(swap)(long n, real *x, long incx, real *y, long incy);
TH_API void THBlas_(scal)(long n, real a, real *x, long incx);
@@ -15,5 +15,5 @@ TH_API void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y
/* Level 3 */
TH_API void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc);
-
+# endif
#endif
diff --git a/lib/TH/generic/THStorageCopy.c b/lib/TH/generic/THStorageCopy.c
index 583e088..a0b2c95 100644
--- a/lib/TH/generic/THStorageCopy.c
+++ b/lib/TH/generic/THStorageCopy.c
@@ -15,16 +15,38 @@ void THStorage_(copy)(THStorage *storage, THStorage *src)
THStorage_(rawCopy)(storage, src->data);
}
-
#define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \
void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \
{ \
- ptrdiff_t i; \
THArgCheck(storage->size == src->size, 2, "size mismatch"); \
- for(i = 0; i < storage->size; i++) \
- storage->data[i] = (real)src->data[i]; \
+ if(THTypeIdx_(Real) == THTypeIdx_(TYPENAMESRC)) { \
+ memcpy(storage->data, src->data, sizeof(real)*storage->size); /* cast just removes compiler warning */ \
+ } else { \
+ ptrdiff_t i; \
+ for(i = 0; i < storage->size; i++) \
+ storage->data[i] = (real)src->data[i]; \
+ } \
+}
+
+#define IMPLEMENT_THStorage_COPY_FROM_HALF(TYPENAMESRC) \
+void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \
+{ \
+ THArgCheck(storage->size == src->size, 2, "size mismatch"); \
+ ptrdiff_t i; \
+ for(i = 0; i < storage->size; i++) \
+ storage->data[i] = (real)TH_half2float(src->data[i]); \
}
+#define IMPLEMENT_THStorage_COPY_TO_HALF(TYPENAMESRC) \
+void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \
+{ \
+ THArgCheck(storage->size == src->size, 2, "size mismatch"); \
+ ptrdiff_t i; \
+ for(i = 0; i < storage->size; i++) \
+ storage->data[i] = TH_float2half((float)(src->data[i])); \
+}
+
+#ifndef TH_REAL_IS_HALF
IMPLEMENT_THStorage_COPY(Byte)
IMPLEMENT_THStorage_COPY(Char)
IMPLEMENT_THStorage_COPY(Short)
@@ -32,5 +54,20 @@ IMPLEMENT_THStorage_COPY(Int)
IMPLEMENT_THStorage_COPY(Long)
IMPLEMENT_THStorage_COPY(Float)
IMPLEMENT_THStorage_COPY(Double)
+#if TH_GENERIC_USE_HALF
+IMPLEMENT_THStorage_COPY_FROM_HALF(Half)
+#endif
+#else
+/* only allow pass-through for Half */
+IMPLEMENT_THStorage_COPY(Half)
+IMPLEMENT_THStorage_COPY_TO_HALF(Byte)
+IMPLEMENT_THStorage_COPY_TO_HALF(Char)
+IMPLEMENT_THStorage_COPY_TO_HALF(Short)
+IMPLEMENT_THStorage_COPY_TO_HALF(Int)
+IMPLEMENT_THStorage_COPY_TO_HALF(Long)
+IMPLEMENT_THStorage_COPY_TO_HALF(Float)
+IMPLEMENT_THStorage_COPY_TO_HALF(Double)
+#endif
+
#endif
diff --git a/lib/TH/generic/THStorageCopy.h b/lib/TH/generic/THStorageCopy.h
index f853a82..bb2f406 100644
--- a/lib/TH/generic/THStorageCopy.h
+++ b/lib/TH/generic/THStorageCopy.h
@@ -13,5 +13,8 @@ TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src);
TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src);
TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src);
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
+#if TH_GENERIC_USE_HALF
+TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
+#endif
#endif
diff --git a/lib/TH/generic/THTensorConv.c b/lib/TH/generic/THTensorConv.c
index d98a2aa..aa864ad 100644
--- a/lib/TH/generic/THTensorConv.c
+++ b/lib/TH/generic/THTensorConv.c
@@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/THTensorConv.c"
#else
+/* Tensor math may be disabled for certain types, e.g. 'TH_half' */
+#ifndef TH_GENERIC_NO_MATH
/*
2D Input, 2D kernel : convolve given image with the given kernel.
@@ -775,7 +777,7 @@ void THTensor_(conv2DRevgerm)(THTensor *r_, real beta, real alpha, THTensor *t_,
real *ptr_output = output_data + k*nInputPlane*nOutputCols*nOutputRows + i*nOutputCols*nOutputRows;
/* get input */
real *ptr_input = input_data + p*istride0 + i*istride1;
-
+
/* do image, kernel convolution */
THTensor_(validXCorr2DRevptr)(ptr_output,
alpha,
@@ -1174,7 +1176,7 @@ void THTensor_(conv2Dmm)(THTensor *r_, real beta, real alpha, THTensor *t_, THTe
real *ptr_weight = weight_data + k*kstride0 + i*kstride1;
/* get input */
real *ptr_input = input_data + p*nInputPlane*nInputRows*nInputCols + i*nInputRows*nInputCols;
-
+
/* do image, kernel convolution */
if (*vf == 'F')
if (*xc == 'X')
@@ -1955,5 +1957,5 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
THTensor_(free)(input);
THTensor_(free)(kernel);
}
-
+#endif
#endif
diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c
index ea6d6f1..6c7daa8 100644
--- a/lib/TH/generic/THTensorCopy.c
+++ b/lib/TH/generic/THTensorCopy.c
@@ -13,6 +13,19 @@ void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \
}
+#define IMPLEMENT_THTensor_COPY_TO_HALF(TYPENAMESRC, TYPE_SRC) \
+void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
+{ \
+ TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = TH_float2half((float)*src_data);) \
+}
+
+#define IMPLEMENT_THTensor_COPY_FROM_HALF(TYPENAMESRC, TYPE_SRC) \
+void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
+{ \
+ TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)TH_half2float(*src_data);) \
+}
+
+#ifndef TH_REAL_IS_HALF
IMPLEMENT_THTensor_COPY(Byte, unsigned char)
IMPLEMENT_THTensor_COPY(Char, char)
IMPLEMENT_THTensor_COPY(Short, short)
@@ -20,5 +33,20 @@ IMPLEMENT_THTensor_COPY(Int, int)
IMPLEMENT_THTensor_COPY(Long, long)
IMPLEMENT_THTensor_COPY(Float, float)
IMPLEMENT_THTensor_COPY(Double, double)
+#if TH_GENERIC_USE_HALF
+IMPLEMENT_THTensor_COPY_FROM_HALF(Half, half)
+#endif
+#else
+/* only allow pass-through for Half */
+IMPLEMENT_THTensor_COPY(Half, half)
+IMPLEMENT_THTensor_COPY_TO_HALF(Byte, unsigned char)
+IMPLEMENT_THTensor_COPY_TO_HALF(Char, char)
+IMPLEMENT_THTensor_COPY_TO_HALF(Short, short)
+IMPLEMENT_THTensor_COPY_TO_HALF(Int, int)
+IMPLEMENT_THTensor_COPY_TO_HALF(Long, long)
+IMPLEMENT_THTensor_COPY_TO_HALF(Float, float)
+IMPLEMENT_THTensor_COPY_TO_HALF(Double, double)
+
+#endif /* REAL_IS_HALF */
#endif
diff --git a/lib/TH/generic/THTensorCopy.h b/lib/TH/generic/THTensorCopy.h
index 8d03b22..8a0abcf 100644
--- a/lib/TH/generic/THTensorCopy.h
+++ b/lib/TH/generic/THTensorCopy.h
@@ -12,5 +12,7 @@ TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src);
TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src);
TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src);
TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src);
-
+#if TH_GENERIC_USE_HALF
+TH_API void THTensor_(copyHalf)(THTensor *tensor, struct THHalfTensor *src);
+#endif
#endif
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index b275d8f..5570dc2 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -2,6 +2,7 @@
#define TH_GENERIC_FILE "generic/THTensorMath.c"
#else
+#ifndef TH_GENERIC_NO_MATH
#define TH_OMP_OVERHEAD_THRESHOLD 100000
void THTensor_(fill)(THTensor *r_, real value)
@@ -98,10 +99,15 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
long i = 0;
long dim;
long div = 1;
+#ifdef TH_REAL_IS_HALF
+#define IS_NONZERO(val) ((val).x!=0)
+#else
+#define IS_NONZERO(val) ((val)!=0)
+#endif
/* First Pass to determine size of subscripts */
TH_TENSOR_APPLY(real, tensor,
- if (*tensor_data != 0) {
+ if IS_NONZERO(*tensor_data) {
++numel;
});
#ifdef DEBUG
@@ -112,7 +118,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
/* Second pass populates subscripts */
subscript_data = THLongTensor_data(subscript);
TH_TENSOR_APPLY(real, tensor,
- if (*tensor_data != 0) {
+ if IS_NONZERO(*tensor_data) {
div = 1;
for (dim = tensor->nDimension - 1; dim >= 0; dim--) {
@@ -396,6 +402,7 @@ accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
return sum;
}
+
#undef th_isnan
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
#define th_isnan(val) \
@@ -2499,4 +2506,6 @@ void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalu
}
#endif /* floating point only part */
+#endif /* TH_GENERIC_NO_MATH */
+#undef IS_NONZERO
#endif
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index 87f1616..781318c 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -128,6 +128,16 @@ TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
#endif
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF)
+TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
+TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
+#endif
+
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
+TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
+TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
+#endif
+
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_API void THTensor_(sigmoid)(THTensor *r_, THTensor *t);
@@ -171,9 +181,6 @@ TH_API accreal THTensor_(normall)(THTensor *t, real value);
TH_API void THTensor_(linspace)(THTensor *r_, real a, real b, long n);
TH_API void THTensor_(logspace)(THTensor *r_, real a, real b, long n);
-TH_API void THTensor_(rand)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
-TH_API void THTensor_(randn)(THTensor *r_, THGenerator *_generator, THLongStorage *size);
-
#endif
#if defined(TH_REAL_IS_BYTE)
diff --git a/lib/TH/generic/THTensorRandom.c b/lib/TH/generic/THTensorRandom.c
index 514d3dd..18b0471 100644
--- a/lib/TH/generic/THTensorRandom.c
+++ b/lib/TH/generic/THTensorRandom.c
@@ -2,6 +2,9 @@
#define TH_GENERIC_FILE "generic/THTensorRandom.c"
#else
+/* Tensor math may be disabled for certain types, e.g. 'half' */
+#ifndef TH_GENERIC_NO_MATH
+
void THTensor_(random)(THTensor *self, THGenerator *_generator)
{
#if defined(TH_REAL_IS_BYTE)
@@ -247,4 +250,6 @@ void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self)
}
#endif
+#endif /* TH_GENERIC_NO_MATH */
+
#endif
diff --git a/lib/TH/generic/THVector.h b/lib/TH/generic/THVector.h
index 67fdcfa..eaeb008 100644
--- a/lib/TH/generic/THVector.h
+++ b/lib/TH/generic/THVector.h
@@ -2,11 +2,13 @@
#define TH_GENERIC_FILE "generic/THVector.h"
#else
+#ifndef TH_GENERIC_NO_MATH
TH_API void THVector_(fill)(real *x, const real c, const ptrdiff_t n);
TH_API void THVector_(add)(real *y, const real *x, const real c, const ptrdiff_t n);
TH_API void THVector_(diff)(real *z, const real *x, const real *y, const ptrdiff_t n);
TH_API void THVector_(scale)(real *y, const real c, const ptrdiff_t n);
TH_API void THVector_(mul)(real *y, const real *x, const ptrdiff_t n);
+#endif
/* Initialize the dispatch pointers */
TH_API void THVector_(vectorDispatchInit)(void);
diff --git a/lib/TH/generic/THVectorDefault.c b/lib/TH/generic/THVectorDefault.c
index aabc16c..7554d45 100644
--- a/lib/TH/generic/THVectorDefault.c
+++ b/lib/TH/generic/THVectorDefault.c
@@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/THVectorDefault.c"
#else
+#ifndef TH_GENERIC_NO_MATH
+
void THVector_(fill_DEFAULT)(real *x, const real c, const ptrdiff_t n) {
ptrdiff_t i = 0;
@@ -80,5 +82,5 @@ void THVector_(mul_DEFAULT)(real *y, const real *x, const ptrdiff_t n)
for(; i < n; i++)
y[i] *= x[i];
}
-
+# endif
#endif
diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c
index 6fd1d68..3624af0 100644
--- a/lib/TH/generic/THVectorDispatch.c
+++ b/lib/TH/generic/THVectorDispatch.c
@@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/THVectorDispatch.c"
#else
+#ifndef TH_GENERIC_NO_MATH
+
/* For now there are only SIMD implementations for FLOAT and DOUBLE.
* Hopefully in the future this can be made totally generic (e.g, there are SIMD implementations
* for a lot of functions */
@@ -32,7 +34,6 @@ void THVector_(fill)(real *x, const real c, const ptrdiff_t n) {
THVector_(fill_DISPATCHPTR)(x, c, n);
}
-
static void (*THVector_(add_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(add_DEFAULT);
static FunctionDescription THVector_(add_DISPATCHTABLE)[] = {
#if defined(__NEON__)
@@ -136,5 +137,5 @@ void THVector_(vectorDispatchInit)(void)
INIT_DISPATCH_PTR(scale);
INIT_DISPATCH_PTR(mul);
}
-
+#endif /* TH_GENERIC_NO_MATH */
#endif
diff --git a/test/test_half.lua b/test/test_half.lua
new file mode 100644
index 0000000..480fae8
--- /dev/null
+++ b/test/test_half.lua
@@ -0,0 +1,124 @@
+local mytester
+local torchtest = torch.TestSuite()
+local msize = 100
+local precision
+
+-- Lua 5.2 compatibility
+local loadstring = loadstring or load
+local unpack = unpack or table.unpack
+
+local function maxdiff(x,y)
+ local d = x-y
+ if x:type() == 'torch.DoubleTensor' or x:type() == 'torch.FloatTensor' or x:type() == 'torch.HalfTensor' then
+ return d:abs():max()
+ else
+ local dd = torch.Tensor():resize(d:size()):copy(d)
+ return dd:abs():max()
+ end
+end
+
+
+function torchtest.elementSize()
+ local byte = torch.ByteStorage():elementSize()
+ local char = torch.CharStorage():elementSize()
+ local short = torch.ShortStorage():elementSize()
+ local int = torch.IntStorage():elementSize()
+ local long = torch.LongStorage():elementSize()
+ local float = torch.FloatStorage():elementSize()
+ local double = torch.DoubleStorage():elementSize()
+ local half = torch.HalfStorage():elementSize()
+
+ mytester:asserteq(byte, torch.ByteTensor():elementSize())
+ mytester:asserteq(char, torch.CharTensor():elementSize())
+ mytester:asserteq(short, torch.ShortTensor():elementSize())
+ mytester:asserteq(int, torch.IntTensor():elementSize())
+ mytester:asserteq(long, torch.LongTensor():elementSize())
+ mytester:asserteq(float, torch.FloatTensor():elementSize())
+ mytester:asserteq(double, torch.DoubleTensor():elementSize())
+ mytester:asserteq(half, torch.HalfTensor():elementSize())
+
+ mytester:assertne(byte, 0)
+ mytester:assertne(char, 0)
+ mytester:assertne(short, 0)
+ mytester:assertne(int, 0)
+ mytester:assertne(long, 0)
+ mytester:assertne(float, 0)
+ mytester:assertne(double, 0)
+ mytester:assertne(half, 0)
+
+ -- These tests are portable, not necessarily strict for your system.
+ mytester:asserteq(byte, 1)
+ mytester:asserteq(char, 1)
+ mytester:assert(short >= 2)
+ mytester:assert(int >= 2)
+ mytester:assert(int >= short)
+ mytester:assert(long >= 4)
+ mytester:assert(long >= int)
+ mytester:assert(double >= float)
+ mytester:assert(half <= float)
+end
+
+function torchtest.isTensor()
+ local t = torch.randn(3,4):half()
+ print("\n Tensor:half() result: ", t)
+
+ mytester:assert(torch.isTensor(t), 'error in isTensor')
+ mytester:assert(torch.isTensor(t[1]), 'error in isTensor for subTensor')
+ mytester:assert(not torch.isTensor(t[1][2]), 'false positive in isTensor')
+ mytester:assert(torch.Tensor.isTensor(t), 'alias not working')
+end
+function torchtest.isStorage()
+ local t = torch.randn(3,4):half()
+ mytester:assert(torch.isStorage(t:storage()), 'error in isStorage')
+ mytester:assert(not torch.isStorage(t), 'false positive in isStorage')
+end
+
+function torchtest.expand()
+ local result = torch.Tensor():half()
+ local tensor = torch.rand(8,1)
+ local template = torch.rand(8,5)
+ local target = template:size():totable()
+ mytester:assertTableEq(tensor:expandAs(template):size():totable(), target, 'Error in expandAs')
+ mytester:assertTableEq(tensor:expand(8,5):size():totable(), target, 'Error in expand')
+ mytester:assertTableEq(tensor:expand(torch.LongStorage{8,5}):size():totable(), target, 'Error in expand using LongStorage')
+ result:expandAs(tensor,template)
+ mytester:assertTableEq(result:size():totable(), target, 'Error in expandAs using result')
+ result:expand(tensor,8,5)
+ mytester:assertTableEq(result:size():totable(), target, 'Error in expand using result')
+ result:expand(tensor,torch.LongStorage{8,5})
+ mytester:assertTableEq(result:size():totable(), target, 'Error in expand using result and LongStorage')
+ mytester:asserteq((result:mean(2):view(8,1)-tensor):abs():max(), 0, 'Error in expand (not equal)')
+end
+
+function torchtest.repeatTensor()
+ local result = torch.Tensor():half()
+ local tensor = torch.rand(8,4)
+ local size = {3,1,1}
+ local sizeStorage = torch.LongStorage(size)
+ local target = {3,8,4}
+ mytester:assertTableEq(tensor:repeatTensor(unpack(size)):size():totable(), target, 'Error in repeatTensor')
+ mytester:assertTableEq(tensor:repeatTensor(sizeStorage):size():totable(), target, 'Error in repeatTensor using LongStorage')
+ result:repeatTensor(tensor,unpack(size))
+ mytester:assertTableEq(result:size():totable(), target, 'Error in repeatTensor using result')
+ result:repeatTensor(tensor,sizeStorage)
+ mytester:assertTableEq(result:size():totable(), target, 'Error in repeatTensor using result and LongStorage')
+ mytester:asserteq((result:mean(1):view(8,4)-tensor):abs():max(), 0, 'Error in repeatTensor (not equal)')
+end
+
+function torchtest.isSameSizeAs()
+ local t1 = torch.Tensor(3, 4, 9, 10):half()
+ local t2 = torch.Tensor(3, 4):half()
+ local t3 = torch.Tensor(1, 9, 3, 3):half()
+ local t4 = torch.Tensor(3, 4, 9, 10):half()
+
+ mytester:assert(t1:isSameSizeAs(t2) == false, "wrong answer ")
+ mytester:assert(t1:isSameSizeAs(t3) == false, "wrong answer ")
+ mytester:assert(t1:isSameSizeAs(t4) == true, "wrong answer ")
+end
+
+ torch.setheaptracking(true)
+ math.randomseed(os.time())
+ precision = 1e-4
+ mytester = torch.Tester()
+ mytester:add(torchtest)
+ mytester:run(tests)
diff --git a/torchcwrap.lua b/torchcwrap.lua
index ab0df43..551bd05 100644
--- a/torchcwrap.lua
+++ b/torchcwrap.lua
@@ -202,7 +202,7 @@ types.IndexTensor = {
}
for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor",
- "FloatTensor", "DoubleTensor"}) do
+ "FloatTensor", "HalfTensor", "DoubleTensor"}) do
types[typename] = {
@@ -460,3 +460,56 @@ types.charoption = {
postcall = function(arg)
end
}
+
+for _,typename in ipairs({"ptrdiff_t", "size_t"}) do
+ types[typename] = {
+
+ helpname = function(arg)
+ return typename
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(tostring(arg.default)) or 0
+ return string.format("%s arg%d = %g;", typename, arg.i, default)
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, typename, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = tostring(arg.default)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end
+ }
+end