Skip to content

Commit

Permalink
Type narrowing proof of concept
Browse files Browse the repository at this point in the history
  • Loading branch information
fgaz committed Nov 15, 2022
1 parent eeaaa72 commit 3a15d99
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 12 deletions.
59 changes: 59 additions & 0 deletions spec/is/is.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
local util = require("spec.util")

describe("Is<T>:", function()

it("is_ function", util.check [[
local record Is<T> end
local record MyRecord
a: number
end
local record OtherRecord
a: boolean
end
local r : MyRecord | OtherRecord = { a = 1 }
local n : number
local function is_myrecord(x: any): Is<MyRecord>
if x is table then
local a = x.a
return (a is number)
else return false end
end
if is_myrecord(r) then
n = r.a
end
]])

it("is_ method", util.check [[
local record Is<T> end
local record A
is_b : function(self : A | B) : Is<B>
end
local record B
is_b : function(self : A | B) : Is<B>
b_field : string
end
local b1 : B = {
is_b = function(self : A | B) : Is<B>
-- In a real program this would actually do something
return true
end,
b_field = "yes",
}
local ab : A | B = b1
if ab:is_b() then
local b2 : B = ab
local s : string = b2.b_field
end
]])

end)
53 changes: 47 additions & 6 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5585,7 +5585,6 @@ tl.type_check = function(ast, opts)
local function is_valid_union(typ)


local n_table_types = 0
local n_function_types = 0
local n_userdata_types = 0
local n_string_enum = 0
Expand All @@ -5597,11 +5596,6 @@ tl.type_check = function(ast, opts)
if n_userdata_types > 1 then
return false, "cannot discriminate a union between multiple userdata types: %s"
end
elseif ut == "table" then
n_table_types = n_table_types + 1
if n_table_types > 1 then
return false, "cannot discriminate a union between multiple table types: %s"
end
elseif ut == "function" then
n_function_types = n_function_types + 1
if n_function_types > 1 then
Expand Down Expand Up @@ -6711,6 +6705,10 @@ tl.type_check = function(ast, opts)
return false, terr(t1, "enum is incompatible with %s", t2)
end
elseif t1.typename == "integer" and t2.typename == "number" then
return true
elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then


return true
elseif t1.typename == "string" and t2.typename == "enum" then
local ok = t1.tk and t2.enumset[unquote(t1.tk)]
Expand Down Expand Up @@ -7190,6 +7188,28 @@ tl.type_check = function(ast, opts)
else
return nil, "invalid key '" .. key .. "' in type %s"
end

elseif tbl.typename == "union" then
assert(tbl.types[1], "Union has no members")
local field
for _, t in ipairs(tbl.types) do

t = resolve_tuple_and_nominal(t)
t = resolve_typetype(t)


if not is_record_type(t) then
return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)"
end
assert(t.fields, "record has no fields!?")

if not t.fields[key] then
return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s"
else
field = t.fields[key]
end
end
return field
elseif tbl.typename == "emptytable" or is_unknown(tbl) then
if lax then
return INVALID
Expand Down Expand Up @@ -9114,6 +9134,27 @@ node.exps[3] and node.exps[3].type, }
end
end
elseif node.op.op == "@funcall" then


local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is"
local is_method = node.e1.op and node.e1.op.op == ":"
local first_arg = node.e2[1]
if is_is_function and (first_arg or is_method) then
local refined_var


if is_method then
refined_var = node.e1.e1.tk
else
refined_var = first_arg.tk
end
node.known = Fact({
fact = "is",
var = refined_var,
typ = a.rets[1].typevals[1],
where = node,
})
end
if lax and is_unknown(a) then
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
Expand Down
53 changes: 47 additions & 6 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -5585,7 +5585,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
local function is_valid_union(typ: Type): boolean, string
-- check for limitations in our union support
-- due to codegen limitations (we only check with type() so far)
local n_table_types = 0
local n_function_types = 0
local n_userdata_types = 0
local n_string_enum = 0
Expand All @@ -5597,11 +5596,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
if n_userdata_types > 1 then
return false, "cannot discriminate a union between multiple userdata types: %s"
end
elseif ut == "table" then
n_table_types = n_table_types + 1
if n_table_types > 1 then
return false, "cannot discriminate a union between multiple table types: %s"
end
elseif ut == "function" then
n_function_types = n_function_types + 1
if n_function_types > 1 then
Expand Down Expand Up @@ -6712,6 +6706,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
elseif t1.typename == "integer" and t2.typename == "number" then
return true
elseif t1.typename == "boolean" and t2.typename == "nominal" and t2.tk == "Is" then
-- Treat booleans as Is<>, so that is_() functions don't have to cast
-- their return values.
return true
elseif t1.typename == "string" and t2.typename == "enum" then
local ok = t1.tk and t2.enumset[unquote(t1.tk)]
if ok then
Expand Down Expand Up @@ -7190,6 +7188,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
else
return nil, "invalid key '" .. key .. "' in type %s"
end
-- Union of records
elseif tbl.typename == "union" then
assert(tbl.types[1], "Union has no members")
local field : Type
for _,t in ipairs(tbl.types) do
-- TODO probably doing too much stuff
t = resolve_tuple_and_nominal(t)
t = resolve_typetype(t)
-- TODO support unions of unions, recursively
-- (properly, so that eg. we do all those extra checks outside of this if)
if not is_record_type(t) then
return nil, "cannot index key '" .. key .. "' in '" .. t.tk .. "' from union %s (not a record)"
end
assert(t.fields, "record has no fields!?")
-- key should be in all records
if not t.fields[key] then
return nil, "invalid key '" .. key .. "' in type '" .. t.tk .. "' from union %s"
else
field = t.fields[key]
end
end
return field
elseif tbl.typename == "emptytable" or is_unknown(tbl) then
if lax then
return INVALID
Expand Down Expand Up @@ -9114,6 +9134,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string
end
end
elseif node.op.op == "@funcall" then
-- Huge hack, Is should be its own type, or its definition should
-- be provided by teal.
local is_is_function = a.typename == "function" and a.rets[1] and a.rets[1].tk == "Is"
local is_method = node.e1.op and node.e1.op.op == ":"
local first_arg = node.e2[1]
if is_is_function and (first_arg or is_method) then
local refined_var : string
-- If it's a method call, we refine self, otherwise the first
-- argument of the function.
if is_method then
refined_var = node.e1.e1.tk
else
refined_var = first_arg.tk
end
node.known = Fact {
fact = "is",
var = refined_var,
typ = a.rets[1].typevals[1], -- type argument of Is<>
where = node,
}
end
if lax and is_unknown(a) then
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
Expand Down

0 comments on commit 3a15d99

Please sign in to comment.