Welcome to mirror list, hosted at ThFree Co, Russian Federation.

vocab_base.h « data « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8c214c97e48d912e2cd685e8c104db4309065283 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#pragma once

#include "data/types.h"
#include "common/definitions.h"
#include "common/utils.h"
#include "common/file_stream.h"

namespace marian {

class IVocab {
public:
  virtual size_t load(const std::string& vocabPath, size_t maxSize = 0) = 0;

  virtual void create(const std::string& vocabPath,
                      const std::vector<std::string>& trainPaths,
                      size_t maxSize) = 0;

  // return canonical suffix for given type of vocabulary
  virtual const std::string& canonicalExtension() const = 0;
  virtual const std::vector<std::string>& suffixes() const = 0;

  size_t findAndLoad(const std::string& path, size_t maxSize) { // @TODO: Only used in one place; just inline it there -> true interface
    for(auto suffix : suffixes())
      if(filesystem::exists(path + suffix))
        return load(path + suffix, maxSize);
    return 0;
  }

  virtual Word operator[](const std::string& word) const = 0;

  virtual Words encode(const std::string& line,
                       bool addEOS = true,
                       bool inference = false) const = 0;

  virtual std::string decode(const Words& sentence,
                             bool ignoreEos = true) const = 0;
  virtual std::string surfaceForm(const Words& sentence) const = 0;

  virtual const std::string& operator[](Word id) const = 0;

  virtual size_t size() const = 0;
  virtual std::string type() const = 0;

  virtual Word getEosId() const = 0;
  virtual Word getUnkId() const = 0;

  // without specific knowledge of tokenization, these two functions can do nothing
  // Both SentencePieceVocab and FactoredSegmenterVocab
  virtual std::string toUpper(const std::string& line) const { return line; }
  virtual std::string toEnglishTitleCase(const std::string& line) const { return line; }

  // this function is an identity mapping for default vocabularies, hence do nothing
  virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { ptr; num; }

  virtual void createFake() = 0;

  virtual Word randWord() const {
    return Word::fromWordIndex(rand() % size());
  }
  virtual ~IVocab() {};
};

class Options;
Ptr<IVocab> createDefaultVocab();
Ptr<IVocab> createClassVocab();
Ptr<IVocab> createSentencePieceVocab(const std::string& vocabPath, Ptr<Options>, size_t batchIndex);
Ptr<IVocab> createFactoredVocab(const std::string& vocabPath);

}