diff options
-rw-r--r-- | Narrow.lua | 35 | ||||
-rw-r--r-- | init.lua | 6 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 55 |
3 files changed, 96 insertions, 0 deletions
diff --git a/Narrow.lua b/Narrow.lua new file mode 100644 index 0000000..c9b7b89 --- /dev/null +++ b/Narrow.lua @@ -0,0 +1,35 @@ +local Narrow, parent = torch.class('nn.Narrow', 'nn.Module') + +function Narrow:__init(dimension,offset,length) + parent.__init(self) + self.dimension=dimension + self.index=offset + self.length=length or 1 +end + +function Narrow:forward(input) + local output=input:narrow(self.dimension,self.index,self.length); + self.output:resizeAs(output) + return self.output:copy(output) +end + +function Narrow:backward(input, gradOutput) + self.gradInput:resizeAs(input) + self.gradInput:zero(); + self.gradInput:narrow(self.dimension,self.index,self.length):copy(gradOutput) + return self.gradInput +end + +function Narrow:write(file) + parent.write(self, file) + file:writeInt(self.dimension) + file:writeLong(self.index) + file:writeLong(self.length) +end + +function Narrow:read(file, version) + parent.read(self, file) + self.dimension = file:readInt() + self.index = file:readLong() + self.length = file:readLong() +end diff --git a/init.lua b/init.lua new file mode 100644 index 0000000..ff92c6b --- /dev/null +++ b/init.lua @@ -0,0 +1,6 @@ + +require 'torch' +require 'nn' +require 'nnx' + +torch.include('nnx', 'Narrow.lua') diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec new file mode 100644 index 0000000..2d4073a --- /dev/null +++ b/nnx-1.0-1.rockspec @@ -0,0 +1,55 @@ + +package = "nnx" +version = "1.0-1" + +source = { + url = "nnx-1.0-1.tgz" +} + +description = { + summary = "An extension to Torch7's nn package.", + detailed = [[ + This package provides extra trainable modules, + which naturally extend the nn package. + Some of those might get marged into the original + nn package, at some point. For this reason, + all the modules from nnx are appended to nn. + ]], + homepage = "", + license = "MIT/X11" -- or whatever you like +} + +dependencies = { + "lua >= 5.1", + "xlua" +} + +build = { + type = "cmake", + + cmake = [[ + cmake_minimum_required(VERSION 2.8) + + set (CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) + + # infer path for Torch7 + string (REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" TORCH_PREFIX "${CMAKE_INSTALL_PREFIX}" ) + message (STATUS "Found Torch7, installed in: " ${TORCH_PREFIX}) + + find_package (Torch REQUIRED) + + SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + + include_directories (${TORCH_INCLUDE_DIR}) + #add_library (nnx SHARED init.c) + #link_directories (${TORCH_LIBRARY_DIR}) + #target_link_libraries (nnx ${TORCH_LIBRARIES}) + + install_files(/lua/nnx init.lua) + #install_targets(/lib nnx) + ]], + + variables = { + CMAKE_INSTALL_PREFIX = "$(PREFIX)" + } +} |