Welcome to mirror list, hosted at ThFree Co, Russian Federation.

corpus_sqlite.cpp « data « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 297847c04a6b52e77876592160f4e59853ef3d02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <random>

#include "data/corpus_sqlite.h"

namespace marian {
namespace data {

CorpusSQLite::CorpusSQLite(Ptr<Options> options, bool translate /*= false*/, size_t seed /*= Config:seed*/)
    : CorpusBase(options, translate, seed), seed_(seed) {
  fillSQLite();
}

CorpusSQLite::CorpusSQLite(const std::vector<std::string>& paths,
                           const std::vector<Ptr<Vocab>>& vocabs,
                           Ptr<Options> options, size_t seed)
    : CorpusBase(paths, vocabs, options, seed), seed_(seed) {
  fillSQLite();
}

void CorpusSQLite::fillSQLite() {
  auto tempDir = options_->get<std::string>("tempdir");
  bool fill = false;

  // create a temporary or persistent SQLite database
  if(options_->get<std::string>("sqlite") == "temporary") {
    LOG(info, "[sqlite] Creating temporary database in {}", tempDir);

    db_.reset(new SQLite::Database("", SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
    db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");

    fill = true;
  } else {
    auto path = options_->get<std::string>("sqlite");

    if(filesystem::exists(path)) {
      LOG(info, "[sqlite] Reusing persistent database {}", path);

      db_.reset(new SQLite::Database(path, SQLite::OPEN_READWRITE));
      db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");

      if(options_->get<bool>("sqlite-drop")) {
        LOG(info, "[sqlite] Dropping previous data");
        db_->exec("drop table if exists lines");
        fill = true;
      }
    } else {
      LOG(info, "[sqlite] Creating persistent database {}", path);

      db_.reset(new SQLite::Database(path, SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
      db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");

      fill = true;
    }
  }

  // populate tables with lines from text files
  if(fill) {
    std::string createStr = "create table lines (_id integer";
    std::string insertStr = "insert into lines values (?";
    for(size_t i = 0; i < files_.size(); ++i) {
      createStr += ", line" + std::to_string(i) + " text";
      insertStr += ", ?";
    }
    createStr += ");";
    insertStr += ");";

    db_->exec(createStr);

    SQLite::Statement ps(*db_, insertStr);

    int lines = 0;
    int report = 1000000;
    bool cont = true;

    db_->exec("begin;");
    while(cont) {
      ps.bind(1, (int)lines);

      std::string line;
      for(size_t i = 0; i < files_.size(); ++i) {
        cont = cont && io::getline(*files_[i], line);
        if(cont)
          ps.bind((int)(i + 2), line);
      }

      if(cont) {
        ps.exec();
        ps.reset();
      }
      lines++;

      if(lines % report == 0) {
        LOG(info, "[sqlite] Inserted {} lines", lines);
        db_->exec("commit;");
        db_->exec("begin;");
        report *= 2;
      }
    }
    db_->exec("commit;");
    LOG(info, "[sqlite] Inserted {} lines", lines - 1);
    LOG(info, "[sqlite] Creating primary index");
    db_->exec("create unique index idx_line on lines (_id);");
  }

  createRandomFunction();
}

SentenceTuple CorpusSQLite::next() {
  while(select_->executeStep()) {
    // fill up the sentence tuple with sentences from all input files
    size_t curId = select_->getColumn(0).getInt();
    SentenceTuple tup(curId);

    for(size_t i = 0; i < files_.size(); ++i) {
      auto line = select_->getColumn((int)(i + 1));

      if(i > 0 && i == alignFileIdx_) {
        addAlignmentToSentenceTuple(line, tup);
      } else if(i > 0 && i == weightFileIdx_) {
        addWeightsToSentenceTuple(line, tup);
      } else {
        addWordsToSentenceTuple(line, i, tup);
      }
    }

    if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
         return words.size() > 0 && words.size() <= maxLength_;
       }))
      return tup;
  }
  return SentenceTuple(0);
}

void CorpusSQLite::shuffle() {
  LOG(info, "[sqlite] Selecting shuffled data");
  select_.reset(new SQLite::Statement(
      *db_, "select * from lines order by random_seed(" + std::to_string(seed_) + ");"));
}

void CorpusSQLite::reset() {
  select_.reset(
      new SQLite::Statement(*db_, "select * from lines order by _id;"));
}

void CorpusSQLite::restore(Ptr<TrainingState> ts) {
  for(size_t i = 0; i < ts->epochs - 1; ++i) {
    select_.reset(new SQLite::Statement(
        *db_, "select _id from lines order by random_seed(" + std::to_string(seed_) + ");"));
    select_->executeStep();
    reset();
  }
}
}  // namespace data
}  // namespace marian