Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-07 23:56:23 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-07 23:56:23 +0300
commit7fd950fbda443066f5c5ca24db2de35254de9164 (patch)
treee093dd14d7dba436e98191e9ed55d47a318bbe84 /src
parent4291c918aefefd52da1c4334743e0a2960a9abcc (diff)
working basic training
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt18
-rw-r--r--src/cudnn_tensor.h400
-rw-r--r--src/exception.cpp108
-rw-r--r--src/exception.h156
-rw-r--r--src/marian.h248
-rw-r--r--src/operators.h370
-rw-r--r--src/tensor.h117
-rw-r--r--src/test.cpp55
-rw-r--r--src/test.cu71
-rw-r--r--src/thrust_functions.h95
10 files changed, 1409 insertions, 229 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
new file mode 100644
index 00000000..3d751b51
--- /dev/null
+++ b/src/CMakeLists.txt
@@ -0,0 +1,18 @@
+
+include_directories(.)
+
+add_library(libcommon OBJECT
+ exception.cpp
+)
+
+cuda_add_executable(
+ marian
+ test.cu
+ $<TARGET_OBJECTS:libcommon>
+)
+
+foreach(exec marian)
+ target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
+ cuda_add_cublas_to_target(${exec})
+ set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
+endforeach(exec)
diff --git a/src/cudnn_tensor.h b/src/cudnn_tensor.h
new file mode 100644
index 00000000..cd71d942
--- /dev/null
+++ b/src/cudnn_tensor.h
@@ -0,0 +1,400 @@
+#pragma once
+
+#include <memory>
+#include <functional>
+#include <vector>
+#include <cmath>
+
+#include <cudnn.h>
+#include <cublas_v2.h>
+#include <thrust/device_vector.h>
+#include <thrust/functional.h>
+
+#include "exception.h"
+#include "thrust_functions.h"
+
+namespace marian {
+
+struct Handles {
+ cudnnHandle_t cudnnHandle;
+ cublasHandle_t cublasHandle;
+
+ cudnnOpTensorDescriptor_t add;
+
+ Handles() {
+ cudnnCreate(&cudnnHandle);
+ cublasCreate(&cublasHandle);
+ cudnnCreateOpTensorDescriptor(&add);
+ cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN);
+ }
+
+ ~Handles() {
+ cudnnDestroy(cudnnHandle);
+ cublasDestroy(cublasHandle);
+ cudnnDestroyOpTensorDescriptor(add);
+ }
+};
+
+Handles handles;
+
+typedef std::vector<int> Shape;
+
+template<class Float>
+class TensorImpl {
+ private:
+ Shape shape_;
+ thrust::device_vector<Float> data_;
+ cudnnTensorDescriptor_t desc_;
+ size_t tno_;
+ static size_t tensorCounter;
+
+ cudnnDataType_t dataType() {
+ switch(sizeof(Float)) {
+ case 2: return CUDNN_DATA_HALF;
+ case 8: return CUDNN_DATA_DOUBLE;
+ default: return CUDNN_DATA_FLOAT;
+ }
+ }
+
+ public:
+ typedef Float value_type;
+
+ TensorImpl(const Shape& shape, value_type value = 0)
+ : shape_(shape), tno_(tensorCounter++)
+ {
+ // @TODO:
+ UTIL_THROW_IF2(shape_.size() != 2,
+ "For now, only 2D Tensors, will be fixed later.");
+
+ UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
+ "Wrong number of dimensions: " << shape_.size());
+ int size = std::accumulate(shape_.begin(), shape_.end(),
+ 1, std::multiplies<int>());
+ data_.resize(size, value);
+ cudnnCreateTensorDescriptor(&desc_);
+ switch (shape_.size()) {
+ case 1:
+ cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
+ shape_[0], 1, 1, 1); break;
+ case 2:
+ cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
+ shape_[0], shape_[1], 1, 1); break;
+ case 3:
+ cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
+ shape_[0], shape_[1], shape_[2], 1); break;
+ case 4:
+ cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
+ shape_[0], shape_[1], shape_[2], shape_[3]); break;
+ }
+ }
+
+ TensorImpl(const TensorImpl&) = delete;
+ TensorImpl(TensorImpl&&) = delete;
+
+ ~TensorImpl() {
+ cudnnDestroyTensorDescriptor(desc_);
+ }
+
+ value_type operator[](size_t i) const {
+ return data_[i];
+ }
+
+ auto begin() -> decltype( data_.begin() ) {
+ return data_.begin();
+ }
+
+ auto begin() const -> decltype( data_.begin() ) {
+ return data_.begin();
+ }
+
+ auto end() -> decltype( data_.end() ) {
+ return data_.end();
+ }
+
+ auto end() const -> decltype( data_.end() ) {
+ return data_.end();
+ }
+
+ const Shape& shape() const {
+ return shape_;
+ }
+
+ size_t size() const {
+ return data_.size();
+ }
+
+ value_type* data() {
+ return thrust::raw_pointer_cast(data_.data());
+ }
+
+ cudnnTensorDescriptor_t desc() const {
+ return desc_;
+ }
+
+ size_t id() const {
+ return tno_;
+ }
+
+ void set(value_type value) {
+ thrust::fill(data_.begin(), data_.end(), value);
+ }
+};
+
+template <typename Type>
+size_t TensorImpl<Type>::tensorCounter = 0;
+
+class Tensor {
+ private:
+ std::shared_ptr<TensorImpl<float>> pimpl_;
+
+ public:
+ typedef TensorImpl<float>::value_type value_type;
+
+ Tensor(const Shape& shape, value_type value = 0)
+ : pimpl_(new TensorImpl<value_type>(shape, value)) {}
+
+ // Single value with broadcasting super powers. Might be
+ // worth getting rid of this performance-wise, but is saves
+ // so much typing when defining operators.
+ Tensor(value_type value)
+ : pimpl_(new TensorImpl<value_type>({1, 1}, value)) {}
+
+ Tensor() {}
+
+ ~Tensor() {}
+
+ value_type operator[](size_t i) const {
+ return (*pimpl_)[i];
+ }
+
+ size_t size() const {
+ return pimpl_->size();
+ }
+
+ value_type* data() {
+ return pimpl_->data();
+ }
+
+ const value_type* data() const {
+ return pimpl_->data();
+ }
+
+ auto begin() -> decltype( pimpl_->begin() ) {
+ return pimpl_->begin();
+ }
+
+ auto begin() const -> decltype( pimpl_->begin() ) {
+ return pimpl_->begin();
+ }
+
+ auto end() -> decltype( pimpl_->begin() ) {
+ return pimpl_->begin();
+ }
+
+ auto end() const -> decltype( pimpl_->begin() ) {
+ return pimpl_->begin();
+ }
+
+ const Shape& shape() const {
+ return pimpl_->shape();
+ }
+
+ cudnnTensorDescriptor_t desc() const {
+ return pimpl_->desc();
+ }
+
+ void set(value_type value) {
+ pimpl_->set(value);
+ }
+
+ size_t id() const {
+ return pimpl_->id();
+ }
+
+ operator bool() {
+ return pimpl_ != nullptr;
+ }
+};
+
+Tensor uniform(Tensor t, float a=-0.1, float b=0.1) {
+ std::vector<float> r(t.size());
+ for(int i = 0; i < r.size(); i++)
+ r[i] = (float(rand() % 2000) - 1000.0)/10000.0;
+ thrust::copy(r.begin(), r.end(), t.begin());
+ return t;
+};
+
+using namespace thrust::placeholders;
+#define MAX_THREADS 512
+#define MAX_BLOCKS 65535
+
+template <class Functor>
+__global__ void gElement(Functor functor, float* out,
+ size_t rows, size_t cols) {
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ float* rowOut = out + j * cols;
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int i = tid + threadIdx.x;
+ if(i < cols)
+ rowOut[i] = functor(rowOut[i]);;
+ }
+ }
+ }
+}
+
+template <class Functor>
+__global__ void gElement(Functor functor,
+ float* out, const float* in,
+ size_t rows, size_t cols) {
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ float* rowOut = out + j * cols;
+ const float* rowIn = in + j * cols;
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int i = tid + threadIdx.x;
+ if(i < cols)
+ rowOut[i] = functor(rowOut[i], rowIn[i]);;
+ }
+ }
+ }
+}
+
+template <class Functor>
+__global__ void gElement(Functor functor,
+ float* out, const float* in1, const float* in2,
+ size_t rows, size_t cols) {
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ float* rowOut = out + j * cols;
+ const float* rowIn1 = in1 + j * cols;
+ const float* rowIn2 = in2 + j * cols;
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int i = tid + threadIdx.x;
+ if(i < cols)
+ rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
+ }
+ }
+ }
+}
+
+template <class Functor>
+__global__ void gElement(Functor functor,
+ float* out, const float* in1,
+ const float* in2, const float* in3,
+ size_t rows, size_t cols) {
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ float* rowOut = out + j * cols;
+ const float* rowIn1 = in1 + j * cols;
+ const float* rowIn2 = in2 + j * cols;
+ const float* rowIn3 = in3 + j * cols;
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int i = tid + threadIdx.x;
+ if(i < cols)
+ rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
+ }
+ }
+ }
+}
+
+// @TODO add broadcasting
+
+template <class Functor>
+void Element(Functor functor, Tensor Out) {
+ float* d_out = Out.data();
+ int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
+ int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
+ gElement<<<blocks, threads>>>(functor, d_out,
+ Out.shape()[0], Out.shape()[1]);
+ cudaStreamSynchronize(0);
+}
+
+template <class Functor>
+void Element(Functor functor,
+ Tensor Out, const Tensor In) {
+ float* d_out = Out.data();
+ const float* d_in = In.data();
+
+ int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
+ int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
+ gElement<<<blocks, threads>>>(functor, d_out, d_in,
+ Out.shape()[0], Out.shape()[1]);
+ cudaStreamSynchronize(0);
+}
+
+template <class Functor>
+void Element(Functor functor,
+ Tensor Out, const Tensor In1, const Tensor In2) {
+
+ float* d_out = Out.data();
+ const float* d_in1 = In1.data();
+ const float* d_in2 = In2.data();
+
+ int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
+ int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
+ gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
+ Out.shape()[0], Out.shape()[1]);
+ cudaStreamSynchronize(0);
+}
+
+template <class Functor>
+void Element(Functor functor,
+ Tensor Out, const Tensor In1,
+ const Tensor In2, const Tensor In3) {
+
+ float* d_out = Out.data();
+ const float* d_in1 = In1.data();
+ const float* d_in2 = In2.data();
+ const float* d_in3 = In3.data();
+
+ int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
+ int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
+ gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
+ Out.shape()[0], Out.shape()[1]);
+ cudaStreamSynchronize(0);
+}
+
+Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
+ bool transA, bool transB, float beta) {
+ float alpha = 1.0;
+
+ size_t m = A.shape()[0];
+ size_t k = A.shape()[1];
+ if(transA)
+ std::swap(m, k);
+
+ size_t l = B.shape()[0];
+ size_t n = B.shape()[1];
+ if(transB)
+ std::swap(l, n);
+
+ size_t lda = A.shape()[1];
+ size_t ldb = B.shape()[1];
+ size_t ldc = B.shape()[1];
+
+ if(transB)
+ ldc = B.shape()[0];
+
+ cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ cublasSgemm(handle, opB, opA,
+ n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
+ return C;
+}
+
+Tensor Prod(Tensor C, const Tensor A, const Tensor B,
+ bool transA, bool transB, float beta = 0) {
+
+ return Prod(handles.cublasHandle, C, A, B, transA, transB, beta);
+}
+
+} \ No newline at end of file
diff --git a/src/exception.cpp b/src/exception.cpp
new file mode 100644
index 00000000..453fcf66
--- /dev/null
+++ b/src/exception.cpp
@@ -0,0 +1,108 @@
+#include "exception.h"
+
+#ifdef __GXX_RTTI
+#include <typeinfo>
+#endif
+
+#include <cerrno>
+#include <cstring>
+
+#if defined(_WIN32) || defined(_WIN64)
+#include <windows.h>
+#include <io.h>
+#endif
+
+namespace util {
+
+Exception::Exception() throw() {}
+Exception::~Exception() throw() {}
+
+Exception::Exception(const Exception& o) throw() {
+ what_.str(o.what_.str());
+}
+
+void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
+ /* The child class might have set some text, but we want this to come first.
+ * Another option would be passing this information to the constructor, but
+ * then child classes would have to accept constructor arguments and pass
+ * them down.
+ */
+ std::string old_text = what_.str();
+ what_.str(std::string());
+ what_ << file << ':' << line;
+ if (func) what_ << " in " << func << " threw ";
+ if (child_name) {
+ what_ << child_name;
+ } else {
+#ifdef __GXX_RTTI
+ what_ << typeid(this).name();
+#else
+ what_ << "an exception";
+#endif
+ }
+ if (condition) {
+ what_ << " because `" << condition << '\'';
+ }
+ what_ << ".\n";
+ what_ << old_text;
+}
+
+namespace {
+
+#ifdef __GNUC__
+const char *HandleStrerror(int ret, const char *buf) __attribute__ ((unused));
+const char *HandleStrerror(const char *ret, const char * /*buf*/) __attribute__ ((unused));
+#endif
+// At least one of these functions will not be called.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
+// The XOPEN version.
+const char *HandleStrerror(int ret, const char *buf) {
+ if (!ret) return buf;
+ return NULL;
+}
+
+// The GNU version.
+const char *HandleStrerror(const char *ret, const char * /*buf*/) {
+ return ret;
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+} // namespace
+
+ErrnoException::ErrnoException() throw() : errno_(errno) {
+ char buf[200];
+ buf[0] = 0;
+#if defined(sun) || defined(_WIN32) || defined(_WIN64)
+ const char *add = strerror(errno);
+#else
+ const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
+#endif
+
+ if (add) {
+ *this << add << ' ';
+ }
+}
+
+ErrnoException::~ErrnoException() throw() {}
+
+OverflowException::OverflowException() throw() {}
+OverflowException::~OverflowException() throw() {}
+
+#if defined(_WIN32) || defined(_WIN64)
+WindowsException::WindowsException() throw() {
+ unsigned int last_error = GetLastError();
+ char error_msg[256] = "";
+ if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, last_error, LANG_NEUTRAL, error_msg, sizeof(error_msg), NULL)) {
+ *this << "Windows error " << GetLastError() << " while formatting Windows error " << last_error << ". ";
+ } else {
+ *this << "Windows error " << last_error << ": " << error_msg;
+ }
+}
+WindowsException::~WindowsException() throw() {}
+#endif
+
+} // namespace util
diff --git a/src/exception.h b/src/exception.h
new file mode 100644
index 00000000..85827d8c
--- /dev/null
+++ b/src/exception.h
@@ -0,0 +1,156 @@
+#pragma once
+
+#include <sstream>
+#include <exception>
+#include <limits>
+#include <string>
+#include <stdint.h>
+
+namespace util {
+
+template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
+
+class Exception : public std::exception {
+ public:
+ Exception() throw();
+ virtual ~Exception() throw();
+ Exception(const Exception& o) throw();
+
+ const char *what() const throw() { return what_.str().c_str(); }
+
+ // For use by the UTIL_THROW macros.
+ void SetLocation(
+ const char *file,
+ unsigned int line,
+ const char *func,
+ const char *child_name,
+ const char *condition);
+
+ private:
+ template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
+
+ // This helps restrict operator<< defined below.
+ template <class T> struct ExceptionTag {
+ typedef T Identity;
+ };
+
+ std::stringstream what_;
+};
+
+/* This implements the normal operator<< for Exception and all its children.
+ * SFINAE means it only applies to Exception. Think of this as an ersatz
+ * boost::enable_if.
+ */
+template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
+ e.what_ << data;
+ return e;
+}
+
+#ifdef __GNUC__
+#define UTIL_FUNC_NAME __PRETTY_FUNCTION__
+#else
+#ifdef _WIN32
+#define UTIL_FUNC_NAME __FUNCTION__
+#else
+#define UTIL_FUNC_NAME NULL
+#endif
+#endif
+
+/* Create an instance of Exception, add the message Modify, and throw it.
+ * Modify is appended to the what() message and can contain << for ostream
+ * operations.
+ *
+ * do .. while kludge to swallow trailing ; character
+ * http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .
+ * Arg can be a constructor argument to the exception.
+ */
+#define UTIL_THROW_BACKEND(Condition, Exception, Arg, Modify) do { \
+ Exception UTIL_e Arg; \
+ UTIL_e.SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, #Exception, Condition); \
+ UTIL_e << Modify; \
+ throw UTIL_e; \
+} while (0)
+
+#define UTIL_THROW_ARG(Exception, Arg, Modify) \
+ UTIL_THROW_BACKEND(NULL, Exception, Arg, Modify)
+
+#define UTIL_THROW(Exception, Modify) \
+ UTIL_THROW_BACKEND(NULL, Exception, , Modify);
+
+#define UTIL_THROW2(Modify) \
+ UTIL_THROW_BACKEND(NULL, util::Exception, , Modify);
+
+#if __GNUC__ >= 3
+#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0)
+#else
+#define UTIL_UNLIKELY(x) (x)
+#endif
+
+#if __GNUC__ >= 3
+#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1)
+#else
+#define UTIL_LIKELY(x) (x)
+#endif
+
+#define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \
+ if (UTIL_UNLIKELY(Condition)) { \
+ UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \
+ } \
+} while (0)
+
+#define UTIL_THROW_IF(Condition, Exception, Modify) \
+ UTIL_THROW_IF_ARG(Condition, Exception, , Modify)
+
+#define UTIL_THROW_IF2(Condition, Modify) \
+ UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify)
+
+// Exception that records errno and adds it to the message.
+class ErrnoException : public Exception {
+ public:
+ ErrnoException() throw();
+
+ virtual ~ErrnoException() throw();
+
+ int Error() const throw() { return errno_; }
+
+ private:
+ int errno_;
+};
+
+// file wasn't there, or couldn't be open for some reason
+class FileOpenException : public Exception {
+ public:
+ FileOpenException() throw() {}
+ ~FileOpenException() throw() {}
+};
+
+// Utilities for overflow checking.
+class OverflowException : public Exception {
+ public:
+ OverflowException() throw();
+ ~OverflowException() throw();
+};
+
+template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
+ UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
+ return value;
+}
+
+template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
+ return value;
+}
+
+inline std::size_t CheckOverflow(uint64_t value) {
+ return CheckOverflowInternal<sizeof(std::size_t)>(value);
+}
+
+#if defined(_WIN32) || defined(_WIN64)
+/* Thrown for Windows specific operations. */
+class WindowsException : public Exception {
+ public:
+ WindowsException() throw();
+ ~WindowsException() throw();
+};
+#endif
+
+} // namespace util
diff --git a/src/marian.h b/src/marian.h
index 18d561c4..b8320d91 100644
--- a/src/marian.h
+++ b/src/marian.h
@@ -5,211 +5,111 @@
#include <vector>
#include <cmath>
-#include <boost/pool/pool.hpp>
+#include "exception.h"
+#include "cudnn_tensor.h"
namespace marian {
-typedef float Tensor; // Now do this for cuDNN tensors!
-struct Chainable;
-
-boost::pool<> p(sizeof(char));
-std::vector<Chainable*> stack;
-
-struct Chainable {
+template <class DataType>
+struct Chainable : public std::enable_shared_from_this<Chainable<DataType>> {
Chainable() { }
virtual ~Chainable() { }
-
- virtual void chain() { }
+ virtual void forward() { }
+ virtual void backward() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }
-
- static inline void* operator new(size_t nbytes) {
- // thread_local variable
- return p.ordered_malloc(nbytes);
- }
+
+ virtual DataType val() = 0;
+ virtual DataType grad() = 0;
};
-class Vimpl : public Chainable {
+typedef std::vector<Chainable<Tensor>*> ChainableStack;
+typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
+
+ChainableStack stack;
+
+class Node : public Chainable<Tensor> {
public:
- Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} {
+ Node(const Tensor t) : val_(t) {
+ //std::cerr << "Putting node with tensor " << t.id() << " on stack" << std::endl;
stack.push_back(this);
}
- ~Vimpl() {};
+ virtual ~Node() {};
+
+ virtual void init_dependent() {
+ if(adj_) {
+ adj_.set(1);
+ }
+ else {
+ adj_ = Tensor(val_.shape(), 1);
+ }
+ }
- virtual void init_dependent() { adj_ = 1; }
- virtual void set_zero_adjoint() { adj_ = 0; }
+ virtual void set_zero_adjoint() {
+ if(adj_) {
+ adj_.set(0);
+ }
+ else {
+ adj_ = Tensor(val_.shape(), 0);
+ }
+ }
- const Tensor& val() const { return val_; };
- Tensor& grad() { return adj_; };
+ virtual Tensor val() { return val_; };
+ virtual Tensor grad() { return adj_; };
protected:
- const Tensor val_;
- Tensor adj_;
+ Tensor val_;
+ Tensor adj_;
};
-typedef Vimpl* VimplPtr;
-
-static void set_zero_all_adjoints() {
- for(auto&& v : stack)
- v->set_zero_adjoint();
-}
-
-static void grad(Chainable* v) {
- typedef std::vector<Chainable*>::reverse_iterator It;
- v->init_dependent();
- for(It it = stack.rbegin(); it != stack.rend(); ++it) {
- (*it)->chain();
- }
-}
-
class Var {
public:
- Var() : vimpl_{nullptr} {}
- Var(const Tensor& t) : vimpl_{new Vimpl{t}} {}
- Var(const VimplPtr& vimpl) : vimpl_{vimpl} {}
+ Var() : pimpl_(nullptr) {}
+ Var(const Tensor t) : pimpl_(new Node(t)) {}
+ Var(const Tensor::value_type v) : pimpl_(new Node(Tensor(v))) {}
+ Var(const ChainPtr chainable) : pimpl_(chainable) {}
+ Var(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
- const Tensor& val() const {
- return vimpl_->val();
+ Tensor val() {
+ return pimpl_->val();
}
- Tensor& grad() {
- return vimpl_->grad();
+ Tensor grad() {
+ return pimpl_->grad();
}
- VimplPtr vimpl() const {
- return vimpl_;
+ ChainPtr pimpl() {
+ return pimpl_;
}
- void calc_gradients() {
- marian::grad(vimpl_);
+ void forward() {
+ UTIL_THROW_IF2(pimpl_.get() != stack.back(),
+ "Trying to call forward on non-root of computation graph");
+
+ for(auto&& v : stack)
+ v->forward();
}
- private:
- VimplPtr vimpl_;
-};
-
-///////////////////////////////////////////////////
-
-struct OpVimpl : public Vimpl {
- OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
-
- VimplPtr a_;
-};
-
-
-struct LogVimpl : public OpVimpl {
- LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { }
-
- void chain() {
- a_->grad() += adj_ / a_->val();
- }
-};
-
-inline Var log(const Var& a) {
- return Var(VimplPtr(new LogVimpl(a.vimpl())));
-}
-
-struct ExpVimpl : public OpVimpl {
- ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
-
- void chain() {
- a_->grad() += adj_ * std::exp(a_->val());
- }
-};
-
-inline Var exp(const Var& a) {
- return Var(VimplPtr(new ExpVimpl(a.vimpl())));
-}
-
-struct NegVimpl : public OpVimpl {
- NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
-
- void chain() {
- a_->grad() -= adj_;
- }
-};
-
-inline Var operator-(const Var& a) {
- return Var(VimplPtr(new NegVimpl(a.vimpl())));
-}
-
-// @TODO: take care of large exponents
-struct SigmaVimpl : public OpVimpl {
- SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
-
- void chain() {
- Tensor l = 1.f / (1.f + std::exp(-a_->val()));
- a_->grad() += adj_ * l * (1 - l);
- }
-};
-
-inline Var sigma(const Var& a) {
- return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
-}
-
-///////////////////////////////////////////////////
-
-
-struct OpVimplVV : public Vimpl {
- VimplPtr a_;
- VimplPtr b_;
+ void backward() {
+ UTIL_THROW_IF2(pimpl_.get() != stack.back(),
+ "Trying to call backward on non-root of computation graph");
+
+ for(auto&& v : stack)
+ v->set_zero_adjoint();
- OpVimplVV(Tensor t, VimplPtr a, VimplPtr b)
- : Vimpl(t), a_(a), b_(b) { }
-};
-
-struct PlusVimplVV : public OpVimplVV {
- PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { }
-
- void chain() {
- a_->grad() += adj_;
- b_->grad() += adj_;
- }
-};
-
-inline Var operator+(const Var& a, const Var& b) {
- return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
-}
-
-struct MinusVimplVV : public OpVimplVV {
- MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
-
- void chain() {
- a_->grad() -= adj_;
- b_->grad() -= adj_;
- }
-};
-
-inline Var operator-(const Var& a, const Var& b) {
- return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
-}
-
-struct MultVimplVV : public OpVimplVV {
- MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
-
- void chain() {
- a_->grad() += adj_ * b_->val();
- b_->grad() += adj_ * a_->val();
- }
-};
-
-inline Var operator*(const Var& a, const Var& b) {
- return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
-}
-
-struct DivVimplVV : public OpVimplVV {
- DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
-
- void chain() {
- a_->grad() += adj_ / b_->val();
- b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
- }
+ typedef ChainableStack::reverse_iterator It;
+ pimpl_->init_dependent();
+ for(It it = stack.rbegin(); it != stack.rend(); ++it)
+ (*it)->backward();
+ }
+
+ operator ChainPtr() {
+ return pimpl_;
+ }
+
+ private:
+ ChainPtr pimpl_;
};
-inline Var operator/(const Var& a, const Var& b) {
- return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
-}
-
-
} \ No newline at end of file
diff --git a/src/operators.h b/src/operators.h
new file mode 100644
index 00000000..340e5188
--- /dev/null
+++ b/src/operators.h
@@ -0,0 +1,370 @@
+#pragma once
+
+#include <memory>
+#include <functional>
+#include <vector>
+#include <cmath>
+
+#include "marian.h"
+#include "cudnn_tensor.h"
+
+namespace marian {
+
+/*** Unary operators ***/
+
+struct UnaryNodeOp : public Node {
+ ChainPtr a_;
+
+ UnaryNodeOp(const Tensor t, ChainPtr a)
+ : Node(t), a_(a) {}
+};
+
+struct SigmaNodeOp : public UnaryNodeOp {
+ SigmaNodeOp(ChainPtr a)
+ : UnaryNodeOp(Tensor(a->val().shape()), a) { }
+
+ void forward() {
+ Element(_1 = Sigma(_2),
+ val_, a_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
+ a_->grad(), adj_, a_->val());
+ }
+};
+
+inline Var sigma(Var a) {
+ return Var(new SigmaNodeOp(a));
+}
+
+struct TanhNodeOp : public UnaryNodeOp {
+ TanhNodeOp(ChainPtr a)
+ : UnaryNodeOp(Tensor(a->val().shape()), a) { }
+
+ void forward() {
+ Element(_1 = Tanh(_2),
+ val_, a_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
+ a_->grad(), adj_, a_->val());
+ }
+};
+
+inline Var tanh(Var a) {
+ return Var(new TanhNodeOp(a));
+}
+
+struct LogNodeOp : public UnaryNodeOp {
+ LogNodeOp(ChainPtr a)
+ : UnaryNodeOp(Tensor(a->val().shape()), a) { }
+
+ void forward() {
+ Element(_1 = Log(_2), val_, a_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * 1.f / _3,
+ a_->grad(), adj_, a_->val());
+ }
+};
+
+inline Var log(Var a) {
+ return Var(new LogNodeOp(a));
+};
+
+struct ExpNodeOp : public UnaryNodeOp {
+ ExpNodeOp(ChainPtr a)
+ : UnaryNodeOp(Tensor(a->val().shape()), a) { }
+
+ void forward() {
+ Element(_1 = Exp(_2), val_, a_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * Exp(_3),
+ a_->grad(), adj_, a_->val());
+ }
+};
+
+inline Var exp(Var a) {
+ return Var(new ExpNodeOp(a));
+};
+
+struct NegNodeOp : public UnaryNodeOp {
+ NegNodeOp(ChainPtr a)
+ : UnaryNodeOp(Tensor(a->val().shape()), a) { }
+
+ void forward() {
+ Element(_1 = -_2, val_, a_->val());
+ }
+
+ void backward() {
+ Element(_1 += -_2, a_->grad(), adj_);
+ }
+};
+
+inline Var operator-(Var a) {
+ return Var(new NegNodeOp(a));
+};
+
+/******************************************************/
+
+struct BinaryNodeOp : public Node {
+ ChainPtr a_;
+ ChainPtr b_;
+
+ BinaryNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
+ : Node(t), a_(a), b_(b) {}
+};
+
+/*** Matrix Product ***/
+
+struct DotNodeOp : public BinaryNodeOp {
+ DotNodeOp(ChainPtr a, ChainPtr b) : BinaryNodeOp(Tensor(shape(a, b)), a, b) { }
+
+ Shape shape(ChainPtr a, ChainPtr b) {
+ UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0],
+ "matrix product requires dimensions to match");
+ Shape shape1 = a->val().shape();
+ Shape shape2 = b->val().shape();
+ shape1[1] = shape2[1];
+ return shape1;
+ }
+
+ void forward() {
+ // C = A*B
+ Prod(val_, a_->val(), b_->val(), false, false);
+ }
+
+ void backward() {
+ // D is the adjoint, the matrix of derivatives
+ // df/dA += D*B.T
+ // df/dB += A.T*D
+ // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
+ // to sum gradients from different graph parts
+ Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
+ Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
+ }
+};
+
+inline Var dot(Var a, Var b) {
+ return Var(new DotNodeOp(a, b));
+}
+
+/******************************************************/
+
+Var broadcast(Shape shape, Var a) {
+ if(a.val().shape() == shape) {
+ return a;
+ }
+ else {
+ size_t dimsA = a.val().shape().size();
+ size_t dimsB = shape.size();
+ UTIL_THROW_IF2(dimsA != dimsB,
+ "Tensor and shape have different number of dimensions");
+ for(size_t i = 0; i < dimsA; ++i) {
+ int dimA = a.val().shape()[i];
+ int dimB = shape[i];
+ bool broadcastable = (dimA == dimB || dimA == 1);
+ UTIL_THROW_IF2(!broadcastable,
+ "Cannot broadcast tensor dimension "
+ << dimA << " to " << dimB);
+ if(dimA == 1 && dimB > 1) {
+ std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl;
+ if(i == 0) {
+ Var one = Tensor({shape[0], 1}, 1);
+ a = dot(one, a);
+ }
+ else if(i == 1) {
+ Var one = Tensor({1, shape[1]}, 1);
+ a = dot(a, one);
+ }
+ else {
+ UTIL_THROW2("Not inplemented");
+ }
+ }
+ }
+ return a;
+ }
+}
+
+struct BroadcastingNodeOp : public BinaryNodeOp {
+ BroadcastingNodeOp(Var a, Var b)
+ : BroadcastingNodeOp(Tensor(shape(a ,b)), broadcast(shape(a ,b), a), broadcast(shape(a ,b), b)) {}
+
+ static Shape shape(ChainPtr a, ChainPtr b) {
+ size_t dimsA = a->val().shape().size();
+ size_t dimsB = b->val().shape().size();
+ UTIL_THROW_IF2(dimsA != dimsB,
+ "Tensors have different numbers of dimensions");
+ Shape shape(dimsA);
+ for(size_t i = 0; i < dimsA; ++i) {
+ int dimA = a->val().shape()[i];
+ int dimB = b->val().shape()[i];
+ bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
+ UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
+ << "operation cannot be broadcasted: " << dimA << " != " << dimB);
+ shape[i] = std::max(dimA, dimB);
+ }
+ return shape;
+ }
+
+ private:
+ BroadcastingNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
+ : BinaryNodeOp(t, a, b) {}
+};
+
+/*** Binary arithmetic ***/
+
+/*** Plus ***/
+
+struct PlusNodeOp : public BroadcastingNodeOp {
+ PlusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
+
+ void forward() {
+ Element(_1 = _2 + _3,
+ val_, a_->val(), b_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2,
+ a_->grad(), adj_);
+ Element(_1 += _2,
+ b_->grad(), adj_);
+ }
+};
+
+inline Var operator+(Var a, Var b) {
+ return Var(new PlusNodeOp(a, b));
+}
+
+/*** Minus ***/
+
+struct MinusNodeOp : public BroadcastingNodeOp {
+ MinusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
+
+ void forward() {
+ Element(_1 = _2 - _3,
+ val_, a_->val(), b_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2,
+ a_->grad(), adj_);
+ Element(_1 -= _2,
+ b_->grad(), adj_);
+ }
+};
+
+inline Var operator-(Var a, Var b) {
+ return Var(new MinusNodeOp(a, b));
+}
+
+/*** Mult ***/
+
+struct MultNodeOp : public BroadcastingNodeOp {
+ MultNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
+
+ void forward() {
+ Element(_1 = _2 * _3,
+ val_, a_->val(), b_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * _3,
+ a_->grad(), adj_, b_->val());
+ Element(_1 += _2 * _3,
+ b_->grad(), adj_, a_->val());
+ }
+};
+
+inline Var operator*(Var a, Var b) {
+ return Var(new MultNodeOp(a, b));
+}
+
+/*** Division ***/
+
+struct DivNodeOp : public BroadcastingNodeOp {
+ DivNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
+
+ void forward() {
+ Element(_1 = _2 / _3,
+ val_, a_->val(), b_->val());
+ }
+
+ void backward() {
+ Element(_1 += _2 * 1.0f / _3,
+ a_->grad(), adj_, b_->val());
+ Element(_1 -= _2 * _3 / (_4 * _4),
+ b_->grad(), adj_, a_->val(), b_->val());
+ }
+};
+
+inline Var operator/(Var a, Var b) {
+ return Var(new DivNodeOp(a, b));
+}
+
+
+/*** Reductions ***/
+
+enum Axis { undef, axis0, axis1, axis2, axis3 };
+
+// inefficient
+inline Var sum(Var a, Axis axis = Axis::undef) {
+ if(axis == Axis::axis0) {
+ int rows = a.val().shape()[0];
+ int cols = a.val().shape()[1];
+ Var one = Tensor({1, rows}, 1);
+ return dot(one, a);
+ }
+ else if(axis == Axis::axis1) {
+ int rows = a.val().shape()[0];
+ int cols = a.val().shape()[1];
+ Var one = Tensor({cols, 1}, 1);
+ return dot(a, one);
+ }
+ else if(axis == Axis::axis2) {
+ UTIL_THROW2("Not inplemented");
+ }
+ else if(axis == Axis::axis3) {
+ UTIL_THROW2("Not inplemented");
+ }
+ return sum(sum(a, Axis::axis0), Axis::axis1);
+}
+
+// inefficient
+inline Var softmax(Var a, Axis axis = Axis::undef) {
+ Var e = exp(a);
+ return e / sum(e, axis);
+}
+
+// inefficient
+inline Var mean(Var a, Axis axis = Axis::undef) {
+ switch (axis) {
+ case Axis::axis0:
+ return sum(a, axis) / a.val().shape()[0];
+ case Axis::axis1:
+ return sum(a, axis) / a.val().shape()[1];
+ case Axis::axis2:
+ UTIL_THROW2("Not implemented");
+ case Axis::axis3:
+ UTIL_THROW2("Not implemented");
+ case Axis::undef:
+ default:
+ return sum(a) / a.val().size();
+ }
+}
+
+// FAKE
+inline Var input(const std::string& name, Var v) {
+ return v;
+}
+
+inline Var forsave(const std::string& name, Var v) {
+ return v;
+}
+
+} \ No newline at end of file
diff --git a/src/tensor.h b/src/tensor.h
new file mode 100644
index 00000000..932278a1
--- /dev/null
+++ b/src/tensor.h
@@ -0,0 +1,117 @@
+#pragma once
+
+#include <memory>
+#include <functional>
+#include <vector>
+#include <cmath>
+
+namespace marian {
+
+class TensorImpl {
+ public:
+ typedef float value_type;
+
+ TensorImpl(size_t size, value_type value)
+ : data_(size, value), tno_(tensorCounter++)
+ {
+ std::cerr << "Allocating tensor " << tno_ << std::endl;
+ }
+
+ TensorImpl(const TensorImpl& t)
+ : data_(t.data_.begin(), t.data_.end())
+ {
+ std::cerr << "Copying tensor " << tno_ << std::endl;
+ }
+
+ ~TensorImpl() {
+ std::cerr << "Destroying tensor " << tno_ << std::endl;
+ }
+
+ size_t size() const {
+ return data_.size();
+ }
+
+ value_type* data() {
+ return data_.data();
+ }
+
+ const value_type* data() const {
+ return data_.data();
+ }
+
+ size_t id() const {
+ return tno_;
+ }
+
+ void set(value_type value) {
+ std::fill(data_.begin(), data_.end(), value);
+ }
+
+ private:
+ std::vector<value_type> data_;
+ size_t tno_;
+
+ static size_t tensorCounter;
+};
+
+size_t TensorImpl::tensorCounter = 0;
+
+class Tensor {
+ public:
+ typedef TensorImpl::value_type value_type;
+
+ Tensor(size_t size, float value)
+ : pimpl_(new TensorImpl(size, value)) {}
+
+ Tensor() {}
+
+ ~Tensor() {}
+
+ size_t size() const {
+ return pimpl_->size();
+ }
+
+ float* data() {
+ return pimpl_->data();
+ }
+
+ const float* data() const {
+ return pimpl_->data();
+ }
+
+ void set(float value) {
+ pimpl_->set(value);
+ }
+
+ size_t id() const {
+ return pimpl_->id();
+ }
+
+ private:
+ std::shared_ptr<TensorImpl> pimpl_;
+};
+
+Tensor operator+(const Tensor a, const Tensor b) {
+ Tensor c(a.size(), 0);
+ for(size_t i = 0; i < a.size(); ++i) {
+ c.data()[i] = a.data()[i] + b.data()[i];
+ }
+ return c;
+}
+
+Tensor operator*(const Tensor a, const Tensor b) {
+ Tensor c(a.size(), 0);
+ for(size_t i = 0; i < a.size(); ++i) {
+ c.data()[i] = a.data()[i] * b.data()[i];
+ }
+ return c;
+}
+
+Tensor operator+=(Tensor a, const Tensor b) {
+ for(size_t i = 0; i < a.size(); ++i) {
+ a.data()[i] += b.data()[i];
+ }
+ return a;
+}
+
+} \ No newline at end of file
diff --git a/src/test.cpp b/src/test.cpp
deleted file mode 100644
index 8d1e380c..00000000
--- a/src/test.cpp
+++ /dev/null
@@ -1,55 +0,0 @@
-#include <iostream>
-#include <ctime>
-
-#include <cuda_runtime.h>
-#include <device_launch_parameters.h>
-
-#include <cublas_v2.h>
-#include <cudnn.h>
-
-#include "marian.h"
-
-using namespace marian;
-
-Var layer(size_t max, std::vector<Var>& x) {
- Var x0 = rand() % 100, x1 = rand() % 100, x2 = rand() % 100;
- x = { x0, x1, x2 };
-
- Var y = 0.0;
- for(int i = 0; i < max; i++) {
- Var xi = i;
- x.push_back(xi);
- y = y + x0 + log(x2) + x1;
- for(int j = 0; j < i; ++j) {
- y = y + xi;
- }
- }
-
- return y;
-}
-
-int main(int argc, char** argv) {
- srand(time(NULL));
-
- std::vector<Var> x1, x2;
- Var y1 = layer(10, x1);
- Var y2 = layer(rand() % 20 + 1, x2);
-
- Var y = sigma(log(y1) / log(y2));
-
- set_zero_all_adjoints();
- y.calc_gradients();
-
- std::cerr << "y1 = " << y1.val() << std::endl;
- std::cerr << "y2 = " << y2.val() << std::endl;
- std::cerr << "y = " << y.val() << std::endl;
-
- std::cerr << "dy/dy1 = " << y1.grad() << std::endl;
- std::cerr << "dy/dy2 = " << y2.grad() << std::endl;
-
- for(size_t i = 0; i < x1.size(); ++i)
- std::cerr << "x1_" << i << " = " << x1[i].val() << " : dy/dx1_" << i << " = " << x1[i].grad() << std::endl;
- for(size_t i = 0; i < x2.size(); ++i)
- std::cerr << "x2_" << i << " = " << x2[i].val() << " : dy/dx2_" << i << " = " << x2[i].grad() << std::endl;
-
-} \ No newline at end of file
diff --git a/src/test.cu b/src/test.cu
new file mode 100644
index 00000000..db57cdc4
--- /dev/null
+++ b/src/test.cu
@@ -0,0 +1,71 @@
+#include <iostream>
+#include <ctime>
+#include <vector>
+#include <algorithm>
+#include <random>
+#include <boost/timer/timer.hpp>
+
+#include "marian.h"
+#include "operators.h"
+
+using namespace marian;
+
+int main(int argc, char** argv) {
+ boost::timer::auto_cpu_timer t;
+
+ Var x = input("X", Tensor({4, 2}));
+ Var y = input("Y", Tensor({4, 2}));
+
+ std::vector<float> vx = {
+ 0, 0,
+ 0, 1,
+ 1, 0,
+ 1, 1
+ };
+
+ std::vector<float> vy = {
+ 1, 0,
+ 1, 0,
+ 0, 1,
+ 1, 0
+ };
+
+ thrust::copy(vx.begin(), vx.end(), x.val().begin());
+ thrust::copy(vy.begin(), vy.end(), y.val().begin());
+
+ Var w0 = forsave("W0", uniform(Tensor({2, 2})));
+ Var b0 = forsave("b0", uniform(Tensor({1, 2})));
+
+ Var w1 = forsave("W1", uniform(Tensor({2, 2})));
+ Var b1 = forsave("b1", uniform(Tensor({1, 2})));
+
+ std::vector<Var> params = { w0, w1, b0, b1 };
+
+ Var ry = sigma(dot(x, w0) + b0);
+ ry = softmax(dot(ry, w1) + b1, Axis::axis1);
+ Var cost = -mean(sum(y * log(ry), Axis::axis1), Axis::axis0);
+
+ float alpha = 0.1;
+ for(size_t i = 0; i < 30000; ++i) {
+ cost.forward();
+
+ if(i % 100 == 0) {
+ for(size_t j = 0; j < 4; ++j) {
+ std::cerr << ry.val()[j*2] << std::endl;
+ }
+ std::cerr << i << " ct: " << cost.val()[0] << std::endl;
+ // alpha = alpha * 0.9;
+ }
+
+ cost.backward();
+ for(auto p : params) {
+ //std::cerr << p.grad()[0] << std::endl;
+ auto update =
+ _1 -= alpha * _2;
+
+ Element(update, p.val(), p.grad());
+ }
+ }
+
+ return 0;
+} \ No newline at end of file
diff --git a/src/thrust_functions.h b/src/thrust_functions.h
new file mode 100644
index 00000000..a3013423
--- /dev/null
+++ b/src/thrust_functions.h
@@ -0,0 +1,95 @@
+#pragma once
+
+#include <cmath>
+#include <cublas_v2.h>
+#include <thrust/device_vector.h>
+#include <thrust/functional.h>
+
+namespace thrust
+{
+ namespace detail
+ {
+ namespace functional
+ {
+
+ // Ugly hacks, but it seems this is neccessary.
+ __host__ __device__
+ float expf2(float x) {
+ float clip = 16;
+ if(x > clip)
+ x = clip;
+ if(x < -clip)
+ x = -clip;
+ return expf(x);
+ }
+
+ __host__ __device__
+ float logf2(float x) {
+ if(x < 10e-10)
+ x = 10e-10;
+ return logf(x);
+ }
+
+ template<typename T>
+ struct unary_exp : public thrust::unary_function<T,T> {
+ __host__ __device__
+ T operator()(const T &x) const { return expf2(x); }
+ };
+
+ template<typename Eval>
+ __host__ __device__
+ actor<composite<unary_operator<unary_exp>, actor<Eval>>>
+ Exp(const actor<Eval> &_1) {
+ return compose(unary_operator<unary_exp>(), _1);
+ }
+
+ template<typename T>
+ struct unary_log : public thrust::unary_function<T,T> {
+ __host__ __device__
+ T operator()(const T &x) const { return logf2(x); }
+ };
+
+ template<typename Eval>
+ __host__ __device__
+ actor<composite<unary_operator<unary_log>, actor<Eval>>>
+ Log(const actor<Eval> &_1) {
+ return compose(unary_operator<unary_log>(), _1);
+ }
+
+ template<typename T>
+ struct unary_sigma : public thrust::unary_function<T,T> {
+ __host__ __device__
+ T operator()(const T &x) const { return 1.0 / (1.0 + expf2(-x)); }
+ };
+
+ template<typename Eval>
+ __host__ __device__
+ actor<composite<unary_operator<unary_sigma>, actor<Eval>>>
+ Sigma(const actor<Eval> &_1) {
+ return compose(unary_operator<unary_sigma>(), _1);
+ }
+
+ template<typename T>
+ struct unary_tanh : public thrust::unary_function<T,T> {
+ __host__ __device__
+ T operator()(const T &x) const { return tanhf(x); }
+ };
+
+ template<typename Eval>
+ __host__ __device__
+ actor<composite<unary_operator<unary_tanh>, actor<Eval>>>
+ Tanh(const actor<Eval> &_1) {
+ return compose(unary_operator<unary_tanh>(), _1);
+ }
+
+ template<typename T1, typename T2>
+ __host__ __device__
+ actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>>
+ Max(const actor<T1> &_1, const actor<T2> &_2) {
+ return compose(binary_operator<thrust::maximum>(),
+ make_actor(_1),
+ make_actor(_2));
+ }
+ }
+ }
+} \ No newline at end of file