Welcome to mirror list, hosted at ThFree Co, Russian Federation.

perfectHash.h « DynSAInclude « src « moses - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 2e3bfbc8fa61d84221285bf812a5a8852eb82b6e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
/* NO OVERLAY VALUES STORED IN SEPERATE FILTER */
#ifndef INC_PERFECTHASH_H
#define INC_PERFECTHASH_H

#include <map>
#include <stdint.h>
#include "hash.h"
#include "RandLMFilter.h"
#include "quantizer.h"
/*
 * PerfectHash handles setting up hash functions and storage 
 * for LM data. 
 */ 
using randlm::Filter;
using randlm::BitFilter;
typedef std::map<string, count_t> hpDict_t;
typedef hpDict_t::iterator hpdEntry_t;
static count_t collisions_ = 0;
/* Based on Mortenson et. al. 2006 */ 
template<typename T>
class PerfectHash {
public:
  PerfectHash(uint16_t MBs, int width, int bucketRange, float qBase);
  PerfectHash(FileHandler* fin) {
    CHECK(fin != 0);
  }
  virtual ~PerfectHash();
  void analyze();
  count_t hpDictMemUse();
  count_t bucketsMemUse();
protected:
  Filter<T>* filter_;
  Filter<T>* values_;
  hpDict_t dict_;
  uint64_t cells_;
  count_t hitMask_;
  int totBuckets_;
  uint8_t bucketRange_;
  uint8_t* idxTracker_;
  uint64_t insert(const wordID_t* IDs, const int len, const count_t value);
  bool update(const wordID_t* IDs, const int len, const count_t value,
    hpdEntry_t& hpdAddr, uint64_t& filterIdx);
  bool update2(const wordID_t* IDs, const int len, const count_t value,
    hpdEntry_t& hpdAddr, uint64_t& filterIdx);
  int query(const wordID_t* IDs, const int len, 
    hpdEntry_t& hpdAddr, uint64_t& filterIdx);
  virtual void remove(const wordID_t* IDs, const int len);
  void remove(uint64_t index);
  void save(FileHandler* fout);
  void load(FileHandler* fin);
  virtual void markQueried(const uint64_t&)=0;
  //pointer to a specific entry in a hpDict_t
  virtual void markQueried(hpdEntry_t&)=0;
private:
  T nonZeroSignature(const wordID_t* IDs, const int len, count_t bucket); 
  string hpDictKeyValue(const wordID_t* IDs, const int len);
  uint64_t memBound_; // total memory bound in bytes
  uint16_t cellWidth_; // in bits
  UnivHash_linear<count_t>* bucketHash_; 
  UnivHash_linear<T>* fingerHash_;
  LogQtizer* qtizer_;
};
template<typename T>
PerfectHash<T>::PerfectHash(uint16_t MBs, int width, int bucketRange, 
  float qBase): hitMask_(1 << 31), memBound_(MBs * (1ULL << 20)), 
  cellWidth_(width) {
  bucketRange_ = static_cast<uint8_t>(bucketRange);
  if(bucketRange > 255) {
    cerr << "ERROR: Max bucket range is > 2^8\n"; 
    exit(1);
  }
  qtizer_ = new LogQtizer(qBase);
  int valBits = (int)ceil(log2((float)qtizer_->maxcode()));
  cerr << "BITS FOR VALUES ARRAY = " << valBits << endl;
  uint64_t totalBits = memBound_ << 3;
  cells_ = (uint64_t) ceil((float)totalBits / (float)(cellWidth_ + valBits)); // upper bound on cells  
  cells_ += (cells_ % bucketRange_); // make cells multiple of bucket range
  totBuckets_ = (cells_ / bucketRange_) - 1; // minus 1 so totBuckets * bucksize + bucksize = cells
  filter_ = new Filter<T>(cells_, cellWidth_);
  values_ = new Filter<T>(cells_, valBits); 
  idxTracker_ = new uint8_t[totBuckets_];
  for(int i=0; i < totBuckets_; ++i) idxTracker_[i] = 0;
  // initialize ranges for each hash function
  bucketHash_ = new UnivHash_linear<count_t>(totBuckets_, 1, PRIME);
  fingerHash_ = new UnivHash_linear<T>(pow(2.0f, cellWidth_), MAX_HASH_FUNCS, PRIME);
}
template<typename T>
PerfectHash<T>::~PerfectHash() {
  delete[] idxTracker_;
  delete filter_;
  filter_ = NULL;
  delete fingerHash_;
  delete bucketHash_;
  delete qtizer_;
  delete values_;
}
template<typename T>  
uint64_t PerfectHash<T>::insert(const wordID_t* IDs, const int len, 
                            const count_t value) {
  count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len));
  if(idxTracker_[bucket] < (int)bucketRange_) {  // if empty rows 
    // restriction on fprint value is non-zero
    T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
    uint64_t emptyidx = cells_ + 1;
    uint64_t index = bucket * bucketRange_,  // starting bucket row
            lastrow = index + bucketRange_;  // ending row 
    while(index < lastrow) { // unique so check each row for "matching" signature 
      T filterVal = filter_->read(index);
      if((filterVal == 0) && (emptyidx == cells_ + 1)) {  // record first empty row 
        emptyidx = index;
      }
      else if(filterVal == fp) { 
        ++collisions_;
        dict_[hpDictKeyValue(IDs, len)] = value; // store exact in hpd
        return cells_ + 1;  // finished
      }
      ++index;
    }
    CHECK((emptyidx < index) && (filter_->read(emptyidx) == 0)); // should have found empty index if it gets here
    T code = (T)qtizer_->code(value);
    filter_->write(emptyidx, fp); // insert the fprint
    values_->write(emptyidx, code);
    ++idxTracker_[bucket]; // keep track of bucket size
    return emptyidx;
  }
  else { // bucket is full
    dict_[hpDictKeyValue(IDs, len)] = value; // add to hpd
    return cells_ + 1;
  }
}
template<typename T>  
bool PerfectHash<T>::update(const wordID_t* IDs, const int len, 
  const count_t value, hpdEntry_t& hpdAddr, uint64_t& filterIdx) {
  // check if key is in high perf. dictionary
  filterIdx = cells_ + 1;
  string skey = hpDictKeyValue(IDs, len);
  if((hpdAddr = dict_.find(skey)) != dict_.end()) {
    hpdAddr->second = value;  
    return true;
  }
  // else hash ngram
  //count_t bucket = bucketHash_->hash(IDs, len);
  count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len));
  // restriction on fprint value is non-zero
  T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
  uint64_t index = bucket * bucketRange_,  // starting bucket row
          lastrow = index + bucketRange_; 
  while(index < lastrow) { // must check each row for matching fp event
    T filterVal = filter_->read(index);
    if(filterVal == fp) { // found event w.h.p.
      values_->write(index, (T)qtizer_->code(value)); 
      filterIdx = index;
      return true;
    }
    ++index;
  }
  // could add if it gets here. 
  return false;
}
template<typename T>   
int PerfectHash<T>::query(const wordID_t* IDs, const int len, 
  hpdEntry_t& hpdAddr, uint64_t& filterIdx) {
  // check if key is in high perf. dictionary
  string skey = hpDictKeyValue(IDs, len);
  if((hpdAddr = dict_.find(skey)) != dict_.end()) {
    filterIdx = cells_ + 1;
    return(hpdAddr->second);  // returns copy of value
  }
  else {  // check if key is in filter
    // get bucket 
    //count_t bucket = bucketHash_->hash(IDs, len);
    count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len));
    // restriction on fprint value is non-zero
    T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
    // return value if ngram is in filter
    uint64_t index = bucket * bucketRange_,
            lastrow = index + bucketRange_;
    for(; index < lastrow; ++index) {
      if(filter_->read(index) == fp) {
        //cout << "fp = " << fp << "\tbucket = " << bucket << "\tfilter =" << 
          //filter_->read(index) << "\tcode = " << code << endl;
        filterIdx = index;
        hpdAddr = dict_.end();
        return (int)qtizer_->value(values_->read(index)); 
      }
    }
  }
  return -1;
}
template<typename T>
void PerfectHash<T>::remove(const wordID_t* IDs, const int len) {
  // delete key if in high perf. dictionary
  string skey = hpDictKeyValue(IDs, len);
  if(dict_.find(skey) != dict_.end())
    dict_.erase(skey);
  else {  // check if key is in filter
    // get small representation for ngrams 
    //count_t bucket = bucketHash_->hash(IDs, len);
    count_t bucket = (bucketHash_->size() > 1? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len));
    // retrieve non zero fingerprint for ngram
    T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS)); 
    // return value if ngram is in filter
    uint64_t index = bucket * bucketRange_,
            lastrow = index + bucketRange_; 
    for(; index < lastrow; ++index) {
      if(filter_->read(index) == fp) { 
        filter_->write(index, 0);
        values_->write(index, 0);
        --idxTracker_[bucket]; // track bucket size reduction
        break;
      }
    }
  }
}
template<typename T> // clear filter index
void PerfectHash<T>::remove(uint64_t index) {
  CHECK(index < cells_);
  CHECK(filter_->read(index) != 0); // slow
  filter_->write(index, 0);
  values_->write(index, 0);
  //reduce bucket size
  count_t bucket = index / bucketRange_;
  --idxTracker_[bucket];
}
template<typename T>
T PerfectHash<T>::nonZeroSignature(const wordID_t* IDs, const int len,
  count_t bucket) {
  count_t h = bucket;
  T fingerprint(0);
  do {
    fingerprint = fingerHash_->hash(IDs, len, h);
    h += (h < fingerHash_->size() - 1 ? 1 : -h); // wrap around 
  } while((fingerprint == 0) && (h != bucket));
  if(fingerprint == 0) 
    cerr << "WARNING: Unable to find non-zero signature for ngram\n" << endl;
  return fingerprint;
}
template<typename T>
string PerfectHash<T>::hpDictKeyValue(const wordID_t* IDs, const int len) {
  string skey(" ");
  for(int i = 0; i < len; ++i)
    skey += Utils::IntToStr(IDs[i]) + "¬";
  Utils::trim(skey);
  return skey;
}
template<typename T>
count_t PerfectHash<T>::hpDictMemUse() {
  // return hpDict memory usage in MBs
  return (count_t) sizeof(hpDict_t::value_type)* dict_.size() >> 20;
}
template<typename T>
count_t PerfectHash<T>::bucketsMemUse() {
  // return bucket memory usage in MBs
  return (count_t) (filter_->size() + values_->size()); 
}
template<typename T>
void PerfectHash<T>::save(FileHandler* fout) {
  CHECK(fout != 0);
  cerr << "\tSaving perfect hash parameters...\n";
  fout->write((char*)&hitMask_, sizeof(hitMask_));
  fout->write((char*)&memBound_, sizeof(memBound_));
  fout->write((char*)&cellWidth_, sizeof(cellWidth_));
  fout->write((char*)&cells_, sizeof(cells_));
  fout->write((char*)&totBuckets_, sizeof(totBuckets_));
  fout->write((char*)&bucketRange_, sizeof(bucketRange_));
  fout->write((char*)idxTracker_, totBuckets_ * sizeof(idxTracker_[0]));
  qtizer_->save(fout);
  cerr << "\tSaving hash functions...\n";
  fingerHash_->save(fout);
  bucketHash_->save(fout);
  cerr << "\tSaving bit filter...\n";
  filter_->save(fout);
  values_->save(fout);
  cerr << "\tSaving high performance dictionary...\n";
  count_t size = dict_.size();
  fout->write((char*)&size, sizeof(count_t));
  *fout << endl;
  iterate(dict_, t) 
    *fout << t->first << "\t" << t->second << "\n";
}
template<typename T>
void PerfectHash<T>::load(FileHandler* fin) {
  CHECK(fin != 0);
  cerr << "\tLoading perfect hash parameters...\n";
  fin->read((char*)&hitMask_, sizeof(hitMask_));
  fin->read((char*)&memBound_, sizeof(memBound_));
  fin->read((char*)&cellWidth_, sizeof(cellWidth_));
  fin->read((char*)&cells_, sizeof(cells_));
  fin->read((char*)&totBuckets_, sizeof(totBuckets_));
  fin->read((char*)&bucketRange_, sizeof(bucketRange_));
  idxTracker_ = new uint8_t[totBuckets_];
  fin->read((char*)idxTracker_, totBuckets_ * sizeof(idxTracker_[0]));
  qtizer_ = new LogQtizer(fin);
  cerr << "\tLoading hash functions...\n";
  fingerHash_ = new UnivHash_linear<T>(fin);
  bucketHash_ = new UnivHash_linear<count_t>(fin);
  cerr << "\tLoading bit filter...\n";
  filter_ = new Filter<T>(fin);
  values_ = new Filter<T>(fin);
  cerr << "\tLoading HPD...\n";
  count_t size = 0;
  fin->read((char*)&size, sizeof(count_t));
  fin->ignore(256, '\n');
  string line;
  hpDict_t::key_type key;
  hpDict_t::mapped_type val;
  for(count_t i=0; i < size; ++i) {
    getline(*fin, line);
    Utils::trim(line);
    std::istringstream ss(line.c_str());
    ss >> key, ss >> val;
    dict_[key] = val;
  }
  cerr << "\tHPD size=" << dict_.size() << endl;
  cerr << "Finished loading ORLM." << endl;
}
template<typename T>
void PerfectHash<T>::analyze() {
  cerr << "Analyzing Dynamic Bloomier Filter...\n";
  // see how many items in each bucket
  uint8_t* bucketCnt = new uint8_t[totBuckets_];
  unsigned largestBucket = 0, totalCellsSet = 0, 
    smallestBucket = bucketRange_, totalZeroes = 0;
  int curBucket = -1, fullBuckets(0);
  for(int i = 0; i < totBuckets_; ++i) bucketCnt[i] = 0;
  for(uint64_t i =0; i < cells_; ++i) {
    if(i % bucketRange_ == 0) ++curBucket;
    if(filter_->read(i) != 0) {
      ++bucketCnt[curBucket];
      ++totalCellsSet;
    }
    else ++totalZeroes;
  }
  count_t bi = 0, si = 0;
  for(int i = 0; i < totBuckets_; ++i) {
    if(bucketCnt[i] > largestBucket) {
      largestBucket = bucketCnt[i];
      bi = i;
    }
    else if(bucketCnt[i] < smallestBucket) {
      smallestBucket = bucketCnt[i];
      si = i;
    }
  }
  count_t trackerCells(0);
  for(int i = 0; i < totBuckets_; i++) {
    trackerCells += idxTracker_[i];
    if(idxTracker_[i] == bucketRange_)
      ++fullBuckets;
  }
  for(int i = 0; i < totBuckets_; ++i) {
    if(bucketCnt[i] != idxTracker_[i])
      cerr << "bucketCnt[" << i << "] = " << (int)bucketCnt[i] << 
        "\tidxTracker_[" << i << "] = " << (int)idxTracker_[i] << endl;
  }
  cerr << "total cells= " << cells_ << endl;
  cerr << "total buckets= " << totBuckets_ << endl;
  cerr << "bucket range= " << (int)bucketRange_ << endl;
  cerr << "fingerprint bits= " << cellWidth_ << endl;
  cerr << "total cells set= " << totalCellsSet;
  cerr << " (idxTracker set = " << trackerCells << ")" << endl;
  cerr << "total zeroes=" << totalZeroes;
  cerr << " (idxTracker zeros = " << cells_ - trackerCells << ")" << endl;
  cerr << "largest bucket (" << bi << ") size= " << largestBucket << endl;
  cerr << "smallest bucket (" << si << ") size= " << smallestBucket << endl;
  cerr << "last bucket size= " << (int)bucketCnt[totBuckets_ - 1] <<
    " (idxTracker last bucket size = " << (int)idxTracker_[totBuckets_ - 1] << ")" << endl;
  cerr << "total buckets full = " << fullBuckets << endl;
  cerr << "total collision errors= " << collisions_ << endl;
  cerr << "high performance dictionary size= " << dict_.size() << endl;
  cerr << "high performance dictionary MBs= " << hpDictMemUse() << endl;
  cerr << "filter MBs= " << filter_->size() << endl;
  cerr << "values MBs= " << values_->size() << endl;
  delete[] bucketCnt;
}
template<typename T>  
bool PerfectHash<T>::update2(const wordID_t* IDs, const int len, 
  const count_t value, hpdEntry_t& hpdAddr, uint64_t& filterIdx) {
  // check if key is in high perf. dictionary
  filterIdx = cells_ + 1;
  string skey = hpDictKeyValue(IDs, len);
  if((hpdAddr = dict_.find(skey)) != dict_.end()) {
    hpdAddr->second += value;  
    return true;
  }
  // else hash ngram
  //count_t bucket = bucketHash_->hash(IDs, len);
  count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len));
  // restriction on fprint value is non-zero
  T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
  uint64_t index = bucket * bucketRange_,  // starting bucket row
          lastrow = index + bucketRange_; 
  while(index < lastrow) { // must check each row for matching fp event
    T filterVal = filter_->read(index);
    if(filterVal == fp) { // found event w.h.p.
      int oldval = (int)qtizer_->value(values_->read(index)); 
      values_->write(index, (T)qtizer_->code(oldval + value)); 
      filterIdx = index;
      return true;
    }
    ++index;
  }
  // add if it gets here. 
  insert(IDs, len, value);
  return false;
}
#endif