-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_huffman_decompressor.cc
1705 lines (1577 loc) · 59.2 KB
/
gen_huffman_decompressor.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2022 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <atomic>
#include <cstdint>
#include <fstream>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <queue>
#include <set>
#include <string>
#include <thread>
#include <vector>
#include <openssl/sha.h>
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "src/core/ext/transport/chttp2/transport/huffsyms.h"
#include "src/core/lib/gprpp/match.h"
///////////////////////////////////////////////////////////////////////////////
// SHA256 hash handling
// We need strong uniqueness checks of some very long strings - so we hash
// them with SHA256 and compare.
struct Hash {
uint8_t bytes[SHA256_DIGEST_LENGTH];
bool operator==(const Hash& other) const {
return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) == 0;
}
bool operator<(const Hash& other) const {
return memcmp(bytes, other.bytes, SHA256_DIGEST_LENGTH) < 0;
}
std::string ToString() const {
std::string result;
for (int i = 0; i < SHA256_DIGEST_LENGTH; i++) {
absl::StrAppend(&result, absl::Hex(bytes[i], absl::kZeroPad2));
}
return result;
}
};
// Given a vector of ints (T), return a Hash object with the sha256
template <typename T>
Hash HashVec(absl::string_view type, const std::vector<T>& v) {
Hash h;
std::string text = absl::StrCat(type, ":", absl::StrJoin(v, ","));
SHA256(reinterpret_cast<const uint8_t*>(text.data()), text.size(), h.bytes);
return h;
}
///////////////////////////////////////////////////////////////////////////////
// BitQueue
// A utility that treats a sequence of bits like a queue
class BitQueue {
public:
BitQueue(unsigned mask, int len) : mask_(mask), len_(len) {}
BitQueue() : BitQueue(0, 0) {}
// Return the most significant bit (the front of the queue)
int Front() const { return (mask_ >> (len_ - 1)) & 1; }
// Pop one bit off the queue
void Pop() {
mask_ &= ~(1 << (len_ - 1));
len_--;
}
bool Empty() const { return len_ == 0; }
int length() const { return len_; }
unsigned mask() const { return mask_; }
// Text representation of the queue
std::string ToString() const {
return absl::StrCat(absl::Hex(mask_), "/", len_);
}
// Comparisons so that we can use BitQueue as a key in a std::map
bool operator<(const BitQueue& other) const {
return std::tie(mask_, len_) < std::tie(other.mask_, other.len_);
}
private:
// The bits
unsigned mask_;
// How many bits have we
int len_;
};
///////////////////////////////////////////////////////////////////////////////
// Symbol sets for the huffman tree
// A Sym is one symbol in the tree, and the bits that we need to read to decode
// that symbol. As we progress through decoding we remove bits from the symbol,
// but also condense the number of symbols we're considering.
struct Sym {
BitQueue bits;
int symbol;
bool operator<(const Sym& other) const {
return std::tie(bits, symbol) < std::tie(other.bits, other.symbol);
}
};
// A SymSet is all the symbols we're considering at some time
using SymSet = std::vector<Sym>;
// Debug utility to turn a SymSet into a string
std::string SymSetString(const SymSet& syms) {
std::vector<std::string> parts;
for (const Sym& sym : syms) {
parts.push_back(absl::StrCat(sym.symbol, ":", sym.bits.ToString()));
}
return absl::StrJoin(parts, ",");
}
// Initial SymSet - all the symbols [0..256] with their bits initialized from
// the http2 static huffman tree.
SymSet AllSyms() {
SymSet syms;
for (int i = 0; i < GRPC_CHTTP2_NUM_HUFFSYMS; i++) {
Sym sym;
sym.bits =
BitQueue(grpc_chttp2_huffsyms[i].bits, grpc_chttp2_huffsyms[i].length);
sym.symbol = i;
syms.push_back(sym);
}
return syms;
}
// What whould we do after reading a set of bits?
struct ReadActions {
// Emit these symbols
std::vector<int> emit;
// Number of bits that were consumed by the read
int consumed;
// Remaining SymSet that we need to consider on the next read action
SymSet remaining;
};
// Given a SymSet \a pending, read through the bits in \a index and determine
// what actions the decoder should take.
// allow_multiple controls the behavior should we get to the last bit in pending
// and hence know which symbol to emit, but we still have bits in index.
// We could either start decoding the next symbol (allow_multiple == true), or
// we could stop (allow_multiple == false).
// If allow_multiple is true we tend to emit more per read op, but generate
// bigger tables.
ReadActions ActionsFor(BitQueue index, SymSet pending, bool allow_multiple) {
std::vector<int> emit;
int len_start = index.length();
int len_consume = len_start;
// We read one bit in index at a time, so whilst we have bits...
while (!index.Empty()) {
SymSet next_pending;
// For each symbol in the pending set
for (auto sym : pending) {
// If the first bit doesn't match, then that symbol is not part of our
// remaining set.
if (sym.bits.Front() != index.Front()) continue;
sym.bits.Pop();
next_pending.push_back(sym);
}
switch (next_pending.size()) {
case 0:
// There should be no bit patterns that are undecodable.
abort();
case 1:
// If we have one symbol left, we need to have decoded all of it.
if (!next_pending[0].bits.Empty()) abort();
// Emit that symbol
emit.push_back(next_pending[0].symbol);
// Track how many bits we've read.
len_consume = index.length() - 1;
// If we allow multiple, reprime pending and continue, otherwise stop.
if (!allow_multiple) goto done;
pending = AllSyms();
break;
default:
pending = std::move(next_pending);
break;
}
// Finished with this bit, continue with next
index.Pop();
}
done:
return ReadActions{std::move(emit), len_start - len_consume, pending};
}
///////////////////////////////////////////////////////////////////////////////
// MatchCase
// A variant that helps us bunch together related ReadActions
// A Matched in a MatchCase indicates that we need to emit some number of
// symbols
struct Matched {
// number of symbols to emit
int emits;
bool operator<(const Matched& other) const { return emits < other.emits; }
};
// Unmatched says we didn't emit anything and we need to keep decoding
struct Unmatched {
SymSet syms;
bool operator<(const Unmatched& other) const { return syms < other.syms; }
};
// Emit end of stream
struct End {
bool operator<(End) const { return false; }
};
using MatchCase = absl::variant<Matched, Unmatched, End>;
///////////////////////////////////////////////////////////////////////////////
// Text & numeric helper functions
// Given a vector of lines, indent those lines by some number of indents
// (2 spaces) and return that.
std::vector<std::string> IndentLines(std::vector<std::string> lines,
int n = 1) {
std::string indent(2 * n, ' ');
for (auto& line : lines) {
line = absl::StrCat(indent, line);
}
return lines;
}
// Given a snake_case_name return a PascalCaseName
std::string ToPascalCase(const std::string& in) {
std::string out;
bool next_upper = true;
for (char c : in) {
if (c == '_') {
next_upper = true;
} else {
if (next_upper) {
out.push_back(toupper(c));
next_upper = false;
} else {
out.push_back(c);
}
}
}
return out;
}
// Return a uint type for some number of bits (16 -> uint16_t, 32 -> uint32_t)
std::string Uint(int bits) { return absl::StrCat("uint", bits, "_t"); }
// Given a maximum value, how many bits to store it in a uint
int TypeBitsForMax(int max) {
if (max <= 255) {
return 8;
} else if (max <= 65535) {
return 16;
} else {
return 32;
}
}
// Combine Uint & TypeBitsForMax to make for more concise code
std::string TypeForMax(int max) { return Uint(TypeBitsForMax(max)); }
// How many bits are needed to encode a value
int BitsForMaxValue(int x) {
int n = 0;
while (x >= (1 << n)) n++;
return n;
}
///////////////////////////////////////////////////////////////////////////////
// Codegen framework
// Some helpers so we don't need to generate all the code linearly, which helps
// organize this a little more nicely.
// An Item is our primitive for code generation, it can generate some lines
// that it would like to emit - those lines are fed to a parent item that might
// generate more lines or mutate the ones we return, and so on until codegen
// is complete.
class Item {
public:
virtual ~Item() = default;
virtual std::vector<std::string> ToLines() const = 0;
std::string ToString() const {
return absl::StrCat(absl::StrJoin(ToLines(), "\n"), "\n");
}
};
using ItemPtr = std::unique_ptr<Item>;
// An item that emits one line (the one given as an argument!)
class String : public Item {
public:
explicit String(std::string s) : s_(std::move(s)) {}
std::vector<std::string> ToLines() const override { return {s_}; }
private:
std::string s_;
};
// An item that returns a fixed copyright notice and autogenerated note text.
class Prelude final : public Item {
public:
explicit Prelude(absl::string_view comment_prefix, int copyright_year)
: comment_prefix_(comment_prefix), copyright_year_(copyright_year) {}
std::vector<std::string> ToLines() const override {
auto line = [this](absl::string_view text) {
return absl::StrCat(comment_prefix_, " ", text);
};
return {
line(absl::StrCat("Copyright ", copyright_year_, " gRPC authors.")),
line(""),
line("Licensed under the Apache License, Version 2.0 (the "
"\"License\");"),
line(
"you may not use this file except in compliance with the License."),
line("You may obtain a copy of the License at"),
line(""),
line(" http://www.apache.org/licenses/LICENSE-2.0"),
line(""),
line("Unless required by applicable law or agreed to in writing, "
"software"),
line("distributed under the License is distributed on an \"AS IS\" "
"BASIS,"),
line("WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or "
"implied."),
line("See the License for the specific language governing permissions "
"and"),
line("limitations under the License."),
"",
line("This file is autogenerated: see "
"tools/codegen/core/gen_huffman_decompressor.cc"),
""};
}
private:
absl::string_view comment_prefix_;
int copyright_year_;
};
class Switch;
// A Sink is an Item that we can add more Items to.
// At codegen time it calls each of its children in turn and concatenates
// their results together.
class Sink : public Item {
public:
std::vector<std::string> ToLines() const override {
std::vector<std::string> lines;
for (const auto& item : children_) {
for (const auto& line : item->ToLines()) {
lines.push_back(line);
}
}
return lines;
}
// Add one string to our output.
void Add(std::string s) {
children_.push_back(std::make_unique<String>(std::move(s)));
}
// Add an item of type T to our output (constructing it with args).
template <typename T, typename... Args>
T* Add(Args&&... args) {
auto v = std::make_unique<T>(std::forward<Args>(args)...);
auto* r = v.get();
children_.push_back(std::move(v));
return r;
}
private:
std::vector<ItemPtr> children_;
};
// A sink that indents its lines by one indent (2 spaces)
class Indent : public Sink {
public:
std::vector<std::string> ToLines() const override {
return IndentLines(Sink::ToLines());
}
};
// A Sink that wraps its lines in a while block
class While : public Sink {
public:
explicit While(std::string cond) : cond_(std::move(cond)) {}
std::vector<std::string> ToLines() const override {
std::vector<std::string> lines;
lines.push_back(absl::StrCat("while (", cond_, ") {"));
for (const auto& line : IndentLines(Sink::ToLines())) {
lines.push_back(line);
}
lines.push_back("}");
return lines;
}
private:
std::string cond_;
};
// A switch statement.
// Cases can be modified by calling the Case member.
// Identical cases are collapsed into 'case X: case Y:' type blocks.
class Switch : public Item {
public:
struct Default {
bool operator<(const Default&) const { return false; }
bool operator==(const Default&) const { return true; }
};
using CaseLabel = absl::variant<int, std::string, Default>;
// \a cond is the condition to place at the head of the switch statement.
// eg. "switch (cond) {".
explicit Switch(std::string cond) : cond_(std::move(cond)) {}
std::vector<std::string> ToLines() const override {
std::map<std::string, std::vector<CaseLabel>> reverse_map;
for (const auto& kv : cases_) {
reverse_map[kv.second.ToString()].push_back(kv.first);
}
std::vector<std::pair<std::string, std::vector<CaseLabel>>>
sorted_reverse_map;
sorted_reverse_map.reserve(reverse_map.size());
for (auto& kv : reverse_map) {
sorted_reverse_map.push_back(kv);
}
for (auto& e : sorted_reverse_map) {
std::sort(e.second.begin(), e.second.end());
}
std::sort(sorted_reverse_map.begin(), sorted_reverse_map.end(),
[](const auto& a, const auto& b) { return a.second < b.second; });
std::vector<std::string> lines;
lines.push_back(absl::StrCat("switch (", cond_, ") {"));
for (const auto& kv : sorted_reverse_map) {
for (const auto& cond : kv.second) {
lines.push_back(absl::StrCat(
" ",
grpc_core::Match(
cond, [](Default) -> std::string { return "default"; },
[](int i) { return absl::StrCat("case ", i); },
[](const std::string& s) { return absl::StrCat("case ", s); }),
":"));
}
lines.back().append(" {");
for (const auto& case_line :
IndentLines(cases_.find(kv.second[0])->second.ToLines(), 2)) {
lines.push_back(case_line);
}
lines.push_back(" }");
}
lines.push_back("}");
return lines;
}
Sink* Case(CaseLabel cond) { return &cases_[cond]; }
private:
std::string cond_;
std::map<CaseLabel, Sink> cases_;
};
///////////////////////////////////////////////////////////////////////////////
// BuildCtx declaration
// Shared state for one code gen attempt
class TableBuilder;
class FunMaker;
class BuildCtx {
public:
BuildCtx(std::vector<int> max_bits_for_depth, Sink* global_fns,
Sink* global_decls, Sink* global_values, FunMaker* fun_maker)
: max_bits_for_depth_(std::move(max_bits_for_depth)),
global_fns_(global_fns),
global_decls_(global_decls),
global_values_(global_values),
fun_maker_(fun_maker) {}
void AddStep(SymSet start_syms, int num_bits, bool is_top, bool refill,
int depth, Sink* out);
void AddMatchBody(TableBuilder* table_builder, std::string index,
std::string ofs, const MatchCase& match_case, bool is_top,
bool refill, int depth, Sink* out);
void AddDone(SymSet start_syms, int num_bits, bool all_ones_so_far,
Sink* out);
int NewId() { return next_id_++; }
int MaxBitsForTop() const { return max_bits_for_depth_[0]; }
absl::optional<std::string> PreviousNameForArtifact(std::string proposed_name,
Hash hash) {
auto it = arrays_.find(hash);
if (it == arrays_.end()) {
arrays_.emplace(hash, proposed_name);
return absl::nullopt;
}
return it->second;
}
Sink* global_fns() const { return global_fns_; }
Sink* global_decls() const { return global_decls_; }
Sink* global_values() const { return global_values_; }
private:
void AddDoneCase(size_t n, size_t n_bits, bool all_ones_so_far, SymSet syms,
std::vector<uint8_t> emit, TableBuilder* table_builder,
std::map<absl::optional<int>, int>* cases);
const std::vector<int> max_bits_for_depth_;
std::map<Hash, std::string> arrays_;
int next_id_ = 1;
Sink* const global_fns_;
Sink* const global_decls_;
Sink* const global_values_;
FunMaker* const fun_maker_;
};
///////////////////////////////////////////////////////////////////////////////
// TableBuilder
// All our magic for building decode tables.
// We have three kinds of tables to generate:
// 1. op tables that translate a bit sequence to which decode case we should
// execute (and arguments to it), and
// 2. emit tables that translate an index given by the op table and tell us
// which symbols to emit
// Op table format
// Our opcodes contain an offset into an emit table, a number of bits consumed
// and an operation. The consumed bits are how many of the presented to us bits
// we actually took. The operation tells whether to emit some symbols (and how
// many) or to keep decoding.
// Optimization 1:
// op tables are essentially dense maps of bits -> opcode, and it turns out
// that *many* of the opcodes repeat across index bits for some of our tables
// so for those we split the table into two levels: first level indexes into
// a child table, and the child table contains the deduped opcodes.
// Optimization 2:
// Emit tables are a bit list of uint8_ts, and are indexed into by the op
// table (with an offset and length) - since many symbols get repeated, we try
// to overlay the symbols in the emit table to reduce the size.
// Optimization 3:
// We shard the table into some number of slices and use the top bits of the
// incoming lookup to select the shard. This tends to allow us to use smaller
// types to represent the table, saving on footprint.
class TableBuilder {
public:
explicit TableBuilder(BuildCtx* ctx) : ctx_(ctx), id_(ctx->NewId()) {}
// Append one case to the table
void Add(int match_case, std::vector<uint8_t> emit, int consumed_bits) {
elems_.push_back({match_case, std::move(emit), consumed_bits});
max_consumed_bits_ = std::max(max_consumed_bits_, consumed_bits);
max_match_case_ = std::max(max_match_case_, match_case);
}
// Build the table
void Build() const {
Choose()->Build(this, BitsForMaxValue(elems_.size() - 1));
}
// Generate a call to the accessor function for the emit table
std::string EmitAccessor(std::string index, std::string offset) {
return absl::StrCat("GetEmit", id_, "(", index, ", ", offset, ")");
}
// Generate a call to the accessor function for the op table
std::string OpAccessor(std::string index) {
return absl::StrCat("GetOp", id_, "(", index, ")");
}
int ConsumeBits() const { return BitsForMaxValue(max_consumed_bits_); }
int MatchBits() const { return BitsForMaxValue(max_match_case_); }
private:
// One element in the op table.
struct Elem {
int match_case;
std::vector<uint8_t> emit;
int consumed_bits;
};
// A nested slice is one slice of a table using two level lookup
// - i.e. we look at an outer table to get an index into the inner table,
// and then fetch the result from there.
struct NestedSlice {
std::vector<uint8_t> emit;
std::vector<uint64_t> inner;
std::vector<int> outer;
// Various sizes return number of bits to be generated
size_t InnerSize() const {
return inner.size() *
TypeBitsForMax(*std::max_element(inner.begin(), inner.end()));
}
size_t OuterSize() const {
return outer.size() *
TypeBitsForMax(*std::max_element(outer.begin(), outer.end()));
}
size_t EmitSize() const { return emit.size() * 8; }
};
// A slice is one part of a larger table.
struct Slice {
std::vector<uint8_t> emit;
std::vector<uint64_t> ops;
// Various sizes return number of bits to be generated
size_t OpsSize() const {
return ops.size() *
TypeBitsForMax(*std::max_element(ops.begin(), ops.end()));
}
size_t EmitSize() const { return emit.size() * 8; }
// Given a vector of symbols to emit, return the offset into the emit table
// that they're at (adding them to the emit table if necessary).
int OffsetOf(const std::vector<uint8_t>& x) {
if (x.empty()) return 0;
auto r = std::search(emit.begin(), emit.end(), x.begin(), x.end());
if (r == emit.end()) {
// look for a partial match @ end
for (size_t check_len = x.size() - 1; check_len > 0; check_len--) {
if (emit.size() < check_len) continue;
bool matches = true;
for (size_t i = 0; matches && i < check_len; i++) {
if (emit[emit.size() - check_len + i] != x[i]) matches = false;
}
if (matches) {
int offset = emit.size() - check_len;
for (size_t i = check_len; i < x.size(); i++) {
emit.push_back(x[i]);
}
for (size_t i = 0; i < x.size(); i++) {
if (emit[offset + i] != x[i]) {
abort();
}
}
return offset;
}
}
// add new
int result = emit.size();
for (auto v : x) emit.push_back(v);
return result;
}
return r - emit.begin();
}
// Convert this slice to a nested slice.
NestedSlice MakeNestedSlice() const {
NestedSlice result;
result.emit = emit;
std::map<uint64_t, int> op_to_inner;
for (auto v : ops) {
auto it = op_to_inner.find(v);
if (it == op_to_inner.end()) {
it = op_to_inner.emplace(v, op_to_inner.size()).first;
result.inner.push_back(v);
}
result.outer.push_back(it->second);
}
return result;
}
};
// An EncodeOption is a potential way of encoding a table.
struct EncodeOption {
// Overall size (in bits) of the table encoding
virtual size_t Size() const = 0;
// Generate the code
virtual void Build(const TableBuilder* builder, int op_bits) const = 0;
virtual ~EncodeOption() {}
};
// NestedTable is a table that uses two level lookup for each slice
struct NestedTable : public EncodeOption {
std::vector<NestedSlice> slices;
int slice_bits;
size_t Size() const override {
size_t sum = 0;
std::vector<Hash> h_emit;
std::vector<Hash> h_inner;
std::vector<Hash> h_outer;
for (size_t i = 0; i < slices.size(); i++) {
h_emit.push_back(HashVec("uint8_t", slices[i].emit));
h_inner.push_back(HashVec(TypeForMax(MaxInner()), slices[i].inner));
h_outer.push_back(HashVec(TypeForMax(MaxOuter()), slices[i].outer));
}
std::set<Hash> seen;
for (size_t i = 0; i < slices.size(); i++) {
// Try to account for deduplication in the size calculation.
if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
if (seen.count(h_outer[i]) == 0) sum += slices[i].OuterSize();
if (seen.count(h_inner[i]) == 0) sum += slices[i].OuterSize();
seen.insert(h_emit[i]);
seen.insert(h_outer[i]);
seen.insert(h_inner[i]);
}
if (slice_bits != 0) sum += 3 * 64 * slices.size();
return sum;
}
void Build(const TableBuilder* builder, int op_bits) const override {
Sink* const global_fns = builder->ctx_->global_fns();
Sink* const global_decls = builder->ctx_->global_decls();
Sink* const global_values = builder->ctx_->global_values();
const int id = builder->id_;
std::vector<std::string> lines;
const uint64_t max_inner = MaxInner();
const uint64_t max_outer = MaxOuter();
std::vector<std::unique_ptr<Array>> emit_names;
std::vector<std::unique_ptr<Array>> inner_names;
std::vector<std::unique_ptr<Array>> outer_names;
for (size_t i = 0; i < slices.size(); i++) {
emit_names.push_back(builder->GenArray(
slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
"uint8_t", slices[i].emit, true, global_decls, global_values));
inner_names.push_back(builder->GenArray(
slice_bits != 0, absl::StrCat("table", id, "_", i, "_inner"),
TypeForMax(max_inner), slices[i].inner, true, global_decls,
global_values));
outer_names.push_back(builder->GenArray(
slice_bits != 0, absl::StrCat("table", id, "_", i, "_outer"),
TypeForMax(max_outer), slices[i].outer, false, global_decls,
global_values));
}
if (slice_bits == 0) {
global_fns->Add(absl::StrCat(
"static inline uint64_t GetOp", id, "(size_t i) { return ",
inner_names[0]->Index(outer_names[0]->Index("i")), "; }"));
global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
"(size_t, size_t emit) { return ",
emit_names[0]->Index("emit"), "; }"));
} else {
GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
global_values);
GenCompound(id, inner_names, "inner", TypeForMax(max_inner),
global_decls, global_values);
GenCompound(id, outer_names, "outer", TypeForMax(max_outer),
global_decls, global_values);
global_fns->Add(absl::StrCat(
"static inline uint64_t GetOp", id, "(size_t i) { return table", id,
"_inner_[i >> ", op_bits - slice_bits, "][table", id,
"_outer_[i >> ", op_bits - slice_bits, "][i & 0x",
absl::Hex((1 << (op_bits - slice_bits)) - 1), "]]; }"));
global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
"(size_t i, size_t emit) { return table",
id, "_emit_[i >> ", op_bits - slice_bits,
"][emit]; }"));
}
}
uint64_t MaxInner() const {
if (max_inner == 0) {
for (size_t i = 0; i < slices.size(); i++) {
max_inner =
std::max(max_inner, *std::max_element(slices[i].inner.begin(),
slices[i].inner.end()));
}
}
return max_inner;
}
int MaxOuter() const {
if (max_outer == 0) {
for (size_t i = 0; i < slices.size(); i++) {
max_outer =
std::max(max_outer, *std::max_element(slices[i].outer.begin(),
slices[i].outer.end()));
}
}
return max_outer;
}
mutable uint64_t max_inner = 0;
mutable int max_outer = 0;
};
// Encoding that uses single level lookup for each slice.
struct Table : public EncodeOption {
std::vector<Slice> slices;
int slice_bits;
size_t Size() const override {
size_t sum = 0;
std::vector<Hash> h_emit;
std::vector<Hash> h_ops;
for (size_t i = 0; i < slices.size(); i++) {
h_emit.push_back(HashVec("uint8_t", slices[i].emit));
h_ops.push_back(HashVec(TypeForMax(MaxOp()), slices[i].ops));
}
std::set<Hash> seen;
for (size_t i = 0; i < slices.size(); i++) {
if (seen.count(h_emit[i]) == 0) sum += slices[i].EmitSize();
if (seen.count(h_ops[i]) == 0) sum += slices[i].OpsSize();
seen.insert(h_emit[i]);
seen.insert(h_ops[i]);
}
return sum + 3 * 64 * slices.size();
}
void Build(const TableBuilder* builder, int op_bits) const override {
Sink* const global_fns = builder->ctx_->global_fns();
Sink* const global_decls = builder->ctx_->global_decls();
Sink* const global_values = builder->ctx_->global_values();
uint64_t max_op = MaxOp();
const int id = builder->id_;
std::vector<std::unique_ptr<Array>> emit_names;
std::vector<std::unique_ptr<Array>> ops_names;
for (size_t i = 0; i < slices.size(); i++) {
emit_names.push_back(builder->GenArray(
slice_bits != 0, absl::StrCat("table", id, "_", i, "_emit"),
"uint8_t", slices[i].emit, true, global_decls, global_values));
ops_names.push_back(builder->GenArray(
slice_bits != 0, absl::StrCat("table", id, "_", i, "_ops"),
TypeForMax(max_op), slices[i].ops, true, global_decls,
global_values));
}
if (slice_bits == 0) {
global_fns->Add(absl::StrCat("static inline uint64_t GetOp", id,
"(size_t i) { return ",
ops_names[0]->Index("i"), "; }"));
global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
"(size_t, size_t emit) { return ",
emit_names[0]->Index("emit"), "; }"));
} else {
GenCompound(id, emit_names, "emit", "uint8_t", global_decls,
global_values);
GenCompound(id, ops_names, "ops", TypeForMax(max_op), global_decls,
global_values);
global_fns->Add(absl::StrCat(
"static inline uint64_t GetOp", id, "(size_t i) { return table", id,
"_ops_[i >> ", op_bits - slice_bits, "][i & 0x",
absl::Hex((1 << (op_bits - slice_bits)) - 1), "]; }"));
global_fns->Add(absl::StrCat("static inline uint64_t GetEmit", id,
"(size_t i, size_t emit) { return table",
id, "_emit_[i >> ", op_bits - slice_bits,
"][emit]; }"));
}
}
uint64_t MaxOp() const {
if (max_op == 0) {
for (size_t i = 0; i < slices.size(); i++) {
max_op = std::max(max_op, *std::max_element(slices[i].ops.begin(),
slices[i].ops.end()));
}
}
return max_op;
}
mutable uint64_t max_op = 0;
// Convert to a two-level lookup
std::unique_ptr<NestedTable> MakeNestedTable() {
std::unique_ptr<NestedTable> result(new NestedTable);
result->slice_bits = slice_bits;
for (const auto& slice : slices) {
result->slices.push_back(slice.MakeNestedSlice());
}
return result;
}
};
// Given a number of slices (2**slice_bits), generate a table that uses a
// single level lookup for each slice based on our input.
std::unique_ptr<Table> MakeTable(size_t slice_bits) const {
std::unique_ptr<Table> table = std::make_unique<Table>();
int slices = 1 << slice_bits;
table->slices.resize(slices);
table->slice_bits = slice_bits;
const int pack_consume_bits = ConsumeBits();
const int pack_match_bits = MatchBits();
for (size_t i = 0; i < slices; i++) {
auto& slice = table->slices[i];
for (size_t j = 0; j < elems_.size() / slices; j++) {
const auto& elem = elems_[i * elems_.size() / slices + j];
slice.ops.push_back(elem.consumed_bits |
(elem.match_case << pack_consume_bits) |
(slice.OffsetOf(elem.emit)
<< (pack_consume_bits + pack_match_bits)));
}
}
return table;
}
class Array {
public:
virtual ~Array() = default;
virtual std::string Index(absl::string_view value) = 0;
virtual std::string ArrayName() = 0;
virtual int Cost() = 0;
};
class NamedArray : public Array {
public:
explicit NamedArray(std::string name) : name_(std::move(name)) {}
std::string Index(absl::string_view value) override {
return absl::StrCat(name_, "[", value, "]");
}
std::string ArrayName() override { return name_; }
int Cost() override { abort(); }
private:
std::string name_;
};
class IdentityArray : public Array {
public:
std::string Index(absl::string_view value) override {
return std::string(value);
}
std::string ArrayName() override { abort(); }
int Cost() override { return 0; }
};
class ConstantArray : public Array {
public:
explicit ConstantArray(std::string value) : value_(std::move(value)) {}
std::string Index(absl::string_view index) override {
return absl::StrCat("((void)", index, ", ", value_, ")");
}
std::string ArrayName() override { abort(); }
int Cost() override { return 0; }
private:
std::string value_;
};
class OffsetArray : public Array {
public:
explicit OffsetArray(int offset) : offset_(offset) {}
std::string Index(absl::string_view value) override {
return absl::StrCat(value, " + ", offset_);
}
std::string ArrayName() override { abort(); }
int Cost() override { return 10; }
private:
int offset_;
};
class LinearDivideArray : public Array {
public:
LinearDivideArray(int offset, int divisor)
: offset_(offset), divisor_(divisor) {}
std::string Index(absl::string_view value) override {
return absl::StrCat(value, "/", divisor_, " + ", offset_);
}
std::string ArrayName() override { abort(); }
int Cost() override { return 20 + (offset_ != 0 ? 10 : 0); }
private:
int offset_;
int divisor_;
};
class TwoElemArray : public Array {
public:
TwoElemArray(std::string value0, std::string value1)
: value0_(std::move(value0)), value1_(std::move(value1)) {}
std::string Index(absl::string_view value) override {
return absl::StrCat(value, " ? ", value1_, " : ", value0_);
}
std::string ArrayName() override { abort(); }
int Cost() override { return 40; }
private:
std::string value0_;
std::string value1_;
};
class Composite2Array : public Array {
public:
Composite2Array(std::unique_ptr<Array> a, std::unique_ptr<Array> b,
int split)
: a_(std::move(a)), b_(std::move(b)), split_(split) {}
std::string Index(absl::string_view value) override {
return absl::StrCat(
"(", value, " < ", split_, " ? (", a_->Index(value), ") : (",
b_->Index(absl::StrCat("(", value, "-", split_, ")")), "))");
}
std::string ArrayName() override { abort(); }
int Cost() override { return 40 + a_->Cost() + b_->Cost(); }