/
typeinfer.py
1805 lines (1552 loc) · 70.5 KB
/
typeinfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Type inference base on CPA.
The algorithm guarantees monotonic growth of type-sets for each variable.
Steps:
1. seed initial types
2. build constraints
3. propagate constraints
4. unify types
Constraint propagation is precise and does not regret (no backtracing).
Constraints push types forward following the dataflow.
"""
import logging
import operator
import contextlib
import itertools
from pprint import pprint
from collections import OrderedDict, defaultdict
from functools import reduce
from numba.core import types, utils, typing, ir, config
from numba.core.typing.templates import Signature
from numba.core.errors import (TypingError, UntypedAttributeError,
new_error_context, termcolor, UnsupportedError,
ForceLiteralArg, CompilerError, NumbaValueError)
from numba.core.funcdesc import qualifying_prefix
from numba.core.typeconv import Conversion
_logger = logging.getLogger(__name__)
class NOTSET:
pass
# terminal color markup
_termcolor = termcolor()
class TypeVar(object):
def __init__(self, context, var):
self.context = context
self.var = var
self.type = None
self.locked = False
# Stores source location of first definition
self.define_loc = None
# Qualifiers
self.literal_value = NOTSET
def add_type(self, tp, loc):
assert isinstance(tp, types.Type), type(tp)
# Special case for _undef_var.
# If the typevar is the _undef_var, use the incoming type directly.
if self.type is types._undef_var:
self.type = tp
return self.type
if self.locked:
if tp != self.type:
if self.context.can_convert(tp, self.type) is None:
msg = ("No conversion from %s to %s for '%s', "
"defined at %s")
raise TypingError(msg % (tp, self.type, self.var,
self.define_loc),
loc=loc)
else:
if self.type is not None:
unified = self.context.unify_pairs(self.type, tp)
if unified is None:
msg = "Cannot unify %s and %s for '%s', defined at %s"
raise TypingError(msg % (self.type, tp, self.var,
self.define_loc),
loc=self.define_loc)
else:
# First time definition
unified = tp
self.define_loc = loc
self.type = unified
return self.type
def lock(self, tp, loc, literal_value=NOTSET):
assert isinstance(tp, types.Type), type(tp)
if self.locked:
msg = ("Invalid reassignment of a type-variable detected, type "
"variables are locked according to the user provided "
"function signature or from an ir.Const node. This is a "
"bug! Type={}. {}").format(tp, self.type)
raise CompilerError(msg, loc)
# If there is already a type, ensure we can convert it to the
# locked type.
if (self.type is not None and
self.context.can_convert(self.type, tp) is None):
raise TypingError("No conversion from %s to %s for "
"'%s'" % (tp, self.type, self.var), loc=loc)
self.type = tp
self.locked = True
if self.define_loc is None:
self.define_loc = loc
self.literal_value = literal_value
def union(self, other, loc):
if other.type is not None:
self.add_type(other.type, loc=loc)
return self.type
def __repr__(self):
return '%s := %s' % (self.var, self.type or "<undecided>")
@property
def defined(self):
return self.type is not None
def get(self):
return (self.type,) if self.type is not None else ()
def getone(self):
if self.type is None:
raise TypingError("Undecided type {}".format(self))
return self.type
def __len__(self):
return 1 if self.type is not None else 0
class ConstraintNetwork(object):
"""
TODO: It is possible to optimize constraint propagation to consider only
dirty type variables.
"""
def __init__(self):
self.constraints = []
def append(self, constraint):
self.constraints.append(constraint)
def propagate(self, typeinfer):
"""
Execute all constraints. Errors are caught and returned as a list.
This allows progressing even though some constraints may fail
due to lack of information
(e.g. imprecise types such as List(undefined)).
"""
errors = []
for constraint in self.constraints:
loc = constraint.loc
with typeinfer.warnings.catch_warnings(filename=loc.filename,
lineno=loc.line):
try:
constraint(typeinfer)
except ForceLiteralArg as e:
errors.append(e)
except TypingError as e:
_logger.debug("captured error", exc_info=e)
new_exc = TypingError(
str(e), loc=constraint.loc,
highlighting=False,
)
errors.append(utils.chain_exception(new_exc, e))
except Exception as e:
if utils.use_old_style_errors():
_logger.debug("captured error", exc_info=e)
msg = ("Internal error at {con}.\n{err}\n"
"Enable logging at debug level for details.")
new_exc = TypingError(
msg.format(con=constraint, err=str(e)),
loc=constraint.loc,
highlighting=False,
)
errors.append(utils.chain_exception(new_exc, e))
elif utils.use_new_style_errors():
raise e
else:
msg = ("Unknown CAPTURED_ERRORS style: "
f"'{config.CAPTURED_ERRORS}'.")
assert 0, msg
return errors
class Propagate(object):
"""
A simple constraint for direct propagation of types for assignments.
"""
def __init__(self, dst, src, loc):
self.dst = dst
self.src = src
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of assignment at {0}", self.loc,
loc=self.loc):
typeinfer.copy_type(self.src, self.dst, loc=self.loc)
# If `dst` is refined, notify us
typeinfer.refine_map[self.dst] = self
def refine(self, typeinfer, target_type):
# Do not back-propagate to locked variables (e.g. constants)
assert target_type.is_precise()
typeinfer.add_type(self.src, target_type, unless_locked=True,
loc=self.loc)
class ArgConstraint(object):
def __init__(self, dst, src, loc):
self.dst = dst
self.src = src
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of argument at {0}", self.loc):
typevars = typeinfer.typevars
src = typevars[self.src]
if not src.defined:
return
ty = src.getone()
if isinstance(ty, types.Omitted):
ty = typeinfer.context.resolve_value_type_prefer_literal(
ty.value,
)
if not ty.is_precise():
raise TypingError('non-precise type {}'.format(ty))
typeinfer.add_type(self.dst, ty, loc=self.loc)
class BuildTupleConstraint(object):
def __init__(self, target, items, loc):
self.target = target
self.items = items
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of tuple at {0}", self.loc):
typevars = typeinfer.typevars
tsets = [typevars[i.name].get() for i in self.items]
for vals in itertools.product(*tsets):
if vals and all(vals[0] == v for v in vals):
tup = types.UniTuple(dtype=vals[0], count=len(vals))
else:
# empty tuples fall here as well
tup = types.Tuple(vals)
assert tup.is_precise()
typeinfer.add_type(self.target, tup, loc=self.loc)
class _BuildContainerConstraint(object):
def __init__(self, target, items, loc):
self.target = target
self.items = items
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of {0} at {1}",
self.container_type, self.loc):
typevars = typeinfer.typevars
tsets = [typevars[i.name].get() for i in self.items]
if not tsets:
typeinfer.add_type(self.target,
self.container_type(types.undefined),
loc=self.loc)
else:
for typs in itertools.product(*tsets):
unified = typeinfer.context.unify_types(*typs)
if unified is not None:
typeinfer.add_type(self.target,
self.container_type(unified),
loc=self.loc)
class BuildListConstraint(_BuildContainerConstraint):
def __init__(self, target, items, loc):
self.target = target
self.items = items
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of {0} at {1}",
types.List, self.loc):
typevars = typeinfer.typevars
tsets = [typevars[i.name].get() for i in self.items]
if not tsets:
typeinfer.add_type(self.target,
types.List(types.undefined),
loc=self.loc)
else:
for typs in itertools.product(*tsets):
unified = typeinfer.context.unify_types(*typs)
if unified is not None:
# pull out literals if available
islit = [isinstance(x, types.Literal) for x in typs]
iv = None
if all(islit):
iv = [x.literal_value for x in typs]
typeinfer.add_type(self.target,
types.List(unified,
initial_value=iv),
loc=self.loc)
else:
typeinfer.add_type(self.target,
types.LiteralList(typs),
loc=self.loc)
class BuildSetConstraint(_BuildContainerConstraint):
container_type = types.Set
class BuildMapConstraint(object):
def __init__(self, target, items, special_value, value_indexes, loc):
self.target = target
self.items = items
self.special_value = special_value
self.value_indexes = value_indexes
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of dict at {0}", self.loc):
typevars = typeinfer.typevars
# figure out what sort of dict is being dealt with
tsets = [(typevars[k.name].getone(), typevars[v.name].getone())
for k, v in self.items]
if not tsets:
typeinfer.add_type(self.target,
types.DictType(types.undefined,
types.undefined,
self.special_value),
loc=self.loc)
else:
# all the info is known about the dict, if its
# str keys -> random heterogeneous values treat as literalstrkey
ktys = [x[0] for x in tsets]
vtys = [x[1] for x in tsets]
strkey = all([isinstance(x, types.StringLiteral) for x in ktys])
literalvty = all([isinstance(x, types.Literal) for x in vtys])
vt0 = types.unliteral(vtys[0])
# homogeneous values comes in the form of being able to cast
# all the other values in the ctor to the type of the first.
# The order is important as `typed.Dict` takes it's type from
# the first element.
def check(other):
conv = typeinfer.context.can_convert(other, vt0)
return conv is not None and conv < Conversion.unsafe
homogeneous = all([check(types.unliteral(x)) for x in vtys])
# Special cases:
# Single key:value in ctor, key is str, value is an otherwise
# illegal container type, e.g. LiteralStrKeyDict or
# List, there's no way to put this into a typed.Dict, so make it
# a LiteralStrKeyDict, same goes for LiteralList.
if len(vtys) == 1:
valty = vtys[0]
if isinstance(valty, (types.LiteralStrKeyDict,
types.List,
types.LiteralList)):
homogeneous = False
if strkey and not homogeneous:
resolved_dict = {x: y for x, y in zip(ktys, vtys)}
ty = types.LiteralStrKeyDict(resolved_dict,
self.value_indexes)
typeinfer.add_type(self.target, ty, loc=self.loc)
else:
init_value = self.special_value if literalvty else None
key_type, value_type = tsets[0]
typeinfer.add_type(self.target,
types.DictType(key_type,
value_type,
init_value),
loc=self.loc)
class ExhaustIterConstraint(object):
def __init__(self, target, count, iterator, loc):
self.target = target
self.count = count
self.iterator = iterator
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of exhaust iter at {0}", self.loc):
typevars = typeinfer.typevars
for tp in typevars[self.iterator.name].get():
# unpack optional
tp = tp.type if isinstance(tp, types.Optional) else tp
if isinstance(tp, types.BaseTuple):
if len(tp) == self.count:
assert tp.is_precise()
typeinfer.add_type(self.target, tp, loc=self.loc)
break
else:
msg = (f"wrong tuple length for {self.iterator.name}: ",
f"expected {self.count}, got {len(tp)}")
raise NumbaValueError(msg)
elif isinstance(tp, types.IterableType):
tup = types.UniTuple(dtype=tp.iterator_type.yield_type,
count=self.count)
assert tup.is_precise()
typeinfer.add_type(self.target, tup, loc=self.loc)
break
else:
raise TypingError("failed to unpack {}".format(tp),
loc=self.loc)
class PairFirstConstraint(object):
def __init__(self, target, pair, loc):
self.target = target
self.pair = pair
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of pair-first at {0}", self.loc):
typevars = typeinfer.typevars
for tp in typevars[self.pair.name].get():
if not isinstance(tp, types.Pair):
# XXX is this an error?
continue
assert (isinstance(tp.first_type, types.UndefinedFunctionType)
or tp.first_type.is_precise())
typeinfer.add_type(self.target, tp.first_type, loc=self.loc)
class PairSecondConstraint(object):
def __init__(self, target, pair, loc):
self.target = target
self.pair = pair
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of pair-second at {0}", self.loc):
typevars = typeinfer.typevars
for tp in typevars[self.pair.name].get():
if not isinstance(tp, types.Pair):
# XXX is this an error?
continue
assert tp.second_type.is_precise()
typeinfer.add_type(self.target, tp.second_type, loc=self.loc)
class StaticGetItemConstraint(object):
def __init__(self, target, value, index, index_var, loc):
self.target = target
self.value = value
self.index = index
if index_var is not None:
self.fallback = IntrinsicCallConstraint(target, operator.getitem,
(value, index_var), {},
None, loc)
else:
self.fallback = None
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of static-get-item at {0}", self.loc):
typevars = typeinfer.typevars
for ty in typevars[self.value.name].get():
sig = typeinfer.context.resolve_static_getitem(
value=ty, index=self.index,
)
if sig is not None:
itemty = sig.return_type
# if the itemty is not precise, let it through, unification
# will catch it and produce a better error message
typeinfer.add_type(self.target, itemty, loc=self.loc)
elif self.fallback is not None:
self.fallback(typeinfer)
def get_call_signature(self):
# The signature is only needed for the fallback case in lowering
return self.fallback and self.fallback.get_call_signature()
class TypedGetItemConstraint(object):
def __init__(self, target, value, dtype, index, loc):
self.target = target
self.value = value
self.dtype = dtype
self.index = index
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of typed-get-item at {0}", self.loc):
typevars = typeinfer.typevars
idx_ty = typevars[self.index.name].get()
ty = typevars[self.value.name].get()
self.signature = Signature(self.dtype, ty + idx_ty, None)
typeinfer.add_type(self.target, self.dtype, loc=self.loc)
def get_call_signature(self):
return self.signature
def fold_arg_vars(typevars, args, vararg, kws):
"""
Fold and resolve the argument variables of a function call.
"""
# Fetch all argument types, bail if any is unknown
n_pos_args = len(args)
kwds = [kw for (kw, var) in kws]
argtypes = [typevars[a.name] for a in args]
argtypes += [typevars[var.name] for (kw, var) in kws]
if vararg is not None:
argtypes.append(typevars[vararg.name])
if not all(a.defined for a in argtypes):
return
args = tuple(a.getone() for a in argtypes)
pos_args = args[:n_pos_args]
if vararg is not None:
errmsg = "*args in function call should be a tuple, got %s"
# Handle constant literal used for `*args`
if isinstance(args[-1], types.Literal):
const_val = args[-1].literal_value
# Is the constant value a tuple?
if not isinstance(const_val, tuple):
raise TypeError(errmsg % (args[-1],))
# Append the elements in the const tuple to the positional args
pos_args += const_val
# Handle non-constant
elif not isinstance(args[-1], types.BaseTuple):
# Unsuitable for *args
# (Python is more lenient and accepts all iterables)
raise TypeError(errmsg % (args[-1],))
else:
# Append the elements in the tuple to the positional args
pos_args += args[-1].types
# Drop the last arg
args = args[:-1]
kw_args = dict(zip(kwds, args[n_pos_args:]))
return pos_args, kw_args
def _is_array_not_precise(arrty):
"""Check type is array and it is not precise
"""
return isinstance(arrty, types.Array) and not arrty.is_precise()
class CallConstraint(object):
"""Constraint for calling functions.
Perform case analysis foreach combinations of argument types.
"""
signature = None
def __init__(self, target, func, args, kws, vararg, loc):
self.target = target
self.func = func
self.args = args
self.kws = kws or {}
self.vararg = vararg
self.loc = loc
def __call__(self, typeinfer):
msg = "typing of call at {0}\n".format(self.loc)
with new_error_context(msg):
typevars = typeinfer.typevars
with new_error_context(
"resolving caller type: {}".format(self.func)):
fnty = typevars[self.func].getone()
with new_error_context("resolving callee type: {0}", fnty):
self.resolve(typeinfer, typevars, fnty)
def resolve(self, typeinfer, typevars, fnty):
assert fnty
context = typeinfer.context
r = fold_arg_vars(typevars, self.args, self.vararg, self.kws)
if r is None:
# Cannot resolve call type until all argument types are known
return
pos_args, kw_args = r
# Check argument to be precise
for a in itertools.chain(pos_args, kw_args.values()):
# Forbids imprecise type except array of undefined dtype
if not a.is_precise() and not isinstance(a, types.Array):
return
# Resolve call type
if isinstance(fnty, types.TypeRef):
# Unwrap TypeRef
fnty = fnty.instance_type
try:
sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
except ForceLiteralArg as e:
# Adjust for bound methods
folding_args = ((fnty.this,) + tuple(self.args)
if isinstance(fnty, types.BoundFunction)
else self.args)
folded = e.fold_arguments(folding_args, self.kws)
requested = set()
unsatisfied = set()
for idx in e.requested_args:
maybe_arg = typeinfer.func_ir.get_definition(folded[idx])
if isinstance(maybe_arg, ir.Arg):
requested.add(maybe_arg.index)
else:
unsatisfied.add(idx)
if unsatisfied:
raise TypingError("Cannot request literal type.", loc=self.loc)
elif requested:
raise ForceLiteralArg(requested, loc=self.loc)
if sig is None:
# Note: duplicated error checking.
# See types.BaseFunction.get_call_type
# Arguments are invalid => explain why
headtemp = "Invalid use of {0} with parameters ({1})"
args = [str(a) for a in pos_args]
args += ["%s=%s" % (k, v) for k, v in sorted(kw_args.items())]
head = headtemp.format(fnty, ', '.join(map(str, args)))
desc = context.explain_function_type(fnty)
msg = '\n'.join([head, desc])
raise TypingError(msg)
typeinfer.add_type(self.target, sig.return_type, loc=self.loc)
# If the function is a bound function and its receiver type
# was refined, propagate it.
if (isinstance(fnty, types.BoundFunction)
and sig.recvr is not None
and sig.recvr != fnty.this):
refined_this = context.unify_pairs(sig.recvr, fnty.this)
if (refined_this is None and
fnty.this.is_precise() and
sig.recvr.is_precise()):
msg = "Cannot refine type {} to {}".format(
sig.recvr, fnty.this,
)
raise TypingError(msg, loc=self.loc)
if refined_this is not None and refined_this.is_precise():
refined_fnty = fnty.copy(this=refined_this)
typeinfer.propagate_refined_type(self.func, refined_fnty)
# If the return type is imprecise but can be unified with the
# target variable's inferred type, use the latter.
# Useful for code such as::
# s = set()
# s.add(1)
# (the set() call must be typed as int64(), not undefined())
if not sig.return_type.is_precise():
target = typevars[self.target]
if target.defined:
targetty = target.getone()
if context.unify_pairs(targetty, sig.return_type) == targetty:
sig = sig.replace(return_type=targetty)
self.signature = sig
self._add_refine_map(typeinfer, typevars, sig)
def _add_refine_map(self, typeinfer, typevars, sig):
"""Add this expression to the refine_map base on the type of target_type
"""
target_type = typevars[self.target].getone()
# Array
if (isinstance(target_type, types.Array)
and isinstance(sig.return_type.dtype, types.Undefined)):
typeinfer.refine_map[self.target] = self
# DictType
if (isinstance(target_type, types.DictType) and
not target_type.is_precise()):
typeinfer.refine_map[self.target] = self
def refine(self, typeinfer, updated_type):
# Is getitem?
if self.func == operator.getitem:
aryty = typeinfer.typevars[self.args[0].name].getone()
# is array not precise?
if _is_array_not_precise(aryty):
# allow refinement of dtype
assert updated_type.is_precise()
newtype = aryty.copy(dtype=updated_type.dtype)
typeinfer.add_type(self.args[0].name, newtype, loc=self.loc)
else:
m = 'no type refinement implemented for function {} updating to {}'
raise TypingError(m.format(self.func, updated_type))
def get_call_signature(self):
return self.signature
class IntrinsicCallConstraint(CallConstraint):
def __call__(self, typeinfer):
with new_error_context("typing of intrinsic-call at {0}", self.loc):
fnty = self.func
if fnty in utils.OPERATORS_TO_BUILTINS:
fnty = typeinfer.resolve_value_type(None, fnty)
self.resolve(typeinfer, typeinfer.typevars, fnty=fnty)
class GetAttrConstraint(object):
def __init__(self, target, attr, value, loc, inst):
self.target = target
self.attr = attr
self.value = value
self.loc = loc
self.inst = inst
def __call__(self, typeinfer):
with new_error_context("typing of get attribute at {0}", self.loc):
typevars = typeinfer.typevars
valtys = typevars[self.value.name].get()
for ty in valtys:
attrty = typeinfer.context.resolve_getattr(ty, self.attr)
if attrty is None:
raise UntypedAttributeError(ty, self.attr,
loc=self.inst.loc)
else:
assert attrty.is_precise()
typeinfer.add_type(self.target, attrty, loc=self.loc)
typeinfer.refine_map[self.target] = self
def refine(self, typeinfer, target_type):
if isinstance(target_type, types.BoundFunction):
recvr = target_type.this
assert recvr.is_precise()
typeinfer.add_type(self.value.name, recvr, loc=self.loc)
source_constraint = typeinfer.refine_map.get(self.value.name)
if source_constraint is not None:
source_constraint.refine(typeinfer, recvr)
def __repr__(self):
return 'resolving type of attribute "{attr}" of "{value}"'.format(
value=self.value, attr=self.attr)
class SetItemRefinement(object):
"""A mixin class to provide the common refinement logic in setitem
and static setitem.
"""
def _refine_target_type(self, typeinfer, targetty, idxty, valty, sig):
"""Refine the target-type given the known index type and value type.
"""
# For array setitem, refine imprecise array dtype
if _is_array_not_precise(targetty):
typeinfer.add_type(self.target.name, sig.args[0], loc=self.loc)
# For Dict setitem
if isinstance(targetty, types.DictType):
if not targetty.is_precise():
refined = targetty.refine(idxty, valty)
typeinfer.add_type(
self.target.name, refined,
loc=self.loc,
)
elif isinstance(targetty, types.LiteralStrKeyDict):
typeinfer.add_type(
self.target.name, types.DictType(idxty, valty),
loc=self.loc,
)
class SetItemConstraint(SetItemRefinement):
def __init__(self, target, index, value, loc):
self.target = target
self.index = index
self.value = value
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of setitem at {0}", self.loc):
typevars = typeinfer.typevars
if not all(typevars[var.name].defined
for var in (self.target, self.index, self.value)):
return
targetty = typevars[self.target.name].getone()
idxty = typevars[self.index.name].getone()
valty = typevars[self.value.name].getone()
sig = typeinfer.context.resolve_setitem(targetty, idxty, valty)
if sig is None:
raise TypingError("Cannot resolve setitem: %s[%s] = %s" %
(targetty, idxty, valty), loc=self.loc)
self.signature = sig
self._refine_target_type(typeinfer, targetty, idxty, valty, sig)
def get_call_signature(self):
return self.signature
class StaticSetItemConstraint(SetItemRefinement):
def __init__(self, target, index, index_var, value, loc):
self.target = target
self.index = index
self.index_var = index_var
self.value = value
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of staticsetitem at {0}", self.loc):
typevars = typeinfer.typevars
if not all(typevars[var.name].defined
for var in (self.target, self.index_var, self.value)):
return
targetty = typevars[self.target.name].getone()
idxty = typevars[self.index_var.name].getone()
valty = typevars[self.value.name].getone()
sig = typeinfer.context.resolve_static_setitem(targetty,
self.index, valty)
if sig is None:
sig = typeinfer.context.resolve_setitem(targetty, idxty, valty)
if sig is None:
raise TypingError("Cannot resolve setitem: %s[%r] = %s" %
(targetty, self.index, valty), loc=self.loc)
self.signature = sig
self._refine_target_type(typeinfer, targetty, idxty, valty, sig)
def get_call_signature(self):
return self.signature
class DelItemConstraint(object):
def __init__(self, target, index, loc):
self.target = target
self.index = index
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of delitem at {0}", self.loc):
typevars = typeinfer.typevars
if not all(typevars[var.name].defined
for var in (self.target, self.index)):
return
targetty = typevars[self.target.name].getone()
idxty = typevars[self.index.name].getone()
sig = typeinfer.context.resolve_delitem(targetty, idxty)
if sig is None:
raise TypingError("Cannot resolve delitem: %s[%s]" %
(targetty, idxty), loc=self.loc)
self.signature = sig
def get_call_signature(self):
return self.signature
class SetAttrConstraint(object):
def __init__(self, target, attr, value, loc):
self.target = target
self.attr = attr
self.value = value
self.loc = loc
def __call__(self, typeinfer):
with new_error_context("typing of set attribute {0!r} at {1}",
self.attr, self.loc):
typevars = typeinfer.typevars
if not all(typevars[var.name].defined
for var in (self.target, self.value)):
return
targetty = typevars[self.target.name].getone()
valty = typevars[self.value.name].getone()
sig = typeinfer.context.resolve_setattr(targetty, self.attr,
valty)
if sig is None:
raise TypingError("Cannot resolve setattr: (%s).%s = %s" %
(targetty, self.attr, valty),
loc=self.loc)
self.signature = sig
def get_call_signature(self):
return self.signature
class PrintConstraint(object):
def __init__(self, args, vararg, loc):
self.args = args
self.vararg = vararg
self.loc = loc
def __call__(self, typeinfer):
typevars = typeinfer.typevars
r = fold_arg_vars(typevars, self.args, self.vararg, {})
if r is None:
# Cannot resolve call type until all argument types are known
return
pos_args, kw_args = r
fnty = typeinfer.context.resolve_value_type(print)
assert fnty is not None
sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
self.signature = sig
def get_call_signature(self):
return self.signature
class TypeVarMap(dict):
def set_context(self, context):
self.context = context
def __getitem__(self, name):
if name not in self:
self[name] = TypeVar(self.context, name)
return super(TypeVarMap, self).__getitem__(name)
def __setitem__(self, name, value):
assert isinstance(name, str)
if name in self:
raise KeyError("Cannot redefine typevar %s" % name)
else:
super(TypeVarMap, self).__setitem__(name, value)
# A temporary mapping of {function name: dispatcher object}
_temporary_dispatcher_map = {}
# A temporary mapping of {function name: dispatcher object reference count}
# Reference: https://github.com/numba/numba/issues/3658
_temporary_dispatcher_map_ref_count = defaultdict(int)
@contextlib.contextmanager
def register_dispatcher(disp):
"""
Register a Dispatcher for inference while it is not yet stored
as global or closure variable (e.g. during execution of the @jit()
call). This allows resolution of recursive calls with eager
compilation.
"""
assert callable(disp)
assert callable(disp.py_func)
name = disp.py_func.__name__
_temporary_dispatcher_map[name] = disp
_temporary_dispatcher_map_ref_count[name] += 1
try:
yield
finally:
_temporary_dispatcher_map_ref_count[name] -= 1
if not _temporary_dispatcher_map_ref_count[name]:
del _temporary_dispatcher_map[name]
typeinfer_extensions = {}
class TypeInferer(object):
"""
Operates on block that shares the same ir.Scope.
"""
def __init__(self, context, func_ir, warnings):
self.context = context
# sort based on label, ensure iteration order!
self.blocks = OrderedDict()
for k in sorted(func_ir.blocks.keys()):
self.blocks[k] = func_ir.blocks[k]
self.generator_info = func_ir.generator_info
self.func_id = func_ir.func_id
self.func_ir = func_ir
self.typevars = TypeVarMap()
self.typevars.set_context(context)
self.constraints = ConstraintNetwork()
self.warnings = warnings
# { index: mangled name }
self.arg_names = {}
# self.return_type = None
# Set of assumed immutable globals
self.assumed_immutables = set()
# Track all calls and associated constraints
self.calls = []
# The inference result of the above calls
self.calltypes = utils.UniqueDict()
# Target var -> constraint with refine hook
self.refine_map = {}
if config.DEBUG or config.DEBUG_TYPEINFER:
self.debug = TypeInferDebug(self)
else:
self.debug = NullDebug()
self._skip_recursion = False
def copy(self, skip_recursion=False):
clone = TypeInferer(self.context, self.func_ir, self.warnings)