Skip to content

Commit

Permalink
Sync to upstream/release/624 (#1245)
Browse files Browse the repository at this point in the history
# What's changed?

* Optimize table.maxn.  This function is now 5-14x faster
* Reserve Luau stack space for error message.

## New Solver

* Globals can be type-stated, but only if they are already in scope
* Fix a stack overflow that could occur when normalizing certain kinds
of recursive unions of intersections (of unions of intersections...)
* Fix an assertion failure that would trigger when the __iter metamethod
has a bad signature

## Native Codegen

* Type propagation and temporary register type hints
* Direct vector property access should only happen for names of right
length
* BytecodeAnalysis will only predict that some of the vector value
fields are numbers

---

## Internal Contributors

Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
  • Loading branch information
4 people committed May 3, 2024
1 parent 7edd58a commit 8a64cb8
Show file tree
Hide file tree
Showing 30 changed files with 1,655 additions and 114 deletions.
1 change: 1 addition & 0 deletions Analysis/include/Luau/Normalize.h
Expand Up @@ -395,6 +395,7 @@ class Normalizer
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
NormalizationResult intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated = false);

// ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there);
Expand Down
10 changes: 6 additions & 4 deletions Analysis/src/ConstraintGenerator.cpp
Expand Up @@ -1900,10 +1900,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa
return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)};
}
else
{
reportError(global->location, UnknownSymbol{global->name.value, UnknownSymbol::Binding});
return Inference{builtinTypes->errorRecoveryType()};
}
}

Inference ConstraintGenerator::checkIndexName(const ScopePtr& scope, const RefinementKey* key, AstExpr* indexee, const std::string& index, Location indexLocation)
Expand Down Expand Up @@ -2453,7 +2450,12 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt
{
std::optional<TypeId> annotatedTy = scope->lookup(Symbol{global->name});
if (annotatedTy)
return {annotatedTy, arena->addType(BlockedType{})};
{
DefId def = dfg->getDef(global);
TypeId assignedTy = arena->addType(BlockedType{});
rootScope->lvalueTypes[def] = assignedTy;
return {annotatedTy, assignedTy};
}
else
return {annotatedTy, std::nullopt};
}
Expand Down
14 changes: 8 additions & 6 deletions Analysis/src/ConstraintSolver.cpp
Expand Up @@ -1619,9 +1619,7 @@ std::pair<bool, std::optional<TypeId>> ConstraintSolver::tryDispatchSetIndexer(
{
if (tt->indexer)
{
if (isBlocked(tt->indexer->indexType))
return {block(tt->indexer->indexType, constraint), std::nullopt};
else if (isBlocked(tt->indexer->indexResultType))
if (isBlocked(tt->indexer->indexResultType))
return {block(tt->indexer->indexResultType, constraint), std::nullopt};

unify(constraint, indexType, tt->indexer->indexType);
Expand Down Expand Up @@ -2014,10 +2012,14 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl
if (std::optional<TypeId> instantiatedNextFn = instantiate(builtinTypes, arena, NotNull{&limits}, constraint->scope, nextFn))
{
const FunctionType* nextFn = get<FunctionType>(*instantiatedNextFn);
LUAU_ASSERT(nextFn);
const TypePackId nextRetPack = nextFn->retTypes;

pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true});
// If nextFn is nullptr, then the iterator function has an improper signature.
if (nextFn)
{
const TypePackId nextRetPack = nextFn->retTypes;
pushConstraint(constraint->scope, constraint->location, UnpackConstraint{c.variables, nextRetPack, /* resultIsLValue=*/true});
}

return true;
}
else
Expand Down
121 changes: 87 additions & 34 deletions Analysis/src/Normalize.cpp
Expand Up @@ -20,6 +20,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false);
LUAU_FASTFLAGVARIABLE(LuauFixCyclicUnionsOfIntersections, false);
LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false);

// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
Expand All @@ -36,6 +37,11 @@ static bool fixCyclicUnionsOfIntersections()
return FFlag::LuauFixCyclicUnionsOfIntersections || FFlag::DebugLuauDeferredConstraintResolution;
}

static bool fixReduceStackPressure()
{
return FFlag::LuauFixReduceStackPressure || FFlag::DebugLuauDeferredConstraintResolution;
}

