Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-10-07 17:43:56 +0300
committerGitHub <noreply@github.com>2016-10-07 17:43:56 +0300
commit04e1d1dce0f02aea82dc433c4f39e42650c4390f (patch)
treebae3773325263d43ef35eeb6a25412ed9fb2f20f
parenteb397ad0bd77cfe67a5e27f05d25217cba95ffef (diff)
parent3bf1bc2f428ac625a8bcf4150bb4ed0591241827 (diff)
Merge pull request #785 from BTNC/replace-long
replace long with ptrdiff_t for memory size/offset, element counts etc
-rw-r--r--FFI.lua8
-rw-r--r--File.c10
-rw-r--r--TensorMath.lua54
-rw-r--r--general.h1
-rw-r--r--generic/Storage.c33
-rw-r--r--generic/Tensor.c26
-rw-r--r--lib/TH/CMakeLists.txt2
-rw-r--r--lib/TH/THAllocator.c26
-rw-r--r--lib/TH/THAllocator.h6
-rw-r--r--lib/TH/THAtomic.c86
-rw-r--r--lib/TH/THAtomic.h29
-rw-r--r--lib/TH/THDiskFile.c24
-rw-r--r--lib/TH/THGeneral.c54
-rw-r--r--lib/TH/THGeneral.h.in7
-rw-r--r--lib/TH/THLogAdd.c2
-rw-r--r--lib/TH/THMemoryFile.c12
-rw-r--r--lib/TH/generic/THStorage.c28
-rw-r--r--lib/TH/generic/THStorage.h22
-rw-r--r--lib/TH/generic/THStorageCopy.c4
-rw-r--r--lib/TH/generic/THTensor.c43
-rw-r--r--lib/TH/generic/THTensor.h26
-rw-r--r--lib/TH/generic/THTensorConv.c28
-rw-r--r--lib/TH/generic/THTensorMath.c91
-rw-r--r--lib/TH/generic/THTensorMath.h2
-rw-r--r--lib/TH/generic/THVector.h10
-rw-r--r--lib/TH/generic/THVectorDefault.c20
-rw-r--r--lib/TH/generic/THVectorDispatch.c20
-rw-r--r--lib/TH/vector/NEON.c10
-rw-r--r--lib/TH/vector/SSE.c56
-rw-r--r--lib/luaT/luaT.c45
-rw-r--r--lib/luaT/luaT.h7
31 files changed, 521 insertions, 271 deletions
diff --git a/FFI.lua b/FFI.lua
index 904302a..3cc0b21 100644
--- a/FFI.lua
+++ b/FFI.lua
@@ -28,8 +28,8 @@ if ok then
-- Allocator
ffi.cdef[[
typedef struct THAllocator {
- void* (*malloc)(void*, long);
- void* (*realloc)(void*, void*, long);
+ void* (*malloc)(void*, ptrdiff_t);
+ void* (*realloc)(void*, void*, ptrdiff_t);
void (*free)(void*, void*);
} THAllocator;
]]
@@ -41,7 +41,7 @@ typedef struct THAllocator {
typedef struct THRealStorage
{
real *data;
- long size;
+ ptrdiff_t size;
int refcount;
char flag;
THAllocator *allocator;
@@ -78,7 +78,7 @@ typedef struct THRealTensor
int nDimension;
THRealStorage *storage;
- long storageOffset;
+ ptrdiff_t storageOffset;
int refcount;
char flag;
diff --git a/File.c b/File.c
index 586efed..e07bc46 100644
--- a/File.c
+++ b/File.c
@@ -39,7 +39,7 @@ IMPLEMENT_TORCH_FILE_FUNC(synchronize)
static int torch_File_seek(lua_State *L)
{
THFile *self = luaT_checkudata(L, 1, "torch.File");
- long position = luaL_checklong(L, 2)-1;
+ ptrdiff_t position = luaL_checkinteger(L, 2)-1;
// >= 0 because it has 1 already subtracted
THArgCheck(position >= 0, 2, "position has to be greater than 0!");
THFile_seek(self, (size_t)position);
@@ -73,8 +73,8 @@ IMPLEMENT_TORCH_FILE_FUNC(close)
{ \
if(lua_isnumber(L, 2)) \
{ \
- long size = lua_tonumber(L, 2); \
- long nread; \
+ ptrdiff_t size = lua_tonumber(L, 2); \
+ ptrdiff_t nread; \
\
TH##TYPEC##Storage *storage = TH##TYPEC##Storage_newWithSize(size); \
luaT_pushudata(L, storage, "torch." #TYPEC "Storage"); \
@@ -134,7 +134,7 @@ static int torch_File_readString(lua_State *L)
THFile *self = luaT_checkudata(L, 1, "torch.File");
const char *format = luaL_checkstring(L, 2);
char *str;
- long size;
+ ptrdiff_t size;
size = THFile_readStringRaw(self, format, &str);
lua_pushlstring(L, str, size);
@@ -151,7 +151,7 @@ static int torch_File_writeString(lua_State *L)
luaL_checktype(L, 2, LUA_TSTRING);
str = lua_tolstring(L, 2, &size);
- lua_pushnumber(L, THFile_writeStringRaw(self, str, (long)size));
+ lua_pushnumber(L, THFile_writeStringRaw(self, str, size));
return 1;
}
diff --git a/TensorMath.lua b/TensorMath.lua
index 6b79237..682de23 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -4,6 +4,58 @@ require 'torchcwrap'
local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
+local argtypes = wrap.CInterface.argtypes
+
+argtypes['ptrdiff_t'] = {
+
+ helpname = function(arg)
+ return 'ptrdiff_t'
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(tostring(arg.default)) or 0
+ return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = tostring(arg.default)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end
+}
interface:print([[
#include "TH.h"
@@ -533,7 +585,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
wrap("numel",
cname("numel"),
{{name=Tensor},
- {name="long", creturned=true}})
+ {name="ptrdiff_t", creturned=true}})
for _,name in ipairs({"cumsum", "cumprod"}) do
wrap(name,
diff --git a/general.h b/general.h
index 4896adf..3ccf4bd 100644
--- a/general.h
+++ b/general.h
@@ -3,6 +3,7 @@
#include <stdlib.h>
#include <string.h>
+#include <stddef.h>
#include "luaT.h"
#include "TH.h"
diff --git a/generic/Storage.c b/generic/Storage.c
index 287796b..134dc63 100644
--- a/generic/Storage.c
+++ b/generic/Storage.c
@@ -20,15 +20,15 @@ static int torch_Storage_(new)(lua_State *L)
int isShared = 0;
if(luaT_optboolean(L, index + 1, 0))
isShared = TH_ALLOCATOR_MAPPED_SHARED;
- long size = luaL_optlong(L, index + 2, 0);
+ ptrdiff_t size = luaL_optinteger(L, index + 2, 0);
if (isShared && luaT_optboolean(L, index + 3, 0))
isShared = TH_ALLOCATOR_MAPPED_SHAREDMEM;
storage = THStorage_(newWithMapping)(fileName, size, isShared);
}
else if(lua_type(L, index) == LUA_TTABLE)
{
- long size = lua_objlen(L, index);
- long i;
+ ptrdiff_t size = lua_objlen(L, index);
+ ptrdiff_t i;
if (allocator)
storage = THStorage_(newWithAllocator)(size, allocator, NULL);
else
@@ -52,11 +52,11 @@ static int torch_Storage_(new)(lua_State *L)
THStorage *src = luaT_checkudata(L, index, torch_Storage);
real *ptr = src->data;
- long offset = luaL_optlong(L, index + 1, 1) - 1;
+ ptrdiff_t offset = luaL_optinteger(L, index + 1, 1) - 1;
if (offset < 0 || offset >= src->size) {
luaL_error(L, "offset out of bounds");
}
- long size = luaL_optlong(L, index + 2, src->size - offset);
+ ptrdiff_t size = luaL_optinteger(L, index + 2, src->size - offset);
if (size < 1 || size > (src->size - offset)) {
luaL_error(L, "size out of bounds");
}
@@ -67,8 +67,8 @@ static int torch_Storage_(new)(lua_State *L)
}
else if(lua_type(L, index + 1) == LUA_TNUMBER)
{
- long size = luaL_optlong(L, index, 0);
- real *ptr = (real *)luaL_optlong(L, index + 1, 0);
+ ptrdiff_t size = luaL_optinteger(L, index, 0);
+ real *ptr = (real *)luaL_optinteger(L, index + 1, 0);
if (allocator)
storage = THStorage_(newWithDataAndAllocator)(ptr, size, allocator, NULL);
else
@@ -77,7 +77,7 @@ static int torch_Storage_(new)(lua_State *L)
}
else
{
- long size = luaL_optlong(L, index, 0);
+ ptrdiff_t size = luaL_optinteger(L, index, 0);
if (allocator)
storage = THStorage_(newWithAllocator)(size, allocator, NULL);
else
@@ -104,7 +104,7 @@ static int torch_Storage_(free)(lua_State *L)
static int torch_Storage_(resize)(lua_State *L)
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- long size = luaL_checklong(L, 2);
+ ptrdiff_t size = luaL_checkinteger(L, 2);
/* int keepContent = luaT_optboolean(L, 3, 0); */
THStorage_(resize)(storage, size);/*, keepContent); */
lua_settop(L, 1);
@@ -148,14 +148,14 @@ static int torch_Storage_(fill)(lua_State *L)
static int torch_Storage_(elementSize)(lua_State *L)
{
- luaT_pushlong(L, THStorage_(elementSize)());
+ luaT_pushinteger(L, THStorage_(elementSize)());
return 1;
}
static int torch_Storage_(__len__)(lua_State *L)
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- luaT_pushlong(L, storage->size);
+ luaT_pushinteger(L, storage->size);
return 1;
}
@@ -164,7 +164,7 @@ static int torch_Storage_(__newindex__)(lua_State *L)
if(lua_isnumber(L, 2))
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- long index = luaL_checklong(L, 2) - 1;
+ ptrdiff_t index = luaL_checkinteger(L, 2) - 1;
real number = luaG_(checkreal)(L, 3);
THStorage_(set)(storage, index, number);
lua_pushboolean(L, 1);
@@ -180,7 +180,7 @@ static int torch_Storage_(__index__)(lua_State *L)
if(lua_isnumber(L, 2))
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- long index = luaL_checklong(L, 2) - 1;
+ ptrdiff_t index = luaL_checkinteger(L, 2) - 1;
luaG_(pushreal)(L, THStorage_(get)(storage, index));
lua_pushboolean(L, 1);
return 2;
@@ -214,7 +214,7 @@ static int torch_Storage_(string)(lua_State *L)
static int torch_Storage_(totable)(lua_State *L)
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- long i;
+ ptrdiff_t i;
lua_newtable(L);
for(i = 0; i < storage->size; i++)
@@ -237,6 +237,9 @@ static int torch_Storage_(write)(lua_State *L)
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
THFile *file = luaT_checkudata(L, 2, "torch.File");
+#ifdef DEBUG
+ THAssert(storage->size < LONG_MAX);
+#endif
THFile_writeLongScalar(file, storage->size);
THFile_writeRealRaw(file, storage->data, storage->size);
@@ -247,7 +250,7 @@ static int torch_Storage_(read)(lua_State *L)
{
THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
THFile *file = luaT_checkudata(L, 2, "torch.File");
- long size = THFile_readLongScalar(file);
+ ptrdiff_t size = THFile_readLongScalar(file);
THStorage_(resize)(storage, size);
THFile_readRealRaw(file, storage->data, storage->size);
diff --git a/generic/Tensor.c b/generic/Tensor.c
index 3067213..abb7819 100644
--- a/generic/Tensor.c
+++ b/generic/Tensor.c
@@ -5,7 +5,7 @@
#include "luaG.h"
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
- THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_);
+ THStorage **storage_, ptrdiff_t *storageOffset_, THLongStorage **size_, THLongStorage **stride_);
static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_);
@@ -29,7 +29,7 @@ static int torch_Tensor_(size)(lua_State *L)
static int torch_Tensor_(elementSize)(lua_State *L)
{
- luaT_pushlong(L, THStorage_(elementSize)());
+ luaT_pushinteger(L, THStorage_(elementSize)());
return 1;
}
@@ -55,7 +55,7 @@ static int torch_Tensor_(stride)(lua_State *L)
static int torch_Tensor_(nDimension)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
- luaT_pushlong(L, tensor->nDimension);
+ luaT_pushinteger(L, tensor->nDimension);
return 1;
}
@@ -76,21 +76,21 @@ static int torch_Tensor_(storage)(lua_State *L)
static int torch_Tensor_(storageOffset)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
- luaT_pushlong(L, tensor->storageOffset+1);
+ luaT_pushinteger(L, tensor->storageOffset+1);
return 1;
}
static int torch_Tensor_(new)(lua_State *L)
{
THTensor *tensor;
- long storageOffset;
+ ptrdiff_t storageOffset;
THLongStorage *size, *stride;
if(lua_type(L, 1) == LUA_TTABLE)
{
- long i, j;
+ ptrdiff_t i, j;
THLongStorage *counter;
- long si = 0;
+ ptrdiff_t si = 0;
int dimension = 0;
int is_finished = 0;
@@ -214,7 +214,7 @@ static int torch_Tensor_(set)(lua_State *L)
{
THTensor *self = luaT_checkudata(L, 1, torch_Tensor);
THStorage *storage;
- long storageOffset;
+ ptrdiff_t storageOffset;
THLongStorage *size, *stride;
torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1,
@@ -651,7 +651,7 @@ static int torch_Tensor_(isSetTo)(lua_State *L)
static int torch_Tensor_(nElement)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
- luaT_pushlong(L, THTensor_(nElement)(tensor));
+ luaT_pushinteger(L, THTensor_(nElement)(tensor));
return 1;
}
@@ -752,7 +752,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
}
else if((idx = luaT_toudata(L, 2, "torch.LongStorage")))
{
- long index = THTensor_(storageOffset)(tensor);
+ ptrdiff_t index = THTensor_(storageOffset)(tensor);
real value = luaG_(checkreal)(L,3);
int dim;
@@ -904,7 +904,7 @@ static int torch_Tensor_(__index__)(lua_State *L)
}
else if((idx = luaT_toudata(L, 2, "torch.LongStorage")))
{
- long index = THTensor_(storageOffset)(tensor);
+ ptrdiff_t index = THTensor_(storageOffset)(tensor);
int dim;
THArgCheck(idx->size == tensor->nDimension, 2, "invalid size");
@@ -1071,7 +1071,7 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt
}
static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride,
- THStorage **storage_, long *storageOffset_, THLongStorage **size_, THLongStorage **stride_)
+ THStorage **storage_, ptrdiff_t *storageOffset_, THLongStorage **size_, THLongStorage **stride_)
{
THTensor *src = NULL;
THStorage *storage = NULL;
@@ -1105,7 +1105,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
}
else
{
- *storageOffset_ = luaL_checklong(L, index+1)-1;
+ *storageOffset_ = luaL_checkinteger(L, index+1)-1;
torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_);
}
return;
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt
index a6da933..29343c7 100644
--- a/lib/TH/CMakeLists.txt
+++ b/lib/TH/CMakeLists.txt
@@ -111,7 +111,7 @@ ENDIF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND)
IF(C_AVX_FOUND)
SET(CMAKE_C_FLAGS "-DUSE_AVX ${CMAKE_C_FLAGS}")
IF(MSVC)
- SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast /arch:AVX /std:c99")
+ SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast /arch:AVX /std:c99")
ELSE(MSVC)
SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math -mavx -std=c99")
ENDIF(MSVC)
diff --git a/lib/TH/THAllocator.c b/lib/TH/THAllocator.c
index 7730ba5..5b06502 100644
--- a/lib/TH/THAllocator.c
+++ b/lib/TH/THAllocator.c
@@ -15,11 +15,11 @@
#endif
/* end of stuff for mapped files */
-static void *THDefaultAllocator_alloc(void* ctx, long size) {
+static void *THDefaultAllocator_alloc(void* ctx, ptrdiff_t size) {
return THAlloc(size);
}
-static void *THDefaultAllocator_realloc(void* ctx, void* ptr, long size) {
+static void *THDefaultAllocator_realloc(void* ctx, void* ptr, ptrdiff_t size) {
return THRealloc(ptr, size);
}
@@ -38,7 +38,7 @@ THAllocator THDefaultAllocator = {
struct THMapAllocatorContext_ {
char *filename; /* file name */
int flags;
- long size; /* mapped size */
+ ptrdiff_t size; /* mapped size */
int fd;
};
@@ -91,7 +91,7 @@ int THMapAllocatorContext_fd(THMapAllocatorContext *ctx)
return ctx->fd;
}
-long THMapAllocatorContext_size(THMapAllocatorContext *ctx)
+ptrdiff_t THMapAllocatorContext_size(THMapAllocatorContext *ctx)
{
return ctx->size;
}
@@ -103,7 +103,7 @@ void THMapAllocatorContext_free(THMapAllocatorContext *ctx)
THFree(ctx);
}
-static void *_map_alloc(void* ctx_, long size)
+static void *_map_alloc(void* ctx_, ptrdiff_t size)
{
THMapAllocatorContext *ctx = ctx_;
void *data = NULL;
@@ -326,11 +326,11 @@ static void *_map_alloc(void* ctx_, long size)
return data;
}
-static void * THMapAllocator_alloc(void *ctx, long size) {
+static void * THMapAllocator_alloc(void *ctx, ptrdiff_t size) {
return _map_alloc(ctx, size);
}
-static void *THMapAllocator_realloc(void* ctx, void* ptr, long size) {
+static void *THMapAllocator_realloc(void* ctx, void* ptr, ptrdiff_t size) {
THError("cannot realloc mapped data");
return NULL;
}
@@ -378,12 +378,12 @@ void THMapAllocatorContext_free(THMapAllocatorContext *ctx) {
THError("file mapping not supported on your system");
}
-static void *THMapAllocator_alloc(void* ctx_, long size) {
+static void *THMapAllocator_alloc(void* ctx_, ptrdiff_t size) {
THError("file mapping not supported on your system");
return NULL;
}
-static void *THMapAllocator_realloc(void* ctx, void* ptr, long size) {
+static void *THMapAllocator_realloc(void* ctx, void* ptr, ptrdiff_t size) {
THError("file mapping not supported on your system");
return NULL;
}
@@ -396,7 +396,7 @@ static void THMapAllocator_free(void* ctx, void* data) {
#if (defined(_WIN32) || defined(HAVE_MMAP)) && defined(TH_ATOMIC_IPC_REFCOUNT)
-static void * THRefcountedMapAllocator_alloc(void *_ctx, long size) {
+static void * THRefcountedMapAllocator_alloc(void *_ctx, ptrdiff_t size) {
THMapAllocatorContext *ctx = _ctx;
if (ctx->flags & TH_ALLOCATOR_MAPPED_FROMFD)
@@ -421,7 +421,7 @@ static void * THRefcountedMapAllocator_alloc(void *_ctx, long size) {
return (void*)data;
}
-static void *THRefcountedMapAllocator_realloc(void* ctx, void* ptr, long size) {
+static void *THRefcountedMapAllocator_realloc(void* ctx, void* ptr, ptrdiff_t size) {
THError("cannot realloc mapped data");
return NULL;
}
@@ -464,12 +464,12 @@ int THRefcountedMapAllocator_decref(THMapAllocatorContext *ctx, void *data)
#else
-static void * THRefcountedMapAllocator_alloc(void *ctx, long size) {
+static void * THRefcountedMapAllocator_alloc(void *ctx, ptrdiff_t size) {
THError("refcounted file mapping not supported on your system");
return NULL;
}
-static void *THRefcountedMapAllocator_realloc(void* ctx, void* ptr, long size) {
+static void *THRefcountedMapAllocator_realloc(void* ctx, void* ptr, ptrdiff_t size) {
THError("refcounted file mapping not supported on your system");
return NULL;
}
diff --git a/lib/TH/THAllocator.h b/lib/TH/THAllocator.h
index 14c433a..18fc9ec 100644
--- a/lib/TH/THAllocator.h
+++ b/lib/TH/THAllocator.h
@@ -14,8 +14,8 @@
/* Custom allocator
*/
typedef struct THAllocator {
- void* (*malloc)(void*, long);
- void* (*realloc)(void*, void*, long);
+ void* (*malloc)(void*, ptrdiff_t);
+ void* (*realloc)(void*, void*, ptrdiff_t);
void (*free)(void*, void*);
} THAllocator;
@@ -32,7 +32,7 @@ TH_API THMapAllocatorContext *THMapAllocatorContext_newWithFd(const char *filena
int fd, int flags);
TH_API char * THMapAllocatorContext_filename(THMapAllocatorContext *ctx);
TH_API int THMapAllocatorContext_fd(THMapAllocatorContext *ctx);
-TH_API long THMapAllocatorContext_size(THMapAllocatorContext *ctx);
+TH_API ptrdiff_t THMapAllocatorContext_size(THMapAllocatorContext *ctx);
TH_API void THMapAllocatorContext_free(THMapAllocatorContext *ctx);
TH_API void THRefcountedMapAllocator_incref(THMapAllocatorContext *ctx, void *data);
TH_API int THRefcountedMapAllocator_decref(THMapAllocatorContext *ctx, void *data);
diff --git a/lib/TH/THAtomic.c b/lib/TH/THAtomic.c
index aa70d93..714fc52 100644
--- a/lib/TH/THAtomic.c
+++ b/lib/TH/THAtomic.c
@@ -179,3 +179,89 @@ long THAtomicCompareAndSwapLong(long volatile *a, long oldvalue, long newvalue)
return 0;
#endif
}
+
+void THAtomicSetPtrdiff(ptrdiff_t volatile *a, ptrdiff_t newvalue)
+{
+#if defined(USE_C11_ATOMICS)
+ atomic_store(a, newvalue);
+#elif defined(USE_MSC_ATOMICS)
+#ifdef _WIN64
+ _InterlockedExchange64(a, newvalue);
+#else
+ _InterlockedExchange(a, newvalue);
+#endif
+#elif defined(USE_GCC_ATOMICS)
+ __sync_lock_test_and_set(a, newvalue);
+#else
+ ptrdiff_t oldvalue;
+ do {
+ oldvalue = *a;
+ } while (!THAtomicCompareAndSwapPtrdiff(a, oldvalue, newvalue));
+#endif
+}
+
+ptrdiff_t THAtomicGetPtrdiff(ptrdiff_t volatile *a)
+{
+#if defined(USE_C11_ATOMICS)
+ return atomic_load(a);
+#else
+ ptrdiff_t value;
+ do {
+ value = *a;
+ } while (!THAtomicCompareAndSwapPtrdiff(a, value, value));
+ return value;
+#endif
+}
+
+ptrdiff_t THAtomicAddPtrdiff(ptrdiff_t volatile *a, ptrdiff_t value)
+{
+#if defined(USE_C11_ATOMICS)
+ return atomic_fetch_add(a, value);
+#elif defined(USE_MSC_ATOMICS)
+#ifdef _WIN64
+ return _InterlockedExchangeAdd64(a, value);
+#else
+ return _InterlockedExchangeAdd(a, value);
+#endif
+#elif defined(USE_GCC_ATOMICS)
+ return __sync_fetch_and_add(a, value);
+#else
+ ptrdiff_t oldvalue;
+ do {
+ oldvalue = *a;
+ } while (!THAtomicCompareAndSwapPtrdiff(a, oldvalue, (oldvalue + value)));
+ return oldvalue;
+#endif
+}
+
+ptrdiff_t THAtomicCompareAndSwapPtrdiff(ptrdiff_t volatile *a, ptrdiff_t oldvalue, ptrdiff_t newvalue)
+{
+#if defined(USE_C11_ATOMICS)
+ return atomic_compare_exchange_strong(a, &oldvalue, newvalue);
+#elif defined(USE_MSC_ATOMICS)
+#ifdef _WIN64
+ return (_InterlockedCompareExchange64(a, newvalue, oldvalue) == oldvalue);
+#else
+ return (_InterlockedCompareExchange(a, newvalue, oldvalue) == oldvalue);
+#endif
+#elif defined(USE_GCC_ATOMICS)
+ return __sync_bool_compare_and_swap(a, oldvalue, newvalue);
+#elif defined(USE_PTHREAD_ATOMICS)
+ ptrdiff_t ret = 0;
+ pthread_mutex_lock(&ptm);
+ if(*a == oldvalue) {
+ *a = newvalue;
+ ret = 1;
+ }
+ pthread_mutex_unlock(&ptm);
+ return ret;
+#else
+#warning THAtomic is not thread safe
+ if(*a == oldvalue) {
+ *a = newvalue;
+ return 1;
+ }
+ else
+ return 0;
+#endif
+}
diff --git a/lib/TH/THAtomic.h b/lib/TH/THAtomic.h
index 3a0b6fa..d77b20b 100644
--- a/lib/TH/THAtomic.h
+++ b/lib/TH/THAtomic.h
@@ -86,6 +86,35 @@ TH_API long THAtomicAddLong(long volatile *a, long value);
*/
TH_API long THAtomicCompareAndSwapLong(long volatile *a, long oldvalue, long newvalue);
+
+
+/******************************************************************************
+ * functions for ptrdiff_t type
+ ******************************************************************************/
+
+/*
+ * *a = newvalue
+*/
+TH_API void THAtomicSetPtrdiff(ptrdiff_t volatile *a, ptrdiff_t newvalue);
+
+/*
+ * return *a
+*/
+TH_API ptrdiff_t THAtomicGetPtrdiff(ptrdiff_t volatile *a);
+
+/*
+ * *a += value,
+ * return previous *a
+*/
+TH_API ptrdiff_t THAtomicAddPtrdiff(ptrdiff_t volatile *a, ptrdiff_t value);
+
+/*
+ * check if (*a == oldvalue)
+ * if true: set *a to newvalue, return 1
+ * if false: return 0
+*/
+TH_API ptrdiff_t THAtomicCompareAndSwapPtrdiff(ptrdiff_t volatile *a, ptrdiff_t oldvalue, ptrdiff_t newvalue);
+
#if defined(USE_C11_ATOMICS) && defined(ATOMIC_INT_LOCK_FREE) && \
ATOMIC_INT_LOCK_FREE == 2
#define TH_ATOMIC_IPC_REFCOUNT 1
diff --git a/lib/TH/THDiskFile.c b/lib/TH/THDiskFile.c
index 50b006f..2ded7bd 100644
--- a/lib/TH/THDiskFile.c
+++ b/lib/TH/THDiskFile.c
@@ -381,20 +381,21 @@ static size_t THDiskFile_readLong(THFile *self, long *data, size_t n)
THDiskFile_reverseMemory(data, data, sizeof(long), nread);
} else if(dfself->longSize == 4)
{
- int i;
nread = fread__(data, 4, n, dfself->handle);
if(!dfself->isNativeEncoding && (nread > 0))
THDiskFile_reverseMemory(data, data, 4, nread);
- for(i = nread-1; i >= 0; i--)
- data[i] = ((int *)data)[i];
+ size_t i;
+ for(i = nread; i > 0; i--)
+ data[i-1] = ((int *)data)[i-1];
}
else /* if(dfself->longSize == 8) */
{
- int i, big_endian = !THDiskFile_isLittleEndianCPU();
- long *buffer = THAlloc(8*n);
+ int big_endian = !THDiskFile_isLittleEndianCPU();
+ int32_t *buffer = THAlloc(8*n);
nread = fread__(buffer, 8, n, dfself->handle);
- for(i = nread-1; i >= 0; i--)
- data[i] = buffer[2*i + big_endian];
+ size_t i;
+ for(i = nread; i > 0; i--)
+ data[i-1] = buffer[2*(i-1) + big_endian];
THFree(buffer);
if(!dfself->isNativeEncoding && (nread > 0))
THDiskFile_reverseMemory(data, data, 4, nread);
@@ -450,8 +451,8 @@ static size_t THDiskFile_writeLong(THFile *self, long *data, size_t n)
}
} else if(dfself->longSize == 4)
{
- int i;
- int *buffer = THAlloc(4*n);
+ int32_t *buffer = THAlloc(4*n);
+ size_t i;
for(i = 0; i < n; i++)
buffer[i] = data[i];
if(!dfself->isNativeEncoding)
@@ -461,8 +462,9 @@ static size_t THDiskFile_writeLong(THFile *self, long *data, size_t n)
}
else /* if(dfself->longSize == 8) */
{
- int i, big_endian = !THDiskFile_isLittleEndianCPU();
- long *buffer = THAlloc(8*n);
+ int big_endian = !THDiskFile_isLittleEndianCPU();
+ int32_t *buffer = THAlloc(8*n);
+ size_t i;
for(i = 0; i < n; i++)
{
buffer[2*i + !big_endian] = 0;
diff --git a/lib/TH/THGeneral.c b/lib/TH/THGeneral.c
index d26af0a..399403b 100644
--- a/lib/TH/THGeneral.c
+++ b/lib/TH/THGeneral.c
@@ -129,10 +129,11 @@ void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *da
static __thread void (*torchGCFunction)(void *data) = NULL;
static __thread void *torchGCData;
-static long heapSize = 0;
-static __thread long heapDelta = 0;
-static const long heapMaxDelta = 1e6; // limit to +/- 1MB before updating heapSize
-static __thread long heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
+static ptrdiff_t heapSize = 0;
+static __thread ptrdiff_t heapDelta = 0;
+static const ptrdiff_t heapMaxDelta = (ptrdiff_t)1e6; // limit to +/- 1MB before updating heapSize
+static const ptrdiff_t heapMinDelta = (ptrdiff_t)-1e6;
+static __thread ptrdiff_t heapSoftmax = (ptrdiff_t)3e8; // 300MB, adjusted upward dynamically
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%
@@ -152,7 +153,8 @@ void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data )
torchGCData = data;
}
-static long getAllocSize(void *ptr) {
+/* it is guaranteed the allocated size is not bigger than PTRDIFF_MAX */
+static ptrdiff_t getAllocSize(void *ptr) {
#if defined(__unix) && defined(HAVE_MALLOC_USABLE_SIZE)
return malloc_usable_size(ptr);
#elif defined(__APPLE__)
@@ -164,8 +166,15 @@ static long getAllocSize(void *ptr) {
#endif
}
-static long applyHeapDelta() {
- long newHeapSize = THAtomicAddLong(&heapSize, heapDelta) + heapDelta;
+static ptrdiff_t applyHeapDelta() {
+ ptrdiff_t oldHeapSize = THAtomicAddPtrdiff(&heapSize, heapDelta);
+#ifdef DEBUG
+ if (heapDelta > 0 && oldHeapSize > PTRDIFF_MAX - heapDelta)
+ THError("applyHeapDelta: heapSize(%td) + increased(%td) > PTRDIFF_MAX, heapSize overflow!", oldHeapSize, heapDelta);
+ if (heapDelta < 0 && oldHeapSize < PTRDIFF_MIN - heapDelta)
+ THError("applyHeapDelta: heapSize(%td) + decreased(%td) < PTRDIFF_MIN, heapSize underflow!", oldHeapSize, heapDelta);
+#endif
+ ptrdiff_t newHeapSize = oldHeapSize + heapDelta;
heapDelta = 0;
return newHeapSize;
}
@@ -174,36 +183,43 @@ static long applyHeapDelta() {
* (2) if post-GC heap size exceeds 80% of the soft max, increase the
* soft max by 40%
*/
-static void maybeTriggerGC(long curHeapSize) {
+static void maybeTriggerGC(ptrdiff_t curHeapSize) {
if (torchGCFunction && curHeapSize > heapSoftmax) {
torchGCFunction(torchGCData);
// ensure heapSize is accurate before updating heapSoftmax
- long newHeapSize = applyHeapDelta();
+ ptrdiff_t newHeapSize = applyHeapDelta();
if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh) {
- heapSoftmax = heapSoftmax * heapSoftmaxGrowthFactor;
+ heapSoftmax = (ptrdiff_t)(heapSoftmax * heapSoftmaxGrowthFactor);
}
}
}
// hooks into the TH heap tracking
-void THHeapUpdate(long size) {
+void THHeapUpdate(ptrdiff_t size) {
+#ifdef DEBUG
+ if (size > 0 && heapDelta > PTRDIFF_MAX - size)
+ THError("THHeapUpdate: heapDelta(%td) + increased(%td) > PTRDIFF_MAX, heapDelta overflow!", heapDelta, size);
+ if (size < 0 && heapDelta < PTRDIFF_MIN - size)
+ THError("THHeapUpdate: heapDelta(%td) + decreased(%td) < PTRDIFF_MIN, heapDelta underflow!", heapDelta, size);
+#endif
+
heapDelta += size;
// batch updates to global heapSize to minimize thread contention
- if (labs(heapDelta) < heapMaxDelta) {
+ if (heapDelta < heapMaxDelta && heapDelta > heapMinDelta) {
return;
}
- long newHeapSize = applyHeapDelta();
+ ptrdiff_t newHeapSize = applyHeapDelta();
if (size > 0) {
maybeTriggerGC(newHeapSize);
}
}
-static void* THAllocInternal(long size)
+static void* THAllocInternal(ptrdiff_t size)
{
void *ptr;
@@ -229,7 +245,7 @@ static void* THAllocInternal(long size)
return ptr;
}
-void* THAlloc(long size)
+void* THAlloc(ptrdiff_t size)
{
void *ptr;
@@ -252,7 +268,7 @@ void* THAlloc(long size)
return ptr;
}
-void* THRealloc(void *ptr, long size)
+void* THRealloc(void *ptr, ptrdiff_t size)
{
if(!ptr)
return(THAlloc(size));
@@ -266,18 +282,20 @@ void* THRealloc(void *ptr, long size)
if(size < 0)
THError("$ Torch: invalid memory size -- maybe an overflow?");
- THHeapUpdate(-getAllocSize(ptr));
+ ptrdiff_t oldSize = -getAllocSize(ptr);
void *newptr = realloc(ptr, size);
if(!newptr && torchGCFunction) {
torchGCFunction(torchGCData);
newptr = realloc(ptr, size);
}
- THHeapUpdate(getAllocSize(newptr ? newptr : ptr));
if(!newptr)
THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824);
+ // update heapSize only after successfully reallocated
+ THHeapUpdate(oldSize + getAllocSize(newptr));
+
return newptr;
}
diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in
index e52ba34..ff41159 100644
--- a/lib/TH/THGeneral.h.in
+++ b/lib/TH/THGeneral.h.in
@@ -9,6 +9,7 @@
#include <float.h>
#include <time.h>
#include <string.h>
+#include <stddef.h>
#cmakedefine USE_BLAS
#cmakedefine USE_LAPACK
@@ -57,12 +58,12 @@ TH_API void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *d
TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...);
TH_API void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data);
TH_API void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data);
-TH_API void* THAlloc(long size);
-TH_API void* THRealloc(void *ptr, long size);
+TH_API void* THAlloc(ptrdiff_t size);
+TH_API void* THRealloc(void *ptr, ptrdiff_t size);
TH_API void THFree(void *ptr);
TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data );
// this hook should only be called by custom allocator functions
-TH_API void THHeapUpdate(long size);
+TH_API void THHeapUpdate(ptrdiff_t size);
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
diff --git a/lib/TH/THLogAdd.c b/lib/TH/THLogAdd.c
index a503d7d..4b14f85 100644
--- a/lib/TH/THLogAdd.c
+++ b/lib/TH/THLogAdd.c
@@ -55,7 +55,7 @@ double THLogSub(double log_a, double log_b)
}
/* Credits to Leon Bottou */
-double THExpMinusApprox(double x)
+double THExpMinusApprox(const double x)
{
#define EXACT_EXPONENTIAL 0
#if EXACT_EXPONENTIAL
diff --git a/lib/TH/THMemoryFile.c b/lib/TH/THMemoryFile.c
index 453e11e..8d97621 100644
--- a/lib/TH/THMemoryFile.c
+++ b/lib/TH/THMemoryFile.c
@@ -371,22 +371,23 @@ static size_t THMemoryFile_readLong(THFile *self, long *data, size_t n)
mfself->position += nread*sizeof(long);
} else if(mfself->longSize == 4)
{
- size_t i;
size_t nByte = 4*n;
size_t nByteRemaining = (mfself->position + nByte <= mfself->size ? nByte : mfself->size-mfself->position);
int32_t *storage = (int32_t *)(mfself->storage->data + mfself->position);
nread = nByteRemaining/4;
+ size_t i;
for(i = 0; i < nread; i++)
data[i] = storage[i];
mfself->position += nread*4;
}
else /* if(mfself->longSize == 8) */
{
- int i, big_endian = !THDiskFile_isLittleEndianCPU();
+ int big_endian = !THDiskFile_isLittleEndianCPU();
size_t nByte = 8*n;
int32_t *storage = (int32_t *)(mfself->storage->data + mfself->position);
size_t nByteRemaining = (mfself->position + nByte <= mfself->size ? nByte : mfself->size-mfself->position);
nread = nByteRemaining/8;
+ size_t i;
for(i = 0; i < nread; i++)
data[i] = storage[2*i + big_endian];
mfself->position += nread*8;
@@ -448,20 +449,21 @@ static size_t THMemoryFile_writeLong(THFile *self, long *data, size_t n)
mfself->position += nByte;
} else if(mfself->longSize == 4)
{
- int i;
size_t nByte = 4*n;
THMemoryFile_grow(mfself, mfself->position+nByte);
int32_t *storage = (int32_t *)(mfself->storage->data + mfself->position);
+ size_t i;
for(i = 0; i < n; i++)
storage[i] = data[i];
mfself->position += nByte;
}
else /* if(mfself->longSize == 8) */
{
- int i, big_endian = !THDiskFile_isLittleEndianCPU();
+ int big_endian = !THDiskFile_isLittleEndianCPU();
size_t nByte = 8*n;
THMemoryFile_grow(mfself, mfself->position+nByte);
int32_t *storage = (int32_t *)(mfself->storage->data + mfself->position);
+ size_t i;
for(i = 0; i < n; i++)
{
storage[2*i + !big_endian] = 0;
@@ -517,7 +519,7 @@ static size_t THMemoryFile_writeLong(THFile *self, long *data, size_t n)
return n;
}
-static char* THMemoryFile_cloneString(const char *str, long size)
+static char* THMemoryFile_cloneString(const char *str, ptrdiff_t size)
{
char *cstr = THAlloc(size);
memcpy(cstr, str, size);
diff --git a/lib/TH/generic/THStorage.c b/lib/TH/generic/THStorage.c
index 788f6c7..a592cfb 100644
--- a/lib/TH/generic/THStorage.c
+++ b/lib/TH/generic/THStorage.c
@@ -7,12 +7,12 @@ real* THStorage_(data)(const THStorage *self)
return self->data;
}
-long THStorage_(size)(const THStorage *self)
+ptrdiff_t THStorage_(size)(const THStorage *self)
{
return self->size;
}
-int THStorage_(elementSize)()
+size_t THStorage_(elementSize)()
{
return sizeof(real);
}
@@ -22,12 +22,12 @@ THStorage* THStorage_(new)(void)
return THStorage_(newWithSize)(0);
}
-THStorage* THStorage_(newWithSize)(long size)
+THStorage* THStorage_(newWithSize)(ptrdiff_t size)
{
return THStorage_(newWithAllocator)(size, &THDefaultAllocator, NULL);
}
-THStorage* THStorage_(newWithAllocator)(long size,
+THStorage* THStorage_(newWithAllocator)(ptrdiff_t size,
THAllocator *allocator,
void *allocatorContext)
{
@@ -41,7 +41,7 @@ THStorage* THStorage_(newWithAllocator)(long size,
return storage;
}
-THStorage* THStorage_(newWithMapping)(const char *filename, long size, int flags)
+THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int flags)
{
THMapAllocatorContext *ctx = THMapAllocatorContext_new(filename, flags);
@@ -127,13 +127,13 @@ void THStorage_(free)(THStorage *storage)
}
}
-THStorage* THStorage_(newWithData)(real *data, long size)
+THStorage* THStorage_(newWithData)(real *data, ptrdiff_t size)
{
return THStorage_(newWithDataAndAllocator)(data, size,
&THDefaultAllocator, NULL);
}
-THStorage* THStorage_(newWithDataAndAllocator)(real* data, long size,
+THStorage* THStorage_(newWithDataAndAllocator)(real* data, ptrdiff_t size,
THAllocator* allocator,
void* allocatorContext) {
THStorage *storage = THAlloc(sizeof(THStorage));
@@ -146,14 +146,14 @@ THStorage* THStorage_(newWithDataAndAllocator)(real* data, long size,
return storage;
}
-void THStorage_(resize)(THStorage *storage, long size)
+void THStorage_(resize)(THStorage *storage, ptrdiff_t size)
{
if(storage->flag & TH_STORAGE_RESIZABLE)
{
if(storage->allocator->realloc == NULL) {
/* case when the allocator does not have a realloc defined */
real *old_data = storage->data;
- long old_size = storage->size;
+ ptrdiff_t old_size = storage->size;
if (size == 0) {
storage->data = NULL;
} else {
@@ -163,7 +163,7 @@ void THStorage_(resize)(THStorage *storage, long size)
}
storage->size = size;
if (old_data != NULL) {
- long copy_size = old_size;
+ ptrdiff_t copy_size = old_size;
if (storage->size < copy_size) {
copy_size = storage->size;
}
@@ -186,18 +186,18 @@ void THStorage_(resize)(THStorage *storage, long size)
void THStorage_(fill)(THStorage *storage, real value)
{
- long i;
+ ptrdiff_t i;
for(i = 0; i < storage->size; i++)
storage->data[i] = value;
}
-void THStorage_(set)(THStorage *self, long idx, real value)
+void THStorage_(set)(THStorage *self, ptrdiff_t idx, real value)
{
THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds");
self->data[idx] = value;
}
-real THStorage_(get)(const THStorage *self, long idx)
+real THStorage_(get)(const THStorage *self, ptrdiff_t idx)
{
THArgCheck((idx >= 0) && (idx < self->size), 2, "out of bounds");
return self->data[idx];
@@ -207,7 +207,7 @@ void THStorage_(swap)(THStorage *storage1, THStorage *storage2)
{
#define SWAP(val) { val = storage1->val; storage1->val = storage2->val; storage2->val = val; }
real *data;
- long size;
+ ptrdiff_t size;
char flag;
THAllocator *allocator;
void *allocatorContext;
diff --git a/lib/TH/generic/THStorage.h b/lib/TH/generic/THStorage.h
index 0f6dcca..3dd214b 100644
--- a/lib/TH/generic/THStorage.h
+++ b/lib/TH/generic/THStorage.h
@@ -24,7 +24,7 @@
typedef struct THStorage
{
real *data;
- long size;
+ ptrdiff_t size;
int refcount;
char flag;
THAllocator *allocator;
@@ -33,29 +33,29 @@ typedef struct THStorage
} THStorage;
TH_API real* THStorage_(data)(const THStorage*);
-TH_API long THStorage_(size)(const THStorage*);
-TH_API int THStorage_(elementSize)(void);
+TH_API ptrdiff_t THStorage_(size)(const THStorage*);
+TH_API size_t THStorage_(elementSize)(void);
/* slow access -- checks everything */
-TH_API void THStorage_(set)(THStorage*, long, real);
-TH_API real THStorage_(get)(const THStorage*, long);
+TH_API void THStorage_(set)(THStorage*, ptrdiff_t, real);
+TH_API real THStorage_(get)(const THStorage*, ptrdiff_t);
TH_API THStorage* THStorage_(new)(void);
-TH_API THStorage* THStorage_(newWithSize)(long size);
+TH_API THStorage* THStorage_(newWithSize)(ptrdiff_t size);
TH_API THStorage* THStorage_(newWithSize1)(real);
TH_API THStorage* THStorage_(newWithSize2)(real, real);
TH_API THStorage* THStorage_(newWithSize3)(real, real, real);
TH_API THStorage* THStorage_(newWithSize4)(real, real, real, real);
-TH_API THStorage* THStorage_(newWithMapping)(const char *filename, long size, int flags);
+TH_API THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int flags);
/* takes ownership of data */
-TH_API THStorage* THStorage_(newWithData)(real *data, long size);
+TH_API THStorage* THStorage_(newWithData)(real *data, ptrdiff_t size);
-TH_API THStorage* THStorage_(newWithAllocator)(long size,
+TH_API THStorage* THStorage_(newWithAllocator)(ptrdiff_t size,
THAllocator* allocator,
void *allocatorContext);
TH_API THStorage* THStorage_(newWithDataAndAllocator)(
- real* data, long size, THAllocator* allocator, void *allocatorContext);
+ real* data, ptrdiff_t size, THAllocator* allocator, void *allocatorContext);
/* should not differ with API */
TH_API void THStorage_(setFlag)(THStorage *storage, const char flag);
@@ -65,7 +65,7 @@ TH_API void THStorage_(swap)(THStorage *storage1, THStorage *storage2);
/* might differ with other API (like CUDA) */
TH_API void THStorage_(free)(THStorage *storage);
-TH_API void THStorage_(resize)(THStorage *storage, long size);
+TH_API void THStorage_(resize)(THStorage *storage, ptrdiff_t size);
TH_API void THStorage_(fill)(THStorage *storage, real value);
#endif
diff --git a/lib/TH/generic/THStorageCopy.c b/lib/TH/generic/THStorageCopy.c
index 63a26dc..583e088 100644
--- a/lib/TH/generic/THStorageCopy.c
+++ b/lib/TH/generic/THStorageCopy.c
@@ -4,7 +4,7 @@
void THStorage_(rawCopy)(THStorage *storage, real *src)
{
- long i;
+ ptrdiff_t i;
for(i = 0; i < storage->size; i++)
storage->data[i] = src[i];
}
@@ -19,7 +19,7 @@ void THStorage_(copy)(THStorage *storage, THStorage *src)
#define IMPLEMENT_THStorage_COPY(TYPENAMESRC) \
void THStorage_(copy##TYPENAMESRC)(THStorage *storage, TH##TYPENAMESRC##Storage *src) \
{ \
- long i; \
+ ptrdiff_t i; \
THArgCheck(storage->size == src->size, 2, "size mismatch"); \
for(i = 0; i < storage->size; i++) \
storage->data[i] = (real)src->data[i]; \
diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c
index 26bbb01..42e9e6d 100644
--- a/lib/TH/generic/THTensor.c
+++ b/lib/TH/generic/THTensor.c
@@ -8,7 +8,7 @@ THStorage *THTensor_(storage)(const THTensor *self)
return self->storage;
}
-long THTensor_(storageOffset)(const THTensor *self)
+ptrdiff_t THTensor_(storageOffset)(const THTensor *self)
{
return self->storageOffset;
}
@@ -67,7 +67,7 @@ void THTensor_(clearFlag)(THTensor *self, const char flag)
/**** creation methods ****/
static void THTensor_(rawInit)(THTensor *self);
-static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride);
+static void THTensor_(rawSet)(THTensor *self, THStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride);
static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, long *stride);
@@ -94,13 +94,16 @@ THTensor *THTensor_(newWithTensor)(THTensor *tensor)
}
/* Storage init */
-THTensor *THTensor_(newWithStorage)(THStorage *storage, long storageOffset, THLongStorage *size, THLongStorage *stride)
+THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset, THLongStorage *size, THLongStorage *stride)
{
THTensor *self = THAlloc(sizeof(THTensor));
if(size && stride)
THArgCheck(size->size == stride->size, 4, "inconsistent size");
THTensor_(rawInit)(self);
+#ifdef DEBUG
+ THAssert((size ? size->size : (stride ? stride->size : 0)) <= INT_MAX);
+#endif
THTensor_(rawSet)(self,
storage,
storageOffset,
@@ -110,20 +113,20 @@ THTensor *THTensor_(newWithStorage)(THStorage *storage, long storageOffset, THLo
return self;
}
-THTensor *THTensor_(newWithStorage1d)(THStorage *storage, long storageOffset,
+THTensor *THTensor_(newWithStorage1d)(THStorage *storage, ptrdiff_t storageOffset,
long size0, long stride0)
{
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, -1, -1, -1, -1, -1, -1);
}
-THTensor *THTensor_(newWithStorage2d)(THStorage *storage, long storageOffset,
+THTensor *THTensor_(newWithStorage2d)(THStorage *storage, ptrdiff_t storageOffset,
long size0, long stride0,
long size1, long stride1)
{
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, -1, -1, -1, -1);
}
-THTensor *THTensor_(newWithStorage3d)(THStorage *storage, long storageOffset,
+THTensor *THTensor_(newWithStorage3d)(THStorage *storage, ptrdiff_t storageOffset,
long size0, long stride0,
long size1, long stride1,
long size2, long stride2)
@@ -131,7 +134,7 @@ THTensor *THTensor_(newWithStorage3d)(THStorage *storage, long storageOffset,
return THTensor_(newWithStorage4d)(storage, storageOffset, size0, stride0, size1, stride1, size2, stride2, -1, -1);
}
-THTensor *THTensor_(newWithStorage4d)(THStorage *storage, long storageOffset,
+THTensor *THTensor_(newWithStorage4d)(THStorage *storage, ptrdiff_t storageOffset,
long size0, long stride0,
long size1, long stride1,
long size2, long stride2,
@@ -232,6 +235,9 @@ void THTensor_(resize)(THTensor *self, THLongStorage *size, THLongStorage *strid
if(stride)
THArgCheck(stride->size == size->size, 3, "invalid stride");
+#ifdef DEBUG
+ THAssert(size->size <= INT_MAX);
+#endif
THTensor_(rawResize)(self, size->size, size->data, (stride ? stride->data : NULL));
}
@@ -281,11 +287,14 @@ void THTensor_(set)(THTensor *self, THTensor *src)
src->stride);
}
-void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_)
+void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_)
{
if(size_ && stride_)
THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes");
+#ifdef DEBUG
+ THAssert((size_ ? size_->size : (stride_ ? stride_->size : 0)) <= INT_MAX);
+#endif
THTensor_(rawSet)(self,
storage_,
storageOffset_,
@@ -294,7 +303,7 @@ void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffs
(stride_ ? stride_->data : NULL));
}
-void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_,
+void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_)
{
THTensor_(setStorage4d)(self, storage_, storageOffset_,
@@ -304,7 +313,7 @@ void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOf
-1, -1);
}
-void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_,
+void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_)
{
@@ -315,7 +324,7 @@ void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOf
-1, -1);
}
-void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_,
+void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_)
@@ -327,7 +336,7 @@ void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOf
-1, -1);
}
-void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_,
+void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_,
@@ -348,7 +357,7 @@ void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension, long firstI
THArgCheck( (dimension >= 0) && (dimension < src->nDimension), 2, "out of range");
THArgCheck( (firstIndex >= 0) && (firstIndex < src->size[dimension]), 3, "out of range");
- THArgCheck( (size > 0) && (firstIndex+size <= src->size[dimension]), 4, "out of range");
+ THArgCheck( (size > 0) && (firstIndex <= src->size[dimension] - size), 4, "out of range");
THTensor_(set)(self, src);
@@ -564,13 +573,13 @@ int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
return 0;
}
-long THTensor_(nElement)(const THTensor *self)
+ptrdiff_t THTensor_(nElement)(const THTensor *self)
{
if(self->nDimension == 0)
return 0;
else
{
- long nElement = 1;
+ ptrdiff_t nElement = 1;
int d;
for(d = 0; d < self->nDimension; d++)
nElement *= self->size[d];
@@ -623,7 +632,7 @@ static void THTensor_(rawInit)(THTensor *self)
self->flag = TH_TENSOR_REFCOUNTED;
}
-static void THTensor_(rawSet)(THTensor *self, THStorage *storage, long storageOffset, int nDimension, long *size, long *stride)
+static void THTensor_(rawSet)(THTensor *self, THStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride)
{
/* storage */
if(self->storage != storage)
@@ -653,7 +662,7 @@ static void THTensor_(rawResize)(THTensor *self, int nDimension, long *size, lon
{
int d;
int nDimension_;
- long totalSize;
+ ptrdiff_t totalSize;
int hascorrectsize = 1;
nDimension_ = 0;
diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h
index 7a3d585..81e3cb0 100644
--- a/lib/TH/generic/THTensor.h
+++ b/lib/TH/generic/THTensor.h
@@ -13,7 +13,7 @@ typedef struct THTensor
int nDimension;
THStorage *storage;
- long storageOffset;
+ ptrdiff_t storageOffset;
int refcount;
char flag;
@@ -23,7 +23,7 @@ typedef struct THTensor
/**** access methods ****/
TH_API THStorage* THTensor_(storage)(const THTensor *self);
-TH_API long THTensor_(storageOffset)(const THTensor *self);
+TH_API ptrdiff_t THTensor_(storageOffset)(const THTensor *self);
TH_API int THTensor_(nDimension)(const THTensor *self);
TH_API long THTensor_(size)(const THTensor *self, int dim);
TH_API long THTensor_(stride)(const THTensor *self, int dim);
@@ -39,17 +39,17 @@ TH_API void THTensor_(clearFlag)(THTensor *self, const char flag);
TH_API THTensor *THTensor_(new)(void);
TH_API THTensor *THTensor_(newWithTensor)(THTensor *tensor);
/* stride might be NULL */
-TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_);
-TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, long storageOffset_,
+TH_API THTensor *THTensor_(newWithStorage)(THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
+TH_API THTensor *THTensor_(newWithStorage1d)(THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_);
-TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, long storageOffset_,
+TH_API THTensor *THTensor_(newWithStorage2d)(THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_);
-TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, long storageOffset_,
+TH_API THTensor *THTensor_(newWithStorage3d)(THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_);
-TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, long storageOffset_,
+TH_API THTensor *THTensor_(newWithStorage4d)(THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_,
@@ -78,17 +78,17 @@ TH_API void THTensor_(resize4d)(THTensor *tensor, long size0_, long size1_, long
TH_API void THTensor_(resize5d)(THTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_);
TH_API void THTensor_(set)(THTensor *self, THTensor *src);
-TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, long storageOffset_, THLongStorage *size_, THLongStorage *stride_);
-TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, long storageOffset_,
+TH_API void THTensor_(setStorage)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
+TH_API void THTensor_(setStorage1d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_);
-TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, long storageOffset_,
+TH_API void THTensor_(setStorage2d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_);
-TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, long storageOffset_,
+TH_API void THTensor_(setStorage3d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_);
-TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, long storageOffset_,
+TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_,
long size0_, long stride0_,
long size1_, long stride1_,
long size2_, long stride2_,
@@ -106,7 +106,7 @@ TH_API int THTensor_(isContiguous)(const THTensor *self);
TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src);
TH_API int THTensor_(isSetTo)(const THTensor *self, const THTensor *src);
TH_API int THTensor_(isSize)(const THTensor *self, const THLongStorage *dims);
-TH_API long THTensor_(nElement)(const THTensor *self);
+TH_API ptrdiff_t THTensor_(nElement)(const THTensor *self);
TH_API void THTensor_(retain)(THTensor *self);
TH_API void THTensor_(free)(THTensor *self);
diff --git a/lib/TH/generic/THTensorConv.c b/lib/TH/generic/THTensorConv.c
index da37989..d98a2aa 100644
--- a/lib/TH/generic/THTensorConv.c
+++ b/lib/TH/generic/THTensorConv.c
@@ -590,7 +590,7 @@ void THTensor_(conv2DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_,
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
@@ -697,7 +697,7 @@ void THTensor_(conv2DRevgerm)(THTensor *r_, real beta, real alpha, THTensor *t_,
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected");
@@ -809,7 +809,7 @@ void THTensor_(conv2Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
@@ -941,7 +941,7 @@ void THTensor_(conv2Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTe
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
@@ -1076,7 +1076,7 @@ void THTensor_(conv2Dmm)(THTensor *r_, real beta, real alpha, THTensor *t_, THTe
THTensor *input;
THTensor* kernel;
long nbatch;
- long nelem;
+ ptrdiff_t nelem;
real *input_data;
real *weight_data;
real *output_data;
@@ -1229,7 +1229,7 @@ void THTensor_(conv2Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
real *ptr_input;
real *ptr_weight;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
THArgCheck(t_->nDimension == 2 , 3, "input: 2D Tensor expected");
THArgCheck(k_->nDimension == 2 , 4, "kernel: 2D Tensor expected");
@@ -1287,7 +1287,7 @@ void THTensor_(conv2Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, TH
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
@@ -1365,7 +1365,7 @@ void THTensor_(conv2Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
real *weight_data;
real *output_data;
long nmaps;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
@@ -1453,7 +1453,7 @@ void THTensor_(conv3DRevger)(THTensor *r_, real beta, real alpha, THTensor *t_,
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k, i;
THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected");
@@ -1540,7 +1540,7 @@ void THTensor_(conv3Dger)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k, i;
THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected");
@@ -1632,7 +1632,7 @@ void THTensor_(conv3Dmv)(THTensor *r_, real beta, real alpha, THTensor *t_, THTe
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k, i;
THArgCheck(t_->nDimension == 4 , 3, "input: 4D Tensor expected");
@@ -1728,7 +1728,7 @@ void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
real *ptr_input;
real *ptr_weight;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
THArgCheck(t_->nDimension == 3 , 3, "input: 3D Tensor expected");
THArgCheck(k_->nDimension == 3 , 4, "kernel: 3D Tensor expected");
@@ -1794,7 +1794,7 @@ void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, TH
real *input_data;
real *weight_data;
real *output_data;
- long nelem;
+ ptrdiff_t nelem;
long k;
THArgCheck(t_->nDimension == 4 , 3, "input: 3D Tensor expected");
@@ -1876,7 +1876,7 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
THTensor *input;
THTensor *kernel;
- long nelem;
+ ptrdiff_t nelem;
real *input_data;
real *weight_data;
real *output_data;
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index a324191..b275d8f 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -35,8 +35,8 @@ void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
{
THTensor *srct = THTensor_(newContiguous)(src);
real *src_data = THTensor_(data)(srct);
- long cntr = 0;
- long nelem = THTensor_(nElement)(srct);
+ ptrdiff_t cntr = 0;
+ ptrdiff_t nelem = THTensor_(nElement)(srct);
if (THTensor_(nElement)(tensor) != THByteTensor_nElement(mask))
{
THTensor_(free)(srct);
@@ -68,9 +68,12 @@ void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask)
{
- long numel = THByteTensor_sumall(mask);
+ ptrdiff_t numel = THByteTensor_sumall(mask);
real *tensor_data;
+#ifdef DEBUG
+ THAssert(numel <= LONG_MAX);
+#endif
THTensor_(resize1d)(tensor,numel);
tensor_data = THTensor_(data)(tensor);
TH_TENSOR_APPLY2(real, src, unsigned char, mask,
@@ -90,7 +93,7 @@ void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask
// Finds non-zero elements of a tensor and returns their subscripts
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
{
- long numel = 0;
+ ptrdiff_t numel = 0;
long *subscript_data;
long i = 0;
long dim;
@@ -101,6 +104,9 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
if (*tensor_data != 0) {
++numel;
});
+#ifdef DEBUG
+ THAssert(numel <= LONG_MAX);
+#endif
THLongTensor_resize2d(subscript, numel, tensor->nDimension);
/* Second pass populates subscripts */
@@ -121,7 +127,7 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
{
- long i, numel;
+ ptrdiff_t i, numel;
THLongStorage *newSize;
THTensor *tSlice, *sSlice;
long *index_data;
@@ -135,6 +141,9 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
newSize = THLongStorage_newWithSize(src->nDimension);
THLongStorage_rawCopy(newSize,src->size);
+#ifdef DEBUG
+ THAssert(numel <= LONG_MAX);
+#endif
newSize->data[dim] = numel;
THTensor_(resize)(tensor,newSize,NULL);
THLongStorage_free(newSize);
@@ -146,7 +155,7 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
{
tensor_data = THTensor_(data)(tensor);
src_data = THTensor_(data)(src);
- long rowsize = THTensor_(nElement)(src) / src->size[0];
+ ptrdiff_t rowsize = THTensor_(nElement)(src) / src->size[0];
// check that the indices are within range
long max = src->size[0] - 1 + TH_INDEX_BASE;
@@ -191,7 +200,7 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
- long i, numel;
+ ptrdiff_t i, numel;
THTensor *tSlice, *sSlice;
long *index_data;
@@ -230,7 +239,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
{
- long i, numel;
+ ptrdiff_t i, numel;
THTensor *tSlice, *sSlice;
long *index_data;
@@ -271,7 +280,7 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso
void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, real val)
{
- long i, numel;
+ ptrdiff_t i, numel;
THTensor *tSlice;
long *index_data;
@@ -451,8 +460,8 @@ void THTensor_(add)(THTensor *r_, THTensor *t, real value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = tp[i] + value;
@@ -472,8 +481,8 @@ void THTensor_(mul)(THTensor *r_, THTensor *t, real value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = tp[i] * value;
@@ -488,8 +497,8 @@ void THTensor_(div)(THTensor *r_, THTensor *t, real value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = tp[i] / value;
@@ -504,8 +513,8 @@ void THTensor_(fmod)(THTensor *r_, THTensor *t, real value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = fmod(tp[i], value);
@@ -520,8 +529,8 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = (value == 0)? NAN : tp[i] - value * floor(tp[i] / value);
@@ -536,9 +545,9 @@ void THTensor_(clamp)(THTensor *r_, THTensor *t, real min_value, real max_value)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- real t_val;
- long sz = THTensor_(nElement)(t);
- long i;
+ /* real t_val; */
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = (tp[i] < min_value) ? min_value : (tp[i] > max_value ? max_value : tp[i]);
@@ -557,8 +566,8 @@ void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i< sz; i++)
rp[i] = tp[i] + value * sp[i];
@@ -580,8 +589,8 @@ void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = tp[i] * sp[i];
@@ -597,8 +606,8 @@ void THTensor_(cpow)(THTensor *r_, THTensor *t, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = pow(tp[i], sp[i]);
@@ -614,8 +623,8 @@ void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = tp[i] / sp[i];
@@ -631,8 +640,8 @@ void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = fmod(tp[i], sp[i]);
@@ -648,8 +657,8 @@ void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src)
real *tp = THTensor_(data)(t);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = (sp[i] == 0)? NAN : tp[i] - sp[i] * floor(tp[i] / sp[i]);
@@ -664,8 +673,8 @@ void THTensor_(tpow)(THTensor *r_, real value, THTensor *t)
if (THTensor_(isContiguous)(r_) && THTensor_(isContiguous)(t) && THTensor_(nElement)(r_) == THTensor_(nElement)(t)) {
real *tp = THTensor_(data)(t);
real *rp = THTensor_(data)(r_);
- long sz = THTensor_(nElement)(t);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(t);
+ ptrdiff_t i;
#pragma omp parallel for if(sz > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<sz; i++)
rp[i] = pow(value, tp[i]);
@@ -1047,7 +1056,7 @@ void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, TH
THTensor_(free)(result_matrix);
}
-long THTensor_(numel)(THTensor *t)
+ptrdiff_t THTensor_(numel)(THTensor *t)
{
return THTensor_(nElement)(t);
}
@@ -1393,14 +1402,14 @@ void THTensor_(eye)(THTensor *r_, long n, long m)
void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step)
{
- long size;
+ ptrdiff_t size;
real i = 0;
THArgCheck(step > 0 || step < 0, 3, "step must be a non-null number");
THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin))
, 2, "upper bound and larger bound incoherent with step sign");
- size = (long) (((xmax - xmin) / step) + 1);
+ size = (ptrdiff_t) (((xmax - xmin) / step) + 1);
if (THTensor_(nElement)(r_) != size) {
THTensor_(resize1d)(r_, size);
@@ -2044,8 +2053,8 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb)
if (THTensor_(isContiguous)(ta) && THTensor_(isContiguous)(tb)) {
real *tap = THTensor_(data)(ta);
real *tbp = THTensor_(data)(tb);
- long sz = THTensor_(nElement)(ta);
- long i;
+ ptrdiff_t sz = THTensor_(nElement)(ta);
+ ptrdiff_t i;
for (i=0; i<sz; ++i){
if(tap[i] != tbp[i]) return 0;
}
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index d33406f..87f1616 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -60,7 +60,7 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha,
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);
-TH_API long THTensor_(numel)(THTensor *t);
+TH_API ptrdiff_t THTensor_(numel)(THTensor *t);
TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension);
diff --git a/lib/TH/generic/THVector.h b/lib/TH/generic/THVector.h
index 09067e5..5326b16 100644
--- a/lib/TH/generic/THVector.h
+++ b/lib/TH/generic/THVector.h
@@ -2,11 +2,11 @@
#define TH_GENERIC_FILE "generic/THVector.h"
#else
-TH_API void THVector_(fill)(real *x, const real c, const long n);
-TH_API void THVector_(add)(real *y, const real *x, const real c, const long n);
-TH_API void THVector_(diff)(real *z, const real *x, const real *y, const long n);
-TH_API void THVector_(scale)(real *y, const real c, const long n);
-TH_API void THVector_(mul)(real *y, const real *x, const long n);
+TH_API void THVector_(fill)(real *x, const real c, const ptrdiff_t n);
+TH_API void THVector_(add)(real *y, const real *x, const real c, const ptrdiff_t n);
+TH_API void THVector_(diff)(real *z, const real *x, const real *y, const ptrdiff_t n);
+TH_API void THVector_(scale)(real *y, const real c, const ptrdiff_t n);
+TH_API void THVector_(mul)(real *y, const real *x, const ptrdiff_t n);
/* Initialize the dispatch pointers */
TH_API void THVector_(vectorDispatchInit)();
diff --git a/lib/TH/generic/THVectorDefault.c b/lib/TH/generic/THVectorDefault.c
index d51be03..aabc16c 100644
--- a/lib/TH/generic/THVectorDefault.c
+++ b/lib/TH/generic/THVectorDefault.c
@@ -2,8 +2,8 @@
#define TH_GENERIC_FILE "generic/THVectorDefault.c"
#else
-void THVector_(fill_DEFAULT)(real *x, const real c, const long n) {
- long i = 0;
+void THVector_(fill_DEFAULT)(real *x, const real c, const ptrdiff_t n) {
+ ptrdiff_t i = 0;
for(; i < n-4; i += 4)
{
@@ -17,9 +17,9 @@ void THVector_(fill_DEFAULT)(real *x, const real c, const long n) {
x[i] = c;
}
-void THVector_(add_DEFAULT)(real *y, const real *x, const real c, const long n)
+void THVector_(add_DEFAULT)(real *y, const real *x, const real c, const ptrdiff_t n)
{
- long i = 0;
+ ptrdiff_t i = 0;
for(;i < n-4; i += 4)
{
@@ -33,9 +33,9 @@ void THVector_(add_DEFAULT)(real *y, const real *x, const real c, const long n)
y[i] += c * x[i];
}
-void THVector_(diff_DEFAULT)(real *z, const real *x, const real *y, const long n)
+void THVector_(diff_DEFAULT)(real *z, const real *x, const real *y, const ptrdiff_t n)
{
- long i = 0;
+ ptrdiff_t i = 0;
for(; i < n-4; i += 4)
{
@@ -49,9 +49,9 @@ void THVector_(diff_DEFAULT)(real *z, const real *x, const real *y, const long n
z[i] = x[i] - y[i];
}
-void THVector_(scale_DEFAULT)(real *y, const real c, const long n)
+void THVector_(scale_DEFAULT)(real *y, const real c, const ptrdiff_t n)
{
- long i = 0;
+ ptrdiff_t i = 0;
for(; i < n-4; i +=4)
{
@@ -65,9 +65,9 @@ void THVector_(scale_DEFAULT)(real *y, const real c, const long n)
y[i] *= c;
}
-void THVector_(mul_DEFAULT)(real *y, const real *x, const long n)
+void THVector_(mul_DEFAULT)(real *y, const real *x, const ptrdiff_t n)
{
- long i = 0;
+ ptrdiff_t i = 0;
for(; i < n-4; i += 4)
{
diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c
index f16bcda..2f1a556 100644
--- a/lib/TH/generic/THVectorDispatch.c
+++ b/lib/TH/generic/THVectorDispatch.c
@@ -12,7 +12,7 @@
* 3. A dispatch stub, which is what is actually called by clients, that simply wraps the dispatch pointer.
*/
-static void (*THVector_(fill_DISPATCHPTR))(real *, const real, const long) = &THVector_(fill_DEFAULT);
+static void (*THVector_(fill_DISPATCHPTR))(real *, const real, const ptrdiff_t) = &THVector_(fill_DEFAULT);
static FunctionDescription THVector_(fill_DISPATCHTABLE)[] = {
#if defined(__NEON__)
#if defined(TH_REAL_IS_FLOAT)
@@ -28,12 +28,12 @@ static FunctionDescription THVector_(fill_DISPATCHTABLE)[] = {
#endif
FUNCTION_IMPL(THVector_(fill_DEFAULT), SIMDExtension_DEFAULT)
};
-void THVector_(fill)(real *x, const real c, const long n) {
+void THVector_(fill)(real *x, const real c, const ptrdiff_t n) {
THVector_(fill_DISPATCHPTR)(x, c, n);
}
-static void (*THVector_(add_DISPATCHPTR))(real *, const real *, const real, const long) = &THVector_(add_DEFAULT);
+static void (*THVector_(add_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(add_DEFAULT);
static FunctionDescription THVector_(add_DISPATCHTABLE)[] = {
#if defined(__NEON__)
#if defined(TH_REAL_IS_FLOAT)
@@ -50,12 +50,12 @@ static FunctionDescription THVector_(add_DISPATCHTABLE)[] = {
FUNCTION_IMPL(THVector_(add_DEFAULT), SIMDExtension_DEFAULT)
};
-void THVector_(add)(real *y, const real *x, const real c, const long n) {
+void THVector_(add)(real *y, const real *x, const real c, const ptrdiff_t n) {
THVector_(add_DISPATCHPTR)(y, x, c, n);
}
-static void (*THVector_(diff_DISPATCHPTR))(real *, const real *, const real *, const long) = &THVector_(diff_DEFAULT);
+static void (*THVector_(diff_DISPATCHPTR))(real *, const real *, const real *, const ptrdiff_t) = &THVector_(diff_DEFAULT);
static FunctionDescription THVector_(diff_DISPATCHTABLE)[] = {
#if defined(__NEON__)
#if defined(TH_REAL_IS_FLOAT)
@@ -72,12 +72,12 @@ static FunctionDescription THVector_(diff_DISPATCHTABLE)[] = {
FUNCTION_IMPL(THVector_(diff_DEFAULT), SIMDExtension_DEFAULT)
};
-void THVector_(diff)(real *z, const real *x, const real *y, const long n) {
+void THVector_(diff)(real *z, const real *x, const real *y, const ptrdiff_t n) {
THVector_(diff_DISPATCHPTR)(z, x, y, n);
}
-static void (*THVector_(scale_DISPATCHPTR))(real *, const real, const long) = &THVector_(scale_DEFAULT);
+static void (*THVector_(scale_DISPATCHPTR))(real *, const real, const ptrdiff_t) = &THVector_(scale_DEFAULT);
static FunctionDescription THVector_(scale_DISPATCHTABLE)[] = {
#if defined(__NEON__)
#if defined(TH_REAL_IS_FLOAT)
@@ -94,12 +94,12 @@ static FunctionDescription THVector_(scale_DISPATCHTABLE)[] = {
FUNCTION_IMPL(THVector_(scale_DEFAULT), SIMDExtension_DEFAULT)
};
-TH_API void THVector_(scale)(real *y, const real c, const long n) {
+TH_API void THVector_(scale)(real *y, const real c, const ptrdiff_t n) {
THVector_(scale_DISPATCHPTR)(y, c, n);
}
-static void (*THVector_(mul_DISPATCHPTR))(real *, const real *, const long) = &THVector_(mul_DEFAULT);
+static void (*THVector_(mul_DISPATCHPTR))(real *, const real *, const ptrdiff_t) = &THVector_(mul_DEFAULT);
static FunctionDescription THVector_(mul_DISPATCHTABLE)[] = {
#if defined(__NEON__)
#if defined(TH_REAL_IS_FLOAT)
@@ -116,7 +116,7 @@ static FunctionDescription THVector_(mul_DISPATCHTABLE)[] = {
FUNCTION_IMPL(THVector_(mul_DEFAULT), SIMDExtension_DEFAULT)
};
-void THVector_(mul)(real *y, const real *x, const long n) {
+void THVector_(mul)(real *y, const real *x, const ptrdiff_t n) {
THVector_(mul_DISPATCHPTR);
}
diff --git a/lib/TH/vector/NEON.c b/lib/TH/vector/NEON.c
index ee4eb81..bc7cb2b 100644
--- a/lib/TH/vector/NEON.c
+++ b/lib/TH/vector/NEON.c
@@ -1,4 +1,4 @@
-static void THFloatVector_fill_NEON(float *x, const float c, const long n) {
+static void THFloatVector_fill_NEON(float *x, const float c, const ptrdiff_t n) {
float ctemp = c;
float * caddr = &ctemp;
__asm__ __volatile__ (
@@ -29,7 +29,7 @@ static void THFloatVector_fill_NEON(float *x, const float c, const long n) {
}
-static void THFloatVector_diff_NEON(float *z, const float *x, const float *y, const long n) {
+static void THFloatVector_diff_NEON(float *z, const float *x, const float *y, const ptrdiff_t n) {
__asm__ __volatile__ (
"mov r0, %2 @ \n\t"
"mov r1, %1 @ \n\t"
@@ -70,7 +70,7 @@ static void THFloatVector_diff_NEON(float *z, const float *x, const float *y, co
}
-static void THFloatVector_scale_NEON(float *y, const float c, const long n) {
+static void THFloatVector_scale_NEON(float *y, const float c, const ptrdiff_t n) {
float ctemp = c;
float * caddr = &ctemp;
__asm__ __volatile__ (
@@ -150,7 +150,7 @@ static void THFloatVector_scale_NEON(float *y, const float c, const long n) {
}
-static void THFloatVector_mul_NEON(float *y, const float *x, const long n) {
+static void THFloatVector_mul_NEON(float *y, const float *x, const ptrdiff_t n) {
__asm__ __volatile__ (
"mov r0, %0 @ \n\t"
"mov r1, %1 @ \n\t"
@@ -190,7 +190,7 @@ static void THFloatVector_mul_NEON(float *y, const float *x, const long n) {
);
}
-static void THFloatVector_add_NEON(float *y, const float *x, const float c, const long n) {
+static void THFloatVector_add_NEON(float *y, const float *x, const float c, const ptrdiff_t n) {
float ctemp = c;
float * caddr = &ctemp;
__asm__ __volatile__ (
diff --git a/lib/TH/vector/SSE.c b/lib/TH/vector/SSE.c
index c47e28d..781b037 100644
--- a/lib/TH/vector/SSE.c
+++ b/lib/TH/vector/SSE.c
@@ -5,9 +5,9 @@
#endif
-static void THDoubleVector_fill_SSE(double *x, const double c, const long n) {
- long i;
- long off;
+static void THDoubleVector_fill_SSE(double *x, const double c, const ptrdiff_t n) {
+ ptrdiff_t i;
+ ptrdiff_t off;
__m128d XMM0 = _mm_set1_pd(c);
for (i=0; i<=((n)-8); i+=8) {
_mm_storeu_pd((x)+i , XMM0);
@@ -22,8 +22,8 @@ static void THDoubleVector_fill_SSE(double *x, const double c, const long n) {
}
-static void THDoubleVector_add_SSE(double *y, const double *x, const double c, const long n) {
- long i = 0;
+static void THDoubleVector_add_SSE(double *y, const double *x, const double c, const ptrdiff_t n) {
+ ptrdiff_t i = 0;
__m128d XMM7 = _mm_set1_pd(c);
__m128d XMM0,XMM2;
for (; i<=((n)-2); i+=2) {
@@ -39,8 +39,8 @@ static void THDoubleVector_add_SSE(double *y, const double *x, const double c, c
}
-static void THDoubleVector_diff_SSE(double *z, const double *x, const double *y, const long n) {
- long i;
+static void THDoubleVector_diff_SSE(double *z, const double *x, const double *y, const ptrdiff_t n) {
+ ptrdiff_t i;
for (i=0; i<=((n)-8); i+=8) {
__m128d XMM0 = _mm_loadu_pd((x)+i );
__m128d XMM1 = _mm_loadu_pd((x)+i+2);
@@ -59,15 +59,15 @@ static void THDoubleVector_diff_SSE(double *z, const double *x, const double *y,
_mm_storeu_pd((z)+i+4, XMM2);
_mm_storeu_pd((z)+i+6, XMM3);
}
- long off = (n) - ((n)%8);
+ ptrdiff_t off = (n) - ((n)%8);
for (i=0; i<((n)%8); i++) {
z[off+i] = x[off+i] - y[off+i];
}
}
-static void THDoubleVector_scale_SSE(double *y, const double c, const long n) {
- long i;
+static void THDoubleVector_scale_SSE(double *y, const double c, const ptrdiff_t n) {
+ ptrdiff_t i;
__m128d XMM7 = _mm_set1_pd(c);
for (i=0; i<=((n)-4); i+=4) {
__m128d XMM0 = _mm_loadu_pd((y)+i );
@@ -77,15 +77,15 @@ static void THDoubleVector_scale_SSE(double *y, const double c, const long n) {
_mm_storeu_pd((y)+i , XMM0);
_mm_storeu_pd((y)+i+2, XMM1);
}
- long off = (n) - ((n)%4);
+ ptrdiff_t off = (n) - ((n)%4);
for (i=0; i<((n)%4); i++) {
y[off+i] *= c;
}
}
-static void THDoubleVector_mul_SSE(double *y, const double *x, const long n) {
- long i;
+static void THDoubleVector_mul_SSE(double *y, const double *x, const ptrdiff_t n) {
+ ptrdiff_t i;
for (i=0; i<=((n)-8); i+=8) {
__m128d XMM0 = _mm_loadu_pd((x)+i );
__m128d XMM1 = _mm_loadu_pd((x)+i+2);
@@ -104,17 +104,17 @@ static void THDoubleVector_mul_SSE(double *y, const double *x, const long n) {
_mm_storeu_pd((y)+i+4, XMM6);
_mm_storeu_pd((y)+i+6, XMM7);
}
- long off = (n) - ((n)%8);
+ ptrdiff_t off = (n) - ((n)%8);
for (i=0; i<((n)%8); i++) {
y[off+i] *= x[off+i];
}
}
-static void THFloatVector_fill_SSE(float *x, const float c, const long n) {
- long i;
+static void THFloatVector_fill_SSE(float *x, const float c, const ptrdiff_t n) {
+ ptrdiff_t i;
__m128 XMM0 = _mm_set_ps1(c);
- long off;
+ ptrdiff_t off;
for (i=0; i<=((n)-16); i+=16) {
_mm_storeu_ps((x)+i , XMM0);
_mm_storeu_ps((x)+i+4, XMM0);
@@ -128,8 +128,8 @@ static void THFloatVector_fill_SSE(float *x, const float c, const long n) {
}
-static void THFloatVector_add_SSE(float *y, const float *x, const float c, const long n) {
- long i = 0;
+static void THFloatVector_add_SSE(float *y, const float *x, const float c, const ptrdiff_t n) {
+ ptrdiff_t i = 0;
__m128 XMM7 = _mm_set_ps1(c);
__m128 XMM0,XMM2;
for (; i<=((n)-4); i+=4) {
@@ -145,8 +145,8 @@ static void THFloatVector_add_SSE(float *y, const float *x, const float c, const
}
-static void THFloatVector_diff_SSE(float *z, const float *x, const float *y, const long n) {
- long i;
+static void THFloatVector_diff_SSE(float *z, const float *x, const float *y, const ptrdiff_t n) {
+ ptrdiff_t i;
for (i=0; i<=((n)-16); i+=16) {
__m128 XMM0 = _mm_loadu_ps((x)+i );
__m128 XMM1 = _mm_loadu_ps((x)+i+ 4);
@@ -165,15 +165,15 @@ static void THFloatVector_diff_SSE(float *z, const float *x, const float *y, con
_mm_storeu_ps((z)+i+ 8, XMM2);
_mm_storeu_ps((z)+i+12, XMM3);
}
- long off = (n) - ((n)%16);
+ ptrdiff_t off = (n) - ((n)%16);
for (i=0; i<((n)%16); i++) {
z[off+i] = x[off+i] - y[off+i];
}
}
-static void THFloatVector_scale_SSE(float *y, const float c, const long n) {
- long i;
+static void THFloatVector_scale_SSE(float *y, const float c, const ptrdiff_t n) {
+ ptrdiff_t i;
__m128 XMM7 = _mm_set_ps1(c);
for (i=0; i<=((n)-8); i+=8) {
__m128 XMM0 = _mm_loadu_ps((y)+i );
@@ -183,15 +183,15 @@ static void THFloatVector_scale_SSE(float *y, const float c, const long n) {
_mm_storeu_ps((y)+i , XMM0);
_mm_storeu_ps((y)+i+4, XMM1);
}
- long off = (n) - ((n)%8);
+ ptrdiff_t off = (n) - ((n)%8);
for (i=0; i<((n)%8); i++) {
y[off+i] *= c;
}
}
-static void THFloatVector_mul_SSE(float *y, const float *x, const long n) {
- long i;
+static void THFloatVector_mul_SSE(float *y, const float *x, const ptrdiff_t n) {
+ ptrdiff_t i;
for (i=0; i<=((n)-16); i+=16) {
__m128 XMM0 = _mm_loadu_ps((x)+i );
__m128 XMM1 = _mm_loadu_ps((x)+i+ 4);
@@ -210,7 +210,7 @@ static void THFloatVector_mul_SSE(float *y, const float *x, const long n) {
_mm_storeu_ps((y)+i+ 8, XMM6);
_mm_storeu_ps((y)+i+12, XMM7);
}
- long off = (n) - ((n)%16);
+ ptrdiff_t off = (n) - ((n)%16);
for (i=0; i<((n)%16); i++) {
y[off+i] *= x[off+i];
}
diff --git a/lib/luaT/luaT.c b/lib/luaT/luaT.c
index 657cca2..95166ed 100644
--- a/lib/luaT/luaT.c
+++ b/lib/luaT/luaT.c
@@ -4,7 +4,7 @@
#include "luaT.h"
-void* luaT_alloc(lua_State *L, long size)
+void* luaT_alloc(lua_State *L, ptrdiff_t size)
{
void *ptr;
@@ -21,7 +21,7 @@ void* luaT_alloc(lua_State *L, long size)
return ptr;
}
-void* luaT_realloc(lua_State *L, void *ptr, long size)
+void* luaT_realloc(lua_State *L, void *ptr, ptrdiff_t size)
{
if(!ptr)
return(luaT_alloc(L, size));
@@ -354,7 +354,7 @@ void luaT_pushlong(lua_State *L, long n)
#if LUA_VERSION_NUM >= 503
/* Only push the value as an integer if it fits in lua_Integer,
or if the lua_Number representation will be even worse */
- if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) == sizeof(lua_Integer)) {
+ if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) <= sizeof(lua_Integer)) {
lua_pushinteger(L, n);
} else {
lua_pushnumber(L, (lua_Number)n);
@@ -367,7 +367,7 @@ void luaT_pushlong(lua_State *L, long n)
long luaT_checklong(lua_State *L, int idx)
{
#if LUA_VERSION_NUM >= 503
- if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) == sizeof(lua_Integer)) {
+ if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) <= sizeof(lua_Integer)) {
return (long)luaL_checkinteger(L, idx);
} else {
return (long)luaL_checknumber(L, idx);
@@ -380,7 +380,7 @@ long luaT_checklong(lua_State *L, int idx)
long luaT_tolong(lua_State *L, int idx)
{
#if LUA_VERSION_NUM == 503
- if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) == sizeof(lua_Integer)) {
+ if (sizeof(lua_Integer) >= sizeof(long) || sizeof(lua_Number) <= sizeof(lua_Integer)) {
return (long)lua_tointeger(L, idx);
} else {
return (long)lua_tonumber(L, idx);
@@ -390,6 +390,34 @@ long luaT_tolong(lua_State *L, int idx)
#endif
}
+void luaT_pushinteger(lua_State *L, ptrdiff_t n)
+{
+#if LUA_VERSION_NUM >= 503
+ /* Only push the value as an integer if it fits in lua_Integer,
+ or if the lua_Number representation will be even worse */
+ if (sizeof(lua_Integer) >= sizeof(ptrdiff_t) || sizeof(lua_Number) <= sizeof(lua_Integer)) {
+ lua_pushinteger(L, n);
+ } else {
+ lua_pushnumber(L, (lua_Number)n);
+ }
+#else
+ lua_pushnumber(L, (lua_Number)n);
+#endif
+}
+
+ptrdiff_t luaT_checkinteger(lua_State *L, int idx)
+{
+#if LUA_VERSION_NUM >= 503
+ if (sizeof(lua_Integer) >= sizeof(ptrdiff_t) || sizeof(lua_Number) <= sizeof(lua_Integer)) {
+ return (ptrdiff_t)luaL_checkinteger(L, idx);
+ } else {
+ return (ptrdiff_t)luaL_checknumber(L, idx);
+ }
+#else
+ return (ptrdiff_t)luaL_checknumber(L, idx);
+#endif
+}
+
void *luaT_getfieldcheckudata(lua_State *L, int ud, const char *field, const char *tname)
{
void *p;
@@ -980,10 +1008,17 @@ int luaT_lua_isequal(lua_State *L)
static void luaT_pushpointer(lua_State *L, const void *ptr)
{
+#if LUA_VERSION_NUM >= 503
+ // this assumes that lua_Integer is a ptrdiff_t
+ if (sizeof(void *) > sizeof(lua_Integer))
+ luaL_error(L, "Pointer value can't be represented as a Lua integer (an overflow would occur)");
+ lua_pushinteger(L, (uintptr_t)(ptr));
+#else
// 2^53 - this assumes that lua_Number is a double
if ((uintptr_t)ptr > 9007199254740992LLU)
luaL_error(L, "Pointer value can't be represented as a Lua number (an overflow would occur)");
lua_pushnumber(L, (uintptr_t)(ptr));
+#endif
}
int luaT_lua_pointer(lua_State *L)
diff --git a/lib/luaT/luaT.h b/lib/luaT/luaT.h
index b1b6cd9..2479a1d 100644
--- a/lib/luaT/luaT.h
+++ b/lib/luaT/luaT.h
@@ -47,8 +47,8 @@ static int luaL_typerror(lua_State *L, int narg, const char *tname)
/* C functions */
-LUAT_API void* luaT_alloc(lua_State *L, long size);
-LUAT_API void* luaT_realloc(lua_State *L, void *ptr, long size);
+LUAT_API void* luaT_alloc(lua_State *L, ptrdiff_t size);
+LUAT_API void* luaT_realloc(lua_State *L, void *ptr, ptrdiff_t size);
LUAT_API void luaT_free(lua_State *L, void *ptr);
LUAT_API void luaT_setfuncs(lua_State *L, const luaL_Reg *l, int nup);
@@ -73,6 +73,9 @@ LUAT_API void luaT_pushlong(lua_State *L, long n);
LUAT_API long luaT_checklong(lua_State *L, int idx);
LUAT_API long luaT_tolong(lua_State *L, int idx);
+LUAT_API void luaT_pushinteger(lua_State *L, ptrdiff_t n);
+LUAT_API ptrdiff_t luaT_checkinteger(lua_State *L, int idx);
+
LUAT_API void *luaT_getfieldcheckudata(lua_State *L, int ud, const char *field, const char *tname);
LUAT_API void *luaT_getfieldchecklightudata(lua_State *L, int ud, const char *field);
LUAT_API double luaT_getfieldchecknumber(lua_State *L, int ud, const char *field);