Skip to content

Commit

Permalink
Implement modules and require
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorNogueiraRio committed Jun 21, 2021
1 parent 4f3f559 commit 3b02c44
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 30 deletions.
3 changes: 2 additions & 1 deletion pallene/builtins.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()})
builtins.functions = {
type = T.Function({ T.Any() }, { T.String() }),
tostring = T.Function({ T.Any() }, { T.String() }),
ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()})
ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}),
require = T.Function({ T.String() }, { T.Any() })
}

builtins.modules = {
Expand Down
78 changes: 74 additions & 4 deletions pallene/checker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ local util = require "pallene.util"
local checker = {}

local Checker = util.Class()
local driver = {}

-- Type-check a Pallene module
-- On success, returns the typechecked module for the program
-- On failure, returns false and a list of compilation errors
function checker.check(prog_ast)
function checker.check(prog_ast, driver_passed)
driver = driver_passed
local co = coroutine.create(function()
return Checker.new():check_program(prog_ast)
end)
Expand Down Expand Up @@ -96,6 +98,7 @@ function Checker:init()
self.module_symbol = false -- checker.Symbol.Module
self.symbol_table = symtab.new() -- string => checker.Symbol
self.ret_types_stack = {} -- stack of types.T
self.imported_modules = {} -- Imported modules
return self
end

Expand Down Expand Up @@ -316,13 +319,51 @@ function Checker:expand_function_returns(rhs)
end
end

function Checker:is_the_module_variable(exp)
function Checker:is_the_module_variable(var_name)
return self.module_symbol == self.symbol_table:find_symbol(var_name)
end

function Checker:exp_is_the_module_variable(exp)
-- Check if the expression is the module variable without calling check_exp.
-- Doing that would have raised an exception because it is not a value.
return (
exp._tag == "ast.Exp.Var" and
exp.var._tag == "ast.Var.Name" and
(self.module_symbol == self.symbol_table:find_symbol(exp.var.name)))
self:is_the_module_variable(exp.var.name))
end

function Checker:init_imported_module(var_name, mod_name)
local mod_ast = self.imported_modules[mod_name]
local symbols = {}
for _, tls in ipairs(mod_ast.tls) do
if tls._tag == "ast.Toplevel.Stats" then
local stats = tls.stats
for _, stat in ipairs(stats) do
if stat._tag == "ast.Stat.Functions" then
for _, func_stat in ipairs(stat.funcs) do
if func_stat.module then
local typ = func_stat._type
local def = checker.Def.Function(func_stat)
symbols[func_stat.name] = checker.Symbol.Value(typ, def)
end
end
elseif stat._tag == "ast.Stat.Assign" then
for _, var in ipairs(stat.vars) do
if var._exported_as then
symbols[var.name] = checker.Symbol.Value(var._type, var._def)
end
end
end
end
end
end
self:add_module_symbol(var_name, types.T.String(), symbols)
end

local function exp_is_require(exp)
assert(exp._tag == "ast.Exp.Var")
local def = exp.var._def
return def and def._tag == "checker.Def.Builtin" and def.id == "require"
end

function Checker:check_stat(stat, is_toplevel)
Expand All @@ -348,6 +389,11 @@ function Checker:check_stat(stat, is_toplevel)
stat.exps[i] = self:check_initializer_exp(
stat.decls[i], stat.exps[i],
"declaration of local variable '%s'", stat.decls[i].name)
if stat.exps[i]._tag == "ast.Exp.CallFunc" and exp_is_require(stat.exps[i].exp) then
self:init_imported_module(stat.decls[i].name, stat.exps[i].args[i].value)
table.remove(stat.decls, i)
table.remove(stat.exps, i)
end
end
for i = m + 1, n do
stat.exps[i] = self:check_exp_synthesize(stat.exps[i])
Expand Down Expand Up @@ -486,7 +532,7 @@ function Checker:check_stat(stat, is_toplevel)
elseif tag == "ast.Stat.Assign" then

for i, var in ipairs(stat.vars) do
if var._tag == "ast.Var.Dot" and self:is_the_module_variable(var.exp) then
if var._tag == "ast.Var.Dot" and self:exp_is_the_module_variable(var.exp) then
-- Declaring a module field
if not is_toplevel then
type_error(var.loc, "module fields can only be set at the toplevel")
Expand Down Expand Up @@ -670,6 +716,10 @@ function Checker:try_flatten_to_qualified_name(outer_var)
local q = ast.Var.Name(var.loc, table.concat(components, "."))
q._type = sym.typ
q._def = sym.def

local is_builtin = sym.def._tag == "checker.Def.Builtin"
q._mod_name = not is_builtin and not self:is_the_module_variable(root) and root

return q
end