namespace Luau
{

Expand All @@ -45,6 +51,14 @@ static bool normalizeAwayUninhabitableTables()
return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution;
}

static bool shouldEarlyExit(NormalizationResult res)
{
// if res is hit limits, return control flow
if (res == NormalizationResult::HitLimits || res == NormalizationResult::False)
return true;
return false;
}

TypeIds::TypeIds(std::initializer_list<TypeId> tys)
{
for (TypeId ty : tys)
Expand Down Expand Up @@ -1729,6 +1743,27 @@ bool Normalizer::withinResourceLimits()
return true;
}

NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, NormalizedType& intersect, bool useDeprecated)
{

std::optional<NormalizedType> negated;
if (useDeprecated)
{
const NormalizedType* normal = DEPRECATED_normalize(toNegate);
negated = negateNormal(*normal);
}
else
{
std::shared_ptr<const NormalizedType> normal = normalize(toNegate);
negated = negateNormal(*normal);
}

if (!negated)
return NormalizationResult::False;
intersectNormals(intersect, *negated);
return NormalizationResult::True;
}

// See above for an explaination of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars)
{
Expand Down Expand Up @@ -2541,8 +2576,8 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
state = tttv->state;

TypeLevel level = max(httv->level, tttv->level);
TableType result{state, level};

std::unique_ptr<TableType> result = nullptr;
bool hereSubThere = true;
bool thereSubHere = true;

Expand All @@ -2563,8 +2598,18 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
if (tprop.readTy.has_value())
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
if (fixReduceStackPressure())
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::True != isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}
else
{
if (normalizeAwayUninhabitableTables() &&
NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy))
return {builtinTypes->neverType};
}

TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
Expand Down Expand Up @@ -2614,14 +2659,21 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: string indexers

if (prop.readTy || prop.writeTy)
result.props[name] = prop;
{
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level});
result->props[name] = prop;
}
}

for (const auto& [name, tprop] : tttv->props)
{
if (httv->props.count(name) == 0)
{
result.props[name] = tprop;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level});

result->props[name] = tprop;
hereSubThere = false;
}
}
Expand All @@ -2631,18 +2683,24 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
// TODO: What should intersection of indexes be?
TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType);
TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType);
result.indexer = {index, indexResult};
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level});
result->indexer = {index, indexResult};
hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult);
thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult);
}
else if (httv->indexer)
{
result.indexer = httv->indexer;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level});
result->indexer = httv->indexer;
thereSubHere = false;
}
else if (tttv->indexer)
{
result.indexer = tttv->indexer;
if (!result.get())
result = std::make_unique<TableType>(TableType{state, level});
result->indexer = tttv->indexer;
hereSubThere = false;
}

Expand All @@ -2652,7 +2710,12 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
else if (thereSubHere)
table = ttable;
else
table = arena->addType(std::move(result));
{
if (result.get())
table = arena->addType(std::move(*result));
else
table = arena->addType(TableType{state, level});
}

