Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-10 20:09:43 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-10 20:09:43 +0300
commite4ebbd3a2d668ae51d49d9deaf2c62fc6de81a8e (patch)
treec0c3cbf24139a1504ff232476dd7b64987605821 /TensorMath.lua
parent7860a76e1cc50e5c679a965c95cdca2501cac9bc (diff)
parent9efd392d4974e3fe3b1469809971d5d4f360ec7b (diff)
Merge remote-tracking branch 'upstream/master' into more-generic-functions
Diffstat (limited to 'TensorMath.lua')
-rw-r--r--TensorMath.lua54
1 files changed, 53 insertions, 1 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index abdda0b..e917f8c 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -2,6 +2,58 @@ local wrap = require 'cwrap'
local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
+local argtypes = wrap.CInterface.argtypes
+
+argtypes['ptrdiff_t'] = {
+
+ helpname = function(arg)
+ return 'ptrdiff_t'
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(tostring(arg.default)) or 0
+ return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isinteger(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (%s)lua_tointeger(L, %d);", arg.i, 'ptrdiff_t', idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = tostring(arg.default)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushinteger(L, (lua_Integer)arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushinteger(L, (lua_Integer)arg%d);', arg.i)
+ end
+ end
+}
interface:print('/* WARNING: autogenerated file */')
interface:print('')
@@ -559,7 +611,7 @@ for k, Tensor_ in pairs(handledTypenames) do
wrap("numel",
cname("numel"),
{{name=Tensor},
- {name="long", creturned=true}})
+ {name="ptrdiff_t", creturned=true}})
wrap("add",
cname("add"),