Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
diff options
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-28 21:00:43 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-04-28 21:00:43 +0300
commit7e76b61f88fcce9f1ef1c704ff646326ffebf43d (patch)
parent273487f496665c29af072bafcacd2a1773b31f8b (diff)
Towards YAML configurations
11 files changed, 379 insertions, 157 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1effd778..5823aa41 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,6 +18,13 @@ else(Boost_FOUND)
message(SEND_ERROR "Cannot find Boost libraries. Terminating." )
+find_package (YamlCpp)
+ include_directories(${YAMLCPP_INCLUDE_DIRS})
set(KENLM CACHE STRING "Path to compiled kenlm directory")
if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
diff --git a/cmake/FindYamlCpp.cmake b/cmake/FindYamlCpp.cmake
new file mode 100644
index 00000000..c099afae
--- /dev/null
+++ b/cmake/FindYamlCpp.cmake
@@ -0,0 +1,98 @@
+# Locate yaml-cpp
+# This module defines
+# YAMLCPP_FOUND, if false, do not try to link to yaml-cpp
+# YAMLCPP_LIBNAME, name of yaml library
+# YAMLCPP_LIBRARY, where to find yaml-cpp
+# YAMLCPP_LIBRARY_RELEASE, where to find Release or RelWithDebInfo yaml-cpp
+# YAMLCPP_LIBRARY_DEBUG, where to find Debug yaml-cpp
+# YAMLCPP_INCLUDE_DIR, where to find yaml.h
+# YAMLCPP_LIBRARY_DIR, the directories to find YAMLCPP_LIBRARY
+# By default, the dynamic libraries of yaml-cpp will be found. To find the static ones instead,
+# you must set the YAMLCPP_USE_STATIC_LIBS variable to TRUE before calling find_package(YamlCpp ...)
+# attempt to find static library first if this is set
+ set(YAMLCPP_STATIC libyaml-cpp.a)
+ set(YAMLCPP_STATIC_DEBUG libyaml-cpp-dbg.a)
+if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") ### Set Yaml libary name for Windows
+ set(YAMLCPP_LIBNAME "libyaml-cppmd" CACHE STRING "Name of YAML library")
+else() ### Set Yaml libary name for Unix, Linux, OS X, etc
+ set(YAMLCPP_LIBNAME "yaml-cpp" CACHE STRING "Name of YAML library")
+# find the yaml-cpp include directory
+ NAMES yaml-cpp/yaml.h
+ ${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/include
+ ~/Library/Frameworks/yaml-cpp/include/
+ /Library/Frameworks/yaml-cpp/include/
+ /usr/local/include/
+ /usr/include/
+ /sw/yaml-cpp/ # Fink
+ /opt/local/yaml-cpp/ # DarwinPorts
+ /opt/csw/yaml-cpp/ # Blastwave
+ /opt/yaml-cpp/)
+# find the release yaml-cpp library
+ NAMES ${YAMLCPP_STATIC} yaml-cpp libyaml-cppmd.lib
+ PATH_SUFFIXES lib64 lib Release RelWithDebInfo
+ ${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/
+ ${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/build
+ ~/Library/Frameworks
+ /Library/Frameworks
+ /usr/local
+ /usr
+ /sw
+ /opt/local
+ /opt/csw
+ /opt)
+# find the debug yaml-cpp library
+ NAMES ${YAMLCPP_STATIC_DEBUG} yaml-cpp-dbg libyaml-cppmdd.lib
+ PATH_SUFFIXES lib64 lib Debug
+ ${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/
+ ${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/build
+ ~/Library/Frameworks
+ /Library/Frameworks
+ /usr/local
+ /usr
+ /sw
+ /opt/local
+ /opt/csw
+ /opt)
+# set library vars
+# handle the QUIETLY and REQUIRED arguments and set YAMLCPP_FOUND to TRUE if all listed variables are TRUE
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 4578e5e5..7b2e4c42 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -15,6 +15,7 @@ add_library(librescorer OBJECT
add_library(libamunn OBJECT
+ decoder/config.cpp
@@ -39,7 +40,7 @@ cuda_add_executable(
foreach(exec amunn rescorer)
- target_link_libraries(${exec} ${EXT_LIBS})
+ target_link_libraries(${exec} ${EXT_LIBS} cuda)
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
diff --git a/src/decoder/config.cpp b/src/decoder/config.cpp
new file mode 100644
index 00000000..f67ba428
--- /dev/null
+++ b/src/decoder/config.cpp
@@ -0,0 +1,173 @@
+#include <set>
+#include "config.h"
+#define SET_OPTION(key, type) \
+if(!vm_[key].defaulted() || !config_[key]) { \
+ config_[key] = vm_[key].as<type>(); \
+#define SET_OPTION_NONDEFAULT(key, type) \
+if(vm_.count(key) > 0) { \
+ config_[key] = vm_[key].as<type>(); \
+bool Config::Has(const std::string& key) {
+ return config_[key];
+YAML::Node& Config::Get() {
+ return config_;
+void Config::AddOptions(size_t argc, char** argv) {
+ namespace po = boost::program_options;
+ po::options_description general("General options");
+ std::string configPath;
+ std::vector<size_t> devices;
+ std::vector<size_t> tabMap;
+ std::vector<float> weights;
+ std::vector<std::string> modelPaths;
+ std::vector<std::string> lmPaths;
+ std::vector<std::string> sourceVocabPaths;
+ std::string targetVocabPath;
+ general.add_options()
+ ("config,c", po::value(&configPath),
+ "Configuration file")
+ ("model,m", po::value(&modelPaths)->multitoken()->required(),
+ "Path to neural translation model(s)")
+ ("source,s", po::value(&sourceVocabPaths)->multitoken()->required(),
+ "Path to source vocabulary file.")
+ ("target,t", po::value(&targetVocabPath)->required(),
+ "Path to target vocabulary file.")
+ ("ape", po::value<bool>()->zero_tokens()->default_value(false),
+ "Add APE-penalty")
+ ("lm,l", po::value(&lmPaths)->multitoken(),
+ "Path to KenLM language model(s)")
+ ("tab-map", po::value(&tabMap)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
+ "tab map")
+ ("devices,d", po::value(&devices)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
+ "CUDA device(s) to use, set to 0 by default, "
+ "e.g. set to 0 1 to use gpu0 and gpu1. "
+ "Implicitly sets minimal number of threads to number of devices.")
+ ("threads-per-device", po::value<size_t>()->default_value(1),
+ "Number of threads per device, total thread count equals threads x devices")
+ ("help,h", po::value<bool>()->zero_tokens()->default_value(false),
+ "Print this help message and exit")
+ ;
+ po::options_description search("Search options");
+ search.add_options()
+ ("beam-size,b", po::value<size_t>()->default_value(12),
+ "Decoding beam-size")
+ ("normalize,n", po::value<bool>()->zero_tokens()->default_value(false),
+ "Normalize scores by translation length after decoding")
+ ("n-best", po::value<bool>()->zero_tokens()->default_value(false),
+ "Output n-best list with n = beam-size")
+ ("weights,w", po::value(&weights)->multitoken()->default_value(std::vector<float>(1, 1.0), "1.0"),
+ "Model weights (for neural models and KenLM models)")
+ ("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
+ "Output used weights to stdout and exit")
+ ("load-weights", po::value<std::string>(),
+ "Load scorer weights from this file")
+ ;
+ po::options_description kenlm("KenLM specific options");
+ kenlm.add_options()
+ ("kenlm-batch-size", po::value<size_t>()->default_value(1000),
+ "Batch size for batched queries to KenLM")
+ ("kenlm-batch-threads", po::value<size_t>()->default_value(4),
+ "Concurrent worker threads for batch processing")
+ ;
+ po::options_description cmdline_options("Allowed options");
+ cmdline_options.add(general);
+ cmdline_options.add(search);
+ cmdline_options.add(kenlm);
+ po::variables_map vm_;
+ try {
+ po::store(po::command_line_parser(argc,argv)
+ .options(cmdline_options).run(), vm_);
+ po::notify(vm_);
+ }
+ catch (std::exception& e) {
+ std::cerr << "Error: " << e.what() << std::endl << std::endl;
+ std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
+ std::cerr << cmdline_options << std::endl;
+ exit(1);
+ }
+ if (vm_["help"].as<bool>()) {
+ std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
+ std::cerr << cmdline_options << std::endl;
+ exit(0);
+ }
+ if(configPath.size())
+ config_ = YAML::LoadFile(configPath);
+ SET_OPTION("model", std::vector<std::string>)
+ SET_OPTION_NONDEFAULT("lm", std::vector<std::string>)
+ SET_OPTION("ape", bool)
+ SET_OPTION("source", std::vector<std::string>)
+ SET_OPTION("target", std::string)
+ SET_OPTION("n-best", bool)
+ SET_OPTION("normalize", bool)
+ SET_OPTION("beam-size", size_t)
+ SET_OPTION("threads-per-device", size_t)
+ SET_OPTION("devices", std::vector<size_t>)
+ SET_OPTION("tab-map", std::vector<size_t>)
+ SET_OPTION("weights", std::vector<float>)
+ SET_OPTION("show-weights", bool)
+ SET_OPTION_NONDEFAULT("load-weights", std::string)
+ SET_OPTION("kenlm-batch-size", size_t)
+ SET_OPTION("kenlm-batch-threads", size_t)
+void OutputRec(const YAML::Node node, YAML::Emitter& out) {
+ std::set<std::string> flow = { "weights", "devices", "tab-map" };
+ std::set<std::string> sorter;
+ switch (node.Type()) {
+ case YAML::NodeType::Null:
+ out << node; break;
+ case YAML::NodeType::Scalar:
+ out << node; break;
+ case YAML::NodeType::Sequence:
+ out << YAML::BeginSeq;
+ for(auto&& n : node)
+ OutputRec(n, out);
+ out << YAML::EndSeq;
+ break;
+ case YAML::NodeType::Map:
+ for(auto& n : node)
+ sorter.insert(n.first.as<std::string>());
+ out << YAML::BeginMap;
+ for(auto& key : sorter) {
+ out << YAML::Key;
+ out << key;
+ out << YAML::Value;
+ if(flow.count(key))
+ out << YAML::Flow;
+ OutputRec(node[key], out);
+ }
+ out << YAML::EndMap;
+ break;
+ case YAML::NodeType::Undefined:
+ out << node; break;
+ }
+void Config::LogOptions() {
+ std::stringstream ss;
+ YAML::Emitter out;
+ OutputRec(config_, out);
+ LOG(info) << "Options: \n" << out.c_str();
diff --git a/src/decoder/config.h b/src/decoder/config.h
new file mode 100644
index 00000000..2a973ff9
--- /dev/null
+++ b/src/decoder/config.h
@@ -0,0 +1,31 @@
+#pragma once
+#include <yaml-cpp/yaml.h>
+#include <boost/program_options.hpp>
+#include "logging.h"
+class Config {
+ private:
+ YAML::Node config_;
+ public:
+ bool Has(const std::string& key);
+ template <typename T>
+ T Get(const std::string& key) {
+ return config_[key].as<T>();
+ }
+ YAML::Node& Get();
+ void AddOptions(size_t argc, char** argv);
+ template <class OStream>
+ friend OStream& operator<<(OStream& out, const Config& config) {
+ out << config.config_;
+ return out;
+ }
+ void LogOptions();
diff --git a/src/decoder/god.cu b/src/decoder/god.cu
index 28b49e0e..ce8323dc 100644
--- a/src/decoder/god.cu
+++ b/src/decoder/god.cu
@@ -1,7 +1,10 @@
#include <vector>
#include <sstream>
+#include <yaml-cpp/yaml.h>
#include "god.h"
+#include "config.h"
#include "scorer.h"
#include "threadpool.h"
#include "encoder_decoder.h"
@@ -10,16 +13,6 @@
God God::instance_;
-God& God::Init(const std::string& initString) {
- std::vector<std::string> args = po::split_unix(initString);
- int argc = args.size() + 1;
- char* argv[argc];
- argv[0] = const_cast<char*>("dummy");
- for(int i = 1; i < argc; i++)
- argv[i] = const_cast<char*>(args[i-1].c_str());
- return Init(argc, argv);
God& God::Init(int argc, char** argv) {
return Summon().NonStaticInit(argc, argv);
@@ -31,103 +24,24 @@ God& God::NonStaticInit(int argc, char** argv) {
progress_ = spdlog::stderr_logger_mt("progress");
- po::options_description general("General options");
- std::vector<size_t> devices;
- std::vector<std::string> modelPaths;
- std::vector<std::string> lmPaths;
- std::vector<std::string> sourceVocabPaths;
- std::string targetVocabPath;
- general.add_options()
- ("model,m", po::value(&modelPaths)->multitoken()->required(),
- "Path to neural translation model(s)")
- ("source,s", po::value(&sourceVocabPaths)->multitoken()->required(),
- "Path to source vocabulary file.")
- ("target,t", po::value(&targetVocabPath)->required(),
- "Path to target vocabulary file.")
- ("ape", po::value<bool>()->zero_tokens()->default_value(false),
- "Add APE-penalty")
- ("lm,l", po::value(&lmPaths)->multitoken(),
- "Path to KenLM language model(s)")
- ("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
- "tab map")
- ("devices,d", po::value(&devices)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
- "CUDA device(s) to use, set to 0 by default, "
- "e.g. set to 0 1 to use gpu0 and gpu1. "
- "Implicitly sets minimal number of threads to number of devices.")
- ("threads-per-device", po::value<size_t>()->default_value(1),
- "Number of threads per device, total thread count equals threads x devices")
- ("help,h", po::value<bool>()->zero_tokens()->default_value(false),
- "Print this help message and exit")
- ;
- po::options_description search("Search options");
- search.add_options()
- ("beam-size,b", po::value<size_t>()->default_value(12),
- "Decoding beam-size")
- ("normalize,n", po::value<bool>()->zero_tokens()->default_value(false),
- "Normalize scores by translation length after decoding")
- ("n-best", po::value<bool>()->zero_tokens()->default_value(false),
- "Output n-best list with n = beam-size")
- ("weights,w", po::value(&weights_)->multitoken()->default_value(std::vector<float>(1, 1.0), "1.0"),
- "Model weights (for neural models and KenLM models)")
- ("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
- "Output used weights to stdout and exit")
- ("load-weights", po::value<std::string>(),
- "Load scorer weights from this file")
- ;
- po::options_description kenlm("KenLM specific options");
- kenlm.add_options()
- ("kenlm-batch-size", po::value<size_t>()->default_value(1000),
- "Batch size for batched queries to KenLM")
- ("kenlm-batch-threads", po::value<size_t>()->default_value(4),
- "Concurrent worker threads for batch processing")
- ;
- po::options_description cmdline_options("Allowed options");
- cmdline_options.add(general);
- cmdline_options.add(search);
- cmdline_options.add(kenlm);
- try {
- po::store(po::command_line_parser(argc,argv)
- .options(cmdline_options).run(), vm_);
- po::notify(vm_);
- }
- catch (std::exception& e) {
- std::cerr << "Error: " << e.what() << std::endl << std::endl;
- std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
- std::cerr << cmdline_options << std::endl;
- exit(1);
- }
- if (Get<bool>("help")) {
- std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
- std::cerr << cmdline_options << std::endl;
- exit(0);
- }
- PrintConfig();
+ config_.AddOptions(argc, argv);
+ config_.LogOptions();
- for(auto& sourceVocabPath : sourceVocabPaths)
+ for(auto sourceVocabPath : Get<std::vector<std::string>>("source"))
sourceVocabs_.emplace_back(new Vocab(sourceVocabPath));
- targetVocab_.reset(new Vocab(targetVocabPath));
- if(devices.empty()) {
- LOG(info) << "empty";
- devices.push_back(0);
- }
+ targetVocab_.reset(new Vocab(Get<std::string>("target")));
+ auto modelPaths = Get<std::vector<std::string>>("model");
+ tabMap_ = Get<std::vector<size_t>>("tab-map");
if(tabMap_.size() < modelPaths.size()) {
// this should be a warning
- LOG(info) << "More neural models than weights, setting missing tabs to 0";
+ LOG(info) << "More neural models than tabs, setting missing tabs to 0";
tabMap_.resize(modelPaths.size(), 0);
// @TODO: handle this better!
+ weights_ = Get<std::vector<float>>("weights");
if(weights_.size() < modelPaths.size()) {
// this should be a warning
LOG(info) << "More neural models than weights, setting weights to 1.0";
@@ -139,11 +53,11 @@ God& God::NonStaticInit(int argc, char** argv) {
weights_.resize(modelPaths.size(), 1.0);
- if(weights_.size() < modelPaths.size() + lmPaths.size()) {
- // this should be a warning
- LOG(info) << "More KenLM models than weights, setting weights to 0.0";
- weights_.resize(weights_.size() + lmPaths.size(), 0.0);
- }
+ //if(weights_.size() < modelPaths.size() + lmPaths.size()) {
+ // // this should be a warning
+ // LOG(info) << "More KenLM models than weights, setting weights to 0.0";
+ // weights_.resize(weights_.size() + lmPaths.size(), 0.0);
+ //}
if(Has("load-weights")) {
@@ -157,6 +71,7 @@ God& God::NonStaticInit(int argc, char** argv) {
+ auto devices = Get<std::vector<size_t>>("devices");
ThreadPool devicePool(devices.size());
@@ -171,10 +86,10 @@ God& God::NonStaticInit(int argc, char** argv) {
- for(auto& lmPath : lmPaths) {
- LOG(info) << "Loading lm " << lmPath;
- lms_.emplace_back(lmPath, *targetVocab_);
- }
+ //for(auto& lmPath : lmPaths) {
+ // LOG(info) << "Loading lm " << lmPath;
+ // lms_.emplace_back(lmPath, *targetVocab_);
+ //}
return *this;
@@ -230,34 +145,3 @@ void God::LoadWeights(const std::string& path) {
-void God::PrintConfig() {
- LOG(info) << "Options set: ";
- for(auto& entry: instance_.vm_) {
- std::stringstream ss;
- ss << "\t" << entry.first << " = ";
- try {
- for(auto& v : entry.second.as<std::vector<std::string>>())
- ss << v << " ";
- } catch(...) { }
- try {
- for(auto& v : entry.second.as<std::vector<float>>())
- ss << v << " ";
- } catch(...) { }
- try {
- for(auto& v : entry.second.as<std::vector<size_t>>())
- ss << v << " ";
- } catch(...) { }
- try {
- ss << entry.second.as<std::string>();
- } catch(...) { }
- try {
- ss << entry.second.as<bool>() ? "true" : "false";
- } catch(...) { }
- try {
- ss << entry.second.as<size_t>();
- } catch(...) { }
- LOG(info) << ss.str();
- }
-} \ No newline at end of file
diff --git a/src/decoder/god.h b/src/decoder/god.h
index 517b5d75..714075ef 100644
--- a/src/decoder/god.h
+++ b/src/decoder/god.h
@@ -1,7 +1,7 @@
#pragma once
-#include <boost/program_options.hpp>
+#include "config.h"
#include "types.h"
#include "vocab.h"
#include "scorer.h"
@@ -9,8 +9,6 @@
// this should not be here
#include "kenlm.h"
-namespace po = boost::program_options;
class Weights;
@@ -25,12 +23,12 @@ class God {
static bool Has(const std::string& key) {
- return instance_.vm_.count(key) > 0;
+ return Summon().config_.Has(key);
template <typename T>
static T Get(const std::string& key) {
- return instance_.vm_[key].as<T>();
+ return Summon().config_.Get<T>(key);
static Vocab& GetSourceVocab(size_t i = 0);
@@ -40,7 +38,6 @@ class God {
static std::vector<size_t>& GetTabMap();
static void CleanUp();
- static void PrintConfig();
void LoadWeights(const std::string& path);
@@ -48,7 +45,7 @@ class God {
God& NonStaticInit(int argc, char** argv);
static God instance_;
- po::variables_map vm_;
+ Config config_;
std::vector<std::unique_ptr<Vocab>> sourceVocabs_;
std::unique_ptr<Vocab> targetVocab_;
diff --git a/src/dl4mt/encoder.h b/src/dl4mt/encoder.h
index a95853d3..0484518b 100644
--- a/src/dl4mt/encoder.h
+++ b/src/dl4mt/encoder.h
@@ -15,7 +15,10 @@ class Encoder {
void Lookup(mblas::Matrix& Row, size_t i) {
using namespace mblas;
- CopyRow(Row, w_.E_, i);
+ if(i < w_.E_.Rows())
+ CopyRow(Row, w_.E_, i);
+ else
+ CopyRow(Row, w_.E_, 1); // UNK
diff --git a/src/dl4mt/gru.h b/src/dl4mt/gru.h
index 5986ca48..dcb4d240 100644
--- a/src/dl4mt/gru.h
+++ b/src/dl4mt/gru.h
@@ -13,7 +13,7 @@ class SlowGRU {
const mblas::Matrix& Context) const {
using namespace mblas;
- const size_t cols = State.Cols();
+ const size_t cols = GetStateLength();
// @TODO: Optimization
// @TODO: Launch streams to perform GEMMs in parallel
@@ -76,11 +76,11 @@ class FastGRU {
FastGRU(const Weights& model)
: w_(model) {
- for(int i = 0; i < 4; ++i) {
+ /*for(int i = 0; i < 4; ++i) {
cublasSetStream(h_[i], s_[i]);
- }
+ }*/
void GetNextState(mblas::Matrix& NextState,
@@ -88,7 +88,7 @@ class FastGRU {
const mblas::Matrix& Context) const {
using namespace mblas;
- const size_t cols = State.Cols();
+ const size_t cols = GetStateLength();
// @TODO: Optimization
// @TODO: Launch streams to perform GEMMs in parallel
diff --git a/src/mblas/matrix.cu b/src/mblas/matrix.cu
index 11b6beb6..91b8a95c 100644
--- a/src/mblas/matrix.cu
+++ b/src/mblas/matrix.cu
@@ -171,28 +171,53 @@ Matrix& Prod(cublasHandle_t handle, Matrix& C, const Matrix& A, const Matrix& B,
Matrix::value_type alpha = 1.0;
Matrix::value_type beta = 0.0;
+ //size_t m = A.Rows();
+ //size_t k = A.Cols();
+ ////if(transA)
+ //// std::swap(m, k);
+ //
+ //size_t l = B.Rows();
+ //size_t n = B.Cols();
+ ////if(transB)
+ //// std::swap(l, n);
+ //
+ //C.Resize(m, n);
+ //
+ //size_t lda = A.Cols();
+ //size_t ldb = B.Cols();
+ //size_t ldc = C.Cols();
+ //
+ //nervana_sgemm(const_cast<float*>(A.data()),
+ // const_cast<float*>(B.data()),
+ // C.data(),
+ // transA, transB,
+ // m, n, k,
+ // lda, ldb, ldc,
+ // alpha, beta,
+ // 0, false, false, 0);
size_t m = A.Rows();
size_t k = A.Cols();
std::swap(m, k);
size_t l = B.Rows();
size_t n = B.Cols();
std::swap(l, n);
size_t lda = A.Cols();
size_t ldb = B.Cols();
size_t ldc = B.Cols();
ldc = B.Rows();
C.Resize(m, n);
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasSgemm(handle, opB, opA,
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
return C;
diff --git a/src/mblas/matrix.h b/src/mblas/matrix.h
index 6159d7ad..9cd601ec 100644
--- a/src/mblas/matrix.h
+++ b/src/mblas/matrix.h
@@ -11,6 +11,9 @@
#include <thrust/device_vector.h>
#include <thrust/functional.h>
+//#include "nervana_c_api.h"
#include "thrust_functions.h"
namespace lib = thrust;