if (tmtable && hmtable)
{
Expand Down Expand Up @@ -3150,19 +3213,15 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
if (fixNormalizeCaching())
{
std::shared_ptr<const NormalizedType> normal = normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
NormalizationResult res = intersectNormalWithNegationTy(t, here);
if (shouldEarlyExit(res))
return res;
}
else
{
const NormalizedType* normal = DEPRECATED_normalize(t);
std::optional<NormalizedType> negated = negateNormal(*normal);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
NormalizationResult res = intersectNormalWithNegationTy(t, here, /* useDeprecated */ true);
if (shouldEarlyExit(res))
return res;
}
}
else if (const UnionType* itv = get<UnionType>(t))
Expand All @@ -3171,11 +3230,9 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
for (TypeId part : itv->options)
{
std::shared_ptr<const NormalizedType> normalPart = normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
NormalizationResult res = intersectNormalWithNegationTy(part, here);
if (shouldEarlyExit(res))
return res;
}
}
else
Expand All @@ -3184,22 +3241,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type
{
for (TypeId part : itv->options)
{
std::shared_ptr<const NormalizedType> normalPart = normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
NormalizationResult res = intersectNormalWithNegationTy(part, here);
if (shouldEarlyExit(res))
return res;
}
}
else
{
for (TypeId part : itv->options)
{
const NormalizedType* normalPart = DEPRECATED_normalize(part);
std::optional<NormalizedType> negated = negateNormal(*normalPart);
if (!negated)
return NormalizationResult::False;
intersectNormals(here, *negated);
NormalizationResult res = intersectNormalWithNegationTy(part, here, /* useDeprecated */ true);
if (shouldEarlyExit(res))
return res;
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion Analysis/src/TypeChecker2.cpp
Expand Up @@ -1280,7 +1280,9 @@ struct TypeChecker2

void visit(AstExprGlobal* expr)
{
// TODO!
NotNull<Scope> scope = stack.back();
if (!scope->lookup(expr->name))
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
}

void visit(AstExprVarargs* expr)
Expand Down
1 change: 1 addition & 0 deletions CLI/Compile.cpp
Expand Up @@ -317,6 +317,7 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
{
options.includeAssembly = format != CompileFormat::CodegenIr;
options.includeIr = format != CompileFormat::CodegenAsm;
options.includeIrTypes = format != CompileFormat::CodegenAsm;
options.includeOutlinedCode = format == CompileFormat::CodegenVerbose;
}

Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Expand Up @@ -229,6 +229,7 @@ if(LUAU_BUILD_TESTS)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)

target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS})
target_compile_definitions(Luau.Conformance PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY)
target_include_directories(Luau.Conformance PRIVATE extern)
target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM)
if(CMAKE_SYSTEM_NAME MATCHES "Android|iOS")
Expand Down
4 changes: 4 additions & 0 deletions CodeGen/include/Luau/CodeGen.h
Expand Up @@ -40,8 +40,12 @@ enum class CodeGenCompilationResult
CodeGenAssemblerFinalizationFailure = 7, // Failure during assembler finalization
CodeGenLoweringFailure = 8, // Lowering failed
AllocationFailed = 9, // Native codegen failed due to an allocation error

Count = 10,
};

std::string toString(const CodeGenCompilationResult& result);

struct ProtoCompilationFailure
{
CodeGenCompilationResult result = CodeGenCompilationResult::Success;
Expand Down
32 changes: 28 additions & 4 deletions CodeGen/src/BytecodeAnalysis.cpp
Expand Up @@ -6,13 +6,17 @@
#include "Luau/IrUtils.h"

#include "lobject.h"
#include "lstate.h"

#include <algorithm>

#include <algorithm>

LUAU_FASTFLAG(LuauCodegenDirectUserdataFlow)
LUAU_FASTFLAG(LuauLoadTypeInfo) // Because new VM typeinfo load changes the format used by Codegen, same flag is used
LUAU_FASTFLAGVARIABLE(LuauCodegenTypeInfo, false) // New analysis is flagged separately
LUAU_FASTFLAG(LuauTypeInfoLookupImprovement)
LUAU_FASTFLAGVARIABLE(LuauCodegenVectorMispredictFix, false)

namespace Luau
{
Expand Down Expand Up @@ -771,10 +775,30 @@ void analyzeBytecodeTypes(IrFunction& function)

regTags[ra] = LBC_TYPE_ANY;

// Assuming that vector component is being indexed
// TODO: check what key is used
if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_NUMBER;
if (FFlag::LuauCodegenVectorMispredictFix)
{
if (bcType.a == LBC_TYPE_VECTOR)
{
TString* str = gco2ts(function.proto->k[kc].value.gc);
const char* field = getstr(str);

if (str->len == 1)
{
// Same handling as LOP_GETTABLEKS block in lvmexecute.cpp - case-insensitive comparison with "X" / "Y" / "Z"
char ch = field[0] | ' ';

if (ch == 'x' || ch == 'y' || ch == 'z')
regTags[ra] = LBC_TYPE_NUMBER;
}
}
}
else
{
// Assuming that vector component is being indexed
// TODO: check what key is used
if (bcType.a == LBC_TYPE_VECTOR)
regTags[ra] = LBC_TYPE_NUMBER;
}

bcType.result = regTags[ra];
break;
Expand Down

0 comments on commit 8a64cb8

Please sign in to comment.