Expand Down Expand Up @@ -747,6 +797,22 @@ function Checker:coerce_numeric_exp_to_float(exp)
end
end

function Checker:check_require(exp)
local args = exp.args
local arg = args[1]
local filename = string.format("%s.pln", arg.value)
local input, err = driver.load_input(filename)
if err then
type_error(exp.loc, "Can't find module %s\n", arg.value)
end

local module_ast, err = driver.compile_internal(filename, input, "checker", 0)
if not module_ast then
type_error(exp.loc, "Error loading module %s: %s", arg.value, err[1])
end
self.imported_modules[arg.value] = module_ast
end

-- Check (synthesize) the type of a function call expression.
-- If the function returns 0 arguments, it is only allowed in a statement context.
-- Void functions in an expression context are a constant source of headaches.
Expand Down Expand Up @@ -786,6 +852,10 @@ function Checker:check_fun_call(exp, is_stat)
end
exp._types = f_type.ret_types

if exp_is_require(exp.exp) then
self:check_require(exp)
end

return exp
end

Expand Down
85 changes: 79 additions & 6 deletions pallene/coder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ function Coder:init(module, modname, filename)
self.upvalue_of_string = {} -- str => integer
self.upvalue_of_function = {} -- f_id => integer
self.upvalue_of_global = {} -- g_id => integer
self.upvalue_of_imported_function = {} -- f_id => integer
self.upvalue_of_imported_var = {} -- v_id => integer
self:init_upvalues()

self.record_ids = {} -- types.T.Record => integer
Expand Down Expand Up @@ -332,6 +334,14 @@ function Coder:c_value(value)
local f_id = value.id
local typ = self.module.functions[f_id].typ
return lua_value(typ, self:function_upvalue_slot(f_id))
elseif tag == "ir.Value.ImportedFunction" then
local f_id = value.id
local typ = self.module.imported_functions[f_id].typ
return lua_value(typ, self:imported_function_upvalue_slot(f_id))
elseif tag == "ir.Value.ImportedVar" then
local v_id = value.id
local typ = self.module.imported_vars[v_id].typ
return lua_value(typ, self:imported_var_upvalue_slot(v_id))
elseif typedecl.match_tag(tag, "ir.Value") then
typedecl.tag_error(tag, "unable to get C expression for this value type.")
else
Expand Down Expand Up @@ -627,7 +637,9 @@ typedecl.declare(coder, "coder", "Upvalue", {
Metatable = {"typ"},
String = {"str"},
Function = {"f_id"},
ImportedFunction = {"f_id"},
Global = {"g_id"},
ImportedVar = {"v_id"},
})

function Coder:init_upvalues()
Expand Down Expand Up @@ -683,6 +695,16 @@ function Coder:init_upvalues()
table.insert(self.upvalues, coder.Upvalue.Global(g_id))
self.upvalue_of_global[g_id] = #self.upvalues
end

for f_id = 1, #self.module.imported_functions do
table.insert(self.upvalues, coder.Upvalue.ImportedFunction(f_id))
self.upvalue_of_imported_function[f_id] = #self.upvalues
end

for v_id = 1, #self.module.imported_vars do
table.insert(self.upvalues, coder.Upvalue.ImportedVar(v_id))
self.upvalue_of_imported_var[v_id] = #self.upvalues
end
end

local function upvalue_slot(ix)
Expand All @@ -704,6 +726,16 @@ function Coder:function_upvalue_slot(f_id)
return upvalue_slot(ix)
end

function Coder:imported_function_upvalue_slot(f_id)
local ix = assert(self.upvalue_of_imported_function[f_id])
return upvalue_slot(ix)
end

function Coder:imported_var_upvalue_slot(v_id)
local ix = assert(self.upvalue_of_imported_var[v_id])
return upvalue_slot(ix)
end

function Coder:global_upvalue_slot(g_id)
local ix = assert(self.upvalue_of_global[g_id])
return upvalue_slot(ix)
Expand Down Expand Up @@ -1661,6 +1693,7 @@ end
function Coder:generate_luaopen_function()

local init_constants = {}
local modules = {}
for ix, upv in ipairs(self.upvalues) do
local tag = upv._tag
if tag ~= "coder.Upvalue.Global" then
Expand All @@ -1683,16 +1716,56 @@ function Coder:generate_luaopen_function()
entry_point = self:lua_entry_point_name(upv.f_id),
ix = C.integer(self.upvalue_of_function[upv.f_id]),
}))
elseif tag == "coder.Upvalue.ImportedFunction" then
local imported_func = self.module.imported_functions[upv.f_id]
modules[imported_func.mod] = modules[imported_func.mod] or {}
imported_func.id = upv.f_id
imported_func._tag = "func"
table.insert(modules[imported_func.mod], imported_func)
elseif tag == "coder.Upvalue.ImportedVar" then
local imported_var = self.module.imported_vars[upv.v_id]
modules[imported_var.mod] = modules[imported_var.mod] or {}
imported_var.id = upv.v_id
imported_var._tag = "var"
table.insert(modules[imported_var.mod], imported_var)
else
typedecl.tag_error(tag)
end

table.insert(init_constants, util.render([[
lua_setiuservalue(L, globals, $ix);
/**/
]], {
ix = C.integer(ix),
}))
if tag ~= "coder.Upvalue.ImportedFunction" and tag ~= "coder.Upvalue.ImportedVar" then
table.insert(init_constants, util.render([[
lua_setiuservalue(L, globals, $ix);
/**/
]], {
ix = C.integer(ix),
}))
end
end
end

for mod_name, fields in pairs(modules) do
table.insert(init_constants, util.render([[
lua_getglobal(L, "require");
lua_pushstring(L, "$mod_name");
lua_call(L, 1, 1);
if (PALLENE_UNLIKELY(lua_type(L, -1) != LUA_TTABLE))
luaL_error(L, "Module not found");
]], { mod_name = mod_name }))
for _, field in ipairs(fields) do
local id = field.id
local ix
if field._tag == "func" then
ix = C.integer(self.upvalue_of_imported_function[id])
else
ix = C.integer(self.upvalue_of_imported_var[id])
end

table.insert(init_constants, util.render([[
lua_getfield(L, -1, "$field_name");
if (PALLENE_UNLIKELY(lua_type(L, -1) == LUA_TNIL))
luaL_error(L, "field %s is nil", "$field_name");
lua_setiuservalue(L, globals, $ix);
]], { field_name = field.name, ix = ix }))
end
end

Expand Down
15 changes: 8 additions & 7 deletions pallene/constant_propagation.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ local constant_propagation = {}

local function is_constant_value(v)
local tag = v._tag
if tag == "ir.Value.Nil" then return true
elseif tag == "ir.Value.Bool" then return true
elseif tag == "ir.Value.Integer" then return true
elseif tag == "ir.Value.Float" then return true
elseif tag == "ir.Value.String" then return true
elseif tag == "ir.Value.LocalVar" then return false
elseif tag == "ir.Value.Function" then return true
if tag == "ir.Value.Nil" then return true
elseif tag == "ir.Value.Bool" then return true
elseif tag == "ir.Value.Integer" then return true
elseif tag == "ir.Value.Float" then return true
elseif tag == "ir.Value.String" then return true
elseif tag == "ir.Value.LocalVar" then return false
elseif tag == "ir.Value.Function" then return true
elseif tag == "ir.Value.ImportedVar" then return true
else
typedecl.tag_error(tag)
end
Expand Down
2 changes: 1 addition & 1 deletion pallene/driver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function driver.compile_internal(filename, input, stop_after, opt_level)
return prog_ast, errs
end

prog_ast, errs = checker.check(prog_ast)
prog_ast, errs = checker.check(prog_ast, driver)
if stop_after == "checker" or not prog_ast then
return prog_ast, errs
end
Expand Down
34 changes: 26 additions & 8 deletions pallene/ir.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function ir.Module()
globals = {}, -- list of ir.VarDecl
exported_functions = {}, -- list of function ids
exported_globals = {}, -- list of variable ids
imported_functions = {}, -- list of imported functions
imported_vars = {}, -- list of imported variables
}
end

Expand Down Expand Up @@ -61,6 +63,14 @@ function ir.Function(loc, name, typ)
}
end

function ir.ImportedFunction(name, typ, mod)
return {
name = name, -- string
typ = typ, -- Type
mod = mod, -- Module name
}
end

---
--- Mutate modules
--
Expand Down Expand Up @@ -88,6 +98,12 @@ function ir.add_exported_global(module, g_id)
table.insert(module.exported_globals, g_id)
end

function ir.add_imported_function(module, mod_name, name, typ)
table.insert(module.imported_functions, ir.ImportedFunction(name, typ, mod_name))

return #module.imported_functions
end

--
-- Function variables
--
Expand All @@ -114,14 +130,16 @@ end
--

declare_type("Value", {
Nil = {},
Bool = {"value"},
Integer = {"value"},
Float = {"value"},
String = {"value"},
LocalVar = {"id"},
Upvalue = {"id"},
Function = {"id"},
Nil = {},
Bool = {"value"},
Integer = {"value"},
Float = {"value"},
String = {"value"},
LocalVar = {"id"},
Upvalue = {"id"},
Function = {"id"},
ImportedFunction = {"id"},
ImportedVar = {"id"},
})

-- declare_type("Cmd"
Expand Down

0 comments on commit 3b02c44

Please sign in to comment.