Skip to content

Commit

Permalink
Add pattern matching for records (#108)
Browse files Browse the repository at this point in the history
This PR follows up on #103 by adding pattern matching for records. E.g.
```ocaml
let f = fun r -> match r with { | {x = x; y = y} -> x + y } in f {x = 3.3; y = 5.1}
```

It also adds a type error when record literals or patterns use duplicate
field names.
  • Loading branch information
siddharth-krishna committed Mar 26, 2024
1 parent c2bee8a commit 02bb23a
Show file tree
Hide file tree
Showing 19 changed files with 150 additions and 6 deletions.
3 changes: 3 additions & 0 deletions inferno-core/CHANGELOG.md
@@ -1,6 +1,9 @@
# Revision History for inferno-core
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.11.2.0 -- 2024-03-26
* Add record pattern matching

## 0.11.1.0 -- 2024-03-18
* HLint everything

Expand Down
2 changes: 1 addition & 1 deletion inferno-core/inferno-core.cabal
@@ -1,6 +1,6 @@
cabal-version: 2.4
name: inferno-core
version: 0.11.1.0
version: 0.11.2.0
synopsis: A statically-typed functional scripting language
description: Parser, type inference, and interpreter for a statically-typed functional scripting language
category: DSL,Scripting
Expand Down
9 changes: 9 additions & 0 deletions inferno-core/src/Inferno/Eval.hs
Expand Up @@ -8,10 +8,12 @@ import Control.Monad.Except (forM)
import Control.Monad.Reader (ask, local)
import Data.Foldable (foldrM)
import Data.Functor ((<&>))
import Data.List (sortOn)
import Data.List.NonEmpty (NonEmpty (..), toList)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import qualified Data.Text as Text
import Data.Tuple.Extra (fst3)
import Inferno.Eval.Error
( EvalError (AssertionFailed, RuntimeError),
)
Expand Down Expand Up @@ -261,6 +263,13 @@ eval env@(localEnv, pinnedEnv) expr = case expr of
(VEmpty, PEmpty _) -> Just mempty
(VArray vs, PArray _ ps _) -> matchElems vs ps
(VTuple vs, PTuple _ ps _) -> matchElems vs $ tListToList ps
(VRecord vs, PRecord _ ps _) ->
if fs == fs'
then matchElems vs' ps'
else Nothing
where
(fs, vs') = unzip $ Map.toAscList vs

Check warning on line 271 in inferno-core/src/Inferno/Eval.hs

View workflow job for this annotation

GitHub Actions / build

Warning in eval in module Inferno.Eval: Avoid NonEmpty.unzip ▫︎ Found: "unzip" ▫︎ Perhaps: "Data.Functor.unzip" ▫︎ Note: The function is being deprecated
(fs', ps') = unzip $ map (\(f, p', l) -> (f, (p', l))) $ sortOn fst3 ps

Check warning on line 272 in inferno-core/src/Inferno/Eval.hs

View workflow job for this annotation

GitHub Actions / build

Warning in eval in module Inferno.Eval: Avoid NonEmpty.unzip ▫︎ Found: "unzip" ▫︎ Perhaps: "Data.Functor.unzip" ▫︎ Note: The function is being deprecated
_ -> Nothing

matchElems [] [] = Just mempty
Expand Down
31 changes: 31 additions & 0 deletions inferno-core/src/Inferno/Infer.hs
Expand Up @@ -51,6 +51,7 @@ import qualified Data.Map.Merge.Lazy as Map
import Data.Maybe (catMaybes, fromJust, mapMaybe)
import qualified Data.Set as Set
import qualified Data.Text as Text
import Data.Tuple.Extra (snd3)
import Debug.Trace (trace)
import Inferno.Infer.Env (Env (..), TypeMetadata (..), closeOver, closeOverType)
import qualified Inferno.Infer.Env as Env
Expand All @@ -61,6 +62,7 @@ import Inferno.Infer.Exhaustiveness
cEnum,
cInf,
cOne,
cRecord,
cTuple,
checkUsefullness,
exhaustive,
Expand Down Expand Up @@ -185,6 +187,7 @@ mkPattern = \case
PEmpty _ -> cEmpty
PArray _ ps _ -> cInf $ mkEnumArrayPat ps
PTuple _ ps _ -> cTuple $ map (mkPattern . fst) $ tListToList ps
PRecord _ ps _ -> let (fs, ps', _) = unzip3 ps in cRecord (Set.fromList fs) $ map mkPattern ps'
PCommentAbove _ p -> mkPattern p
PCommentAfter p _ -> mkPattern p
PCommentBelow p _ -> mkPattern p
Expand Down Expand Up @@ -615,6 +618,7 @@ infer expr =
let (isMerged, ics) = mergeImplicitMaps (blockPosition expr) is
return (InterpolatedString p1 (fromEitherList xs') p2, ImplType isMerged typeText, Set.unions css `Set.union` Set.fromList ics)
Record p1 fes p2 -> do
checkDuplicateFields exprLoc fes
let (fs, es) = unzip $ map (\(f, e, p) -> (f, (e, p))) fes
(es', impls, tys, cs) <- go es
let (isMerged, ics) = mergeImplicitMaps (blockPosition expr) impls
Expand Down Expand Up @@ -1182,6 +1186,24 @@ infer expr =
(t, vars1, cs1) <- mkPatConstraint p'
(ts, vars2, cs2) <- aux ps'
return (t : ts, vars1 ++ vars2, cs1 `Set.union` cs2)
PRecord _ fs _ -> do
checkDuplicateFields patLoc fs
(ts, vars, cs) <- aux fs
let inferredTy = TRecord ts RowAbsent
attachTypeToPosition
patLoc
TypeMetadata
{ identExpr = patternToExpr $ bimap (const ()) (const ()) pat,
ty = (Set.empty, ImplType Map.empty inferredTy),
docs = Nothing
}
return (inferredTy, vars, cs)
where
aux [] = return (mempty, [], Set.empty)
aux ((f, p', _l) : ps') = do
(t, vars1, cs1) <- mkPatConstraint p'
(ts, vars2, cs2) <- aux ps'
return (Map.insert f t ts, vars1 ++ vars2, cs1 `Set.union` cs2)
PVar _ Nothing -> do
tv <- fresh
let meta =
Expand All @@ -1206,6 +1228,7 @@ infer expr =
POne _ p -> checkVariableOverlap vars p
PArray _ ps _ -> foldM checkVariableOverlap vars $ map fst ps
PTuple _ ps _ -> foldM checkVariableOverlap vars $ map fst $ tListToList ps
PRecord _ ps _ -> foldM checkVariableOverlap vars $ map snd3 ps
_ -> return vars
CommentAbove p e -> do
(e', ty, cs) <- infer e
Expand Down Expand Up @@ -1236,6 +1259,14 @@ infer expr =
Just _openMod -> do
(e', ty, cs) <- infer e
return (OpenModule l1 mHash modNm imports p e', ty, cs)
where
-- Check if a record expr/pat has a duplicate field name
checkDuplicateFields l fs = aux mempty fs

Check warning on line 1264 in inferno-core/src/Inferno/Infer.hs

View workflow job for this annotation

GitHub Actions / build

Warning in infer in module Inferno.Infer: Eta reduce ▫︎ Found: "checkDuplicateFields l fs = aux mempty fs" ▫︎ Perhaps: "checkDuplicateFields l = aux mempty"
where
aux _seen [] = pure ()
aux seen ((f, _, _) : fs')
| Set.member f seen = throwError [DuplicateRecordField f l]
| otherwise = aux (Set.insert f seen) fs'

inferPatLit :: Location SourcePos -> Lit -> InfernoType -> Infer (InfernoType, [b], Set.Set c)
inferPatLit loc n t =
Expand Down
1 change: 1 addition & 0 deletions inferno-core/src/Inferno/Infer/Error.hs
Expand Up @@ -53,6 +53,7 @@ data TypeError a
| ModuleDoesNotExist ModuleName (Location a)
| NameInModuleDoesNotExist ModuleName Ident (Location a)
| AmbiguousName ModuleName Namespace (Location a)
| DuplicateRecordField Ident (Location a)
deriving (Show, Eq, Ord, Foldable)

makeBaseFunctor ''TypeError
Expand Down
32 changes: 29 additions & 3 deletions inferno-core/src/Inferno/Infer/Exhaustiveness.hs
Expand Up @@ -16,6 +16,7 @@ module Inferno.Infer.Exhaustiveness
cOne,
cEmpty,
cTuple,
cRecord,
)
where

Expand All @@ -26,17 +27,29 @@ import Data.Set (Set)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Inferno.Types.Syntax (Pat (PArray, PVar))
import Inferno.Types.Syntax (Ident (..), Pat (PArray, PVar))
import Inferno.Types.VersionControl (VCObjectHash)
import Prettyprinter (Pretty (pretty), align, tupled, (<+>))
import Prettyprinter (Pretty (pretty), align, encloseSep, tupled, (<+>))
import Text.Megaparsec (SourcePos, initialPos)

data Con = COne | CEmpty | CTuple Int | forall a. (Show a, Pretty a, Enum a) => CInf a | CEnum VCObjectHash Text
-- | Constructors, for the purposes of pattern matching.
-- This is an abstraction of the actual type constructors. For instance, all n-tuples
-- have the same constructor, all n-length arrays are represented by
-- @CInf (EnumArrayPat n)@, and all integer constants @n@ are considered as separate
-- constructors @CInf n@. Records are represented by the set of field names.
data Con
= COne
| CEmpty
| CTuple Int
| forall a. (Show a, Pretty a, Enum a) => CInf a
| CEnum VCObjectHash Text
| CRecord (Set.Set Ident)

instance Eq Con where
COne == COne = True
CEmpty == CEmpty = True
(CTuple i) == (CTuple j) = i == j
(CRecord i) == (CRecord j) = i == j
(CEnum e _) == (CEnum f _) = e == f
(CInf a) == (CInf b) = show a == show b
_ == _ = False
Expand All @@ -51,6 +64,7 @@ instance Ord Con where
CTuple n -> show n
CInf v -> show v
CEnum _ e -> show e
CRecord fs -> Text.unpack $ Text.intercalate "," $ map unIdent $ Set.toAscList fs

-- | We define a more abstract type of a pattern here, which only deals with (C)onstructors and
-- holes/(W)ildcards, as we do not need to make a distinction between a variable and a wildcard
Expand All @@ -65,6 +79,9 @@ instance Show Pattern where
C (CTuple _) xs -> "(" <> intercalate "," (map show xs) <> ")"
C (CInf x) _ -> show x
C (CEnum _ x) _ -> "#" <> show x
C (CRecord fs) _ -> "{" <> intercalate "," fields <> "}"
where
fields = map (\(Ident f) -> show f <> " = _") $ Set.toAscList fs
C _ _ -> "undefined"

instance Pretty Pattern where
Expand All @@ -75,6 +92,10 @@ instance Pretty Pattern where
C (CTuple _) xs -> tupled (map pretty xs)
C (CInf x) _ -> pretty x
C (CEnum _ x) _ -> "#" <> pretty x
C (CRecord fs) xs -> encloseSep "{" "}" "," fields
where
fields = map (\(Ident f, p) -> pretty f <+> "=" <+> pretty p) fps
fps = zip (Set.toAscList fs) xs
C _ _ -> "undefined"

type PMatrix = [[Pattern]]
Expand All @@ -86,6 +107,7 @@ cSize = \case
CTuple s -> s
CInf _ -> 0
CEnum _ _ -> 0
CRecord fs -> Set.size fs

specialize :: Con -> PMatrix -> PMatrix
specialize _ [] = []
Expand Down Expand Up @@ -133,6 +155,7 @@ isCompleteSignature enum_sigs s =
CEmpty -> if s == Set.fromList [COne, CEmpty] then Complete else Incomplete $ C COne [W]
COne -> if s == Set.fromList [COne, CEmpty] then Complete else Incomplete $ C CEmpty []
CTuple _ -> Complete
CRecord _ -> Complete
CEnum e _ ->
let e_sig = Set.map (uncurry CEnum) $ enum_sigs Map.! e
in if s == e_sig
Expand Down Expand Up @@ -176,6 +199,9 @@ defaultMatrix _ = error "malformed PMatrix"
cTuple :: [Pattern] -> Pattern
cTuple xs = C (CTuple (length xs)) xs

cRecord :: Set.Set Ident -> [Pattern] -> Pattern
cRecord fs = C (CRecord fs)

cOne :: Pattern -> Pattern
cOne x = C COne [x]

Expand Down
3 changes: 3 additions & 0 deletions inferno-core/src/Inferno/Infer/Pinned.hs
Expand Up @@ -97,6 +97,9 @@ pinPat m pat =
PTuple p1 es p2 -> do
es' <- mapM (\(e, p3) -> (,p3) <$> pinPat m e) es
pure $ PTuple p1 es' p2
PRecord p1 es p2 -> do
es' <- mapM (\(f, e, p3) -> (f,,p3) <$> pinPat m e) es
pure $ PRecord p1 es' p2
PCommentAbove c e -> PCommentAbove c <$> pinPat m e
PCommentAfter e c -> (`PCommentAfter` c) <$> pinPat m e
PCommentBelow e c -> (`PCommentBelow` c) <$> pinPat m e
Expand Down
7 changes: 7 additions & 0 deletions inferno-core/src/Inferno/Instances/Arbitrary.hs
Expand Up @@ -590,6 +590,13 @@ arbitrarySizedPat n =
sequence [(,Nothing) <$> arbitrarySizedPat (n `div` 3) | _ <- [1 .. k]]
)
`suchThat` (\xs -> length xs /= 1)
<*> arbitrary,
PRecord
<$> arbitrary
<*> ( do
k <- choose (0, n)
sequence [(,,Nothing) <$> arbitrary <*> arbitrarySizedPat (n `div` 3) | _ <- [1 .. k]]
)
<*> arbitrary
]

Expand Down
1 change: 1 addition & 0 deletions inferno-core/src/Inferno/Parse.hs
Expand Up @@ -529,6 +529,7 @@ letE = label ("a 'let' expression" ++ example "x") $
pat :: Parser (Pat () SourcePos)
pat =
uncurry3 PArray <$> array pat
<|> uncurry3 PRecord <$> record pat
<|> try (uncurry3 PTuple <$> tuple pat)
<|> parens pat
<|> try (hexadecimal PLit)
Expand Down
5 changes: 5 additions & 0 deletions inferno-core/src/Inferno/Parse/Commented.hs
Expand Up @@ -48,6 +48,11 @@ insertCommentIntoPat comment e =
else -- if the comment is neither before nor after the block, it must be within the expression
case e of
PTuple p1 es1 p2 -> PTuple p1 (tListFromList $ insertTuple $ tListToList es1) p2
PRecord p1 fps p2 -> PRecord p1 fps' p2
where
(fs, ps) = unzip $ map (\(f, p, mp) -> (f, (p, mp))) fps
ps' = insertTuple ps
fps' = zipWith (\f (p, mp) -> (f, p, mp)) fs ps'
POne p e1 -> POne p $ insertCommentIntoPat comment e1
PCommentAfter e1 c -> PCommentAfter (insertCommentIntoPat comment e1) c
PCommentBelow e1 c -> PCommentBelow (insertCommentIntoPat comment e1) c
Expand Down
4 changes: 4 additions & 0 deletions inferno-core/test/Eval/Spec.hs
Expand Up @@ -393,6 +393,10 @@ evalTests = describe "evaluate" $
shouldEvaluateTo "let r = {x = 2; y = 3} in r.y" $ VDouble 3
shouldEvaluateTo "let Array = {x = 2} in Array.x" $ VDouble 2
shouldEvaluateTo "let module r = Array in r.length []" $ VInt 0
shouldEvaluateTo "let f = fun r -> match r with { | {x = x; y = y} -> x + y } in f {x = 3; y = 5}" $ VDouble 8
shouldEvaluateTo "let f = fun r -> match r with { | {x = x; y = [y, z]} -> x + y + z | {x = x; y = t} -> x } in f {x = 3.3; y = [1.2]}" $ VDouble 3.3
shouldEvaluateTo "let f = fun r -> match r with { | {x = x; y = [y, z]} -> x + y + z | {x = x; y = t} -> x } in f {x = 3.3; y = [1.2, 3.4]}" $ VDouble 7.9
shouldEvaluateTo "let f = fun r -> match r with { | {x = x; y = [y, z]} -> x + y + z | {x = x; y = t} -> x } in f {x = 3.3; y = [1.2, 3.4, 5.6]}" $ VDouble 3.3
-- Type annotations
shouldEvaluateTo "let x : int = 2 in x" $ VInt 2
shouldEvaluateTo "let x : double = 2 in x" $ VDouble 2
Expand Down
16 changes: 16 additions & 0 deletions inferno-core/test/Infer/Spec.hs
Expand Up @@ -140,6 +140,7 @@ inferTests = describe "infer" $
shouldInferTypeFor "let r = {name = \"Zaphod\"; age = 391.4} in r.age" $ simpleType typeDouble
shouldInferTypeFor "let r = {name = \"Zaphod\"; age = 391.4} in let f = fun r -> r.age in f r + 1" $ simpleType typeDouble
shouldFailToInferTypeFor "let r = {name = \"Zaphod\"; age = 391.4} in r.age + \" is too old\""
-- Record field access vs Module.variable
shouldFailToInferTypeFor "rec.foo"
shouldInferTypeFor "Array.length []" $ simpleType typeInt
shouldFailToInferTypeFor "let r = {} in r.x"
Expand All @@ -150,12 +151,22 @@ inferTests = describe "infer" $
shouldFailToInferTypeFor "let module r = Array in r.x"
shouldInferTypeFor "let module r = Array in r.length []" $ simpleType typeInt
shouldInferTypeFor "let f = fun r -> r.age in f {age = 21.1; x = 5.4}" $ simpleType typeDouble
-- Record polymorphism
shouldFailToInferTypeFor "let f = fun r -> if #true then r else {age = 1.1} in f {age = 2; ht = 3}"
shouldInferTypeFor "let f = fun r -> truncateTo 2 r.ht + truncateTo 2 r.wt in f" $
makeType 0 [] (TArr (TRecord (Map.fromList [(Ident {unIdent = "ht"}, typeDouble), (Ident {unIdent = "wt"}, typeDouble)]) (RowVar (TV {unTV = 0}))) typeDouble)
shouldFailToInferTypeFor "let f = fun r -> if #true then r else {age = 1.1} in fun r -> let x = r.ht + r.age + 1.1 in f r"
shouldFailToInferTypeFor "let f = fun r -> r.age in let x = f {age = 21.1} in let y = f {age = \"t\"} in 1"
shouldFailToInferTypeFor "let f = fun r -> truncateTo 2 r.age in f {age = \"t\"}"
-- Record patterns
shouldInferTypeFor "let f = fun r -> match r with { | {x = x; y = y} -> x + y } in f {x = 3.3; y = 5.1}" $ simpleType typeDouble
shouldFailToInferTypeFor "let f = fun r -> match r with { | {x = x; y = [y, z]} -> x + y | {x = x; y = t} -> x } in f {x = 3.3; y = 5.1}"
shouldInferTypeFor "let f = fun r -> match r with { | {x = x; y = [y, z]} -> x + y | {x = x; y = t} -> x } in f {x = 3.3; y = [1.2]}" $ simpleType typeDouble
shouldFailToInferTypeFor "let f = fun r -> match r with { | {x = x; y = (y, z)} -> x + y | {x = x; y = t} -> x } in f {x = 3.3; y = 5.1}"
-- Duplicate fields
shouldFailToInferTypeFor "{x = 3.3; y = 5.1; x = 4}"
shouldFailToInferTypeFor "let f = fun r -> match r with { | {x = x; y = y} -> x + y } in f {x = 3.3; y = 5.1; x = 4}"
shouldFailToInferTypeFor "let f = fun r -> match r with { | {x = x; y = y; x = z} -> x + y } in f {x = 3.3; y = 5.1}"

-- Type annotations:
shouldInferTypeFor "let xBoo : double = 1 in truncateTo 2 xBoo" $ simpleType typeDouble
Expand Down Expand Up @@ -247,6 +258,11 @@ inferTests = describe "infer" $
["int", "double"]
"double"
[]
typeRepsShouldBe
"forall 'a. {requires numeric on 'a, requires rep on 'a} ⇒ series of 'a → 'a"
["series of double"]
"double"
[typeDouble]

-- Some tests with records:

Expand Down
1 change: 1 addition & 0 deletions inferno-core/test/Parse/Spec.hs
Expand Up @@ -47,6 +47,7 @@ prelude = builtinModules
normalizePat :: Pat h a -> Pat h a
normalizePat = ana $ \case
PTuple p1 xs p2 -> project $ PTuple p1 (fmap (\(e, _) -> (normalizePat e, Nothing)) xs) p2
PRecord p1 xs p2 -> project $ PRecord p1 (fmap (\(f, e, _) -> (f, normalizePat e, Nothing)) xs) p2
x -> project x

normalizeExpr :: Expr () a -> Expr () a
Expand Down
3 changes: 3 additions & 0 deletions inferno-lsp/CHANGELOG.md
@@ -1,6 +1,9 @@
# Revision History for inferno-lsp
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.2.5.0 -- 2024-03-26
* Add duplicate record field error

## 0.2.4.0 -- 2024-03-18
* HLint everything

Expand Down
2 changes: 1 addition & 1 deletion inferno-lsp/inferno-lsp.cabal
@@ -1,6 +1,6 @@
cabal-version: >=1.10
name: inferno-lsp
version: 0.2.4.0
version: 0.2.5.0
synopsis: LSP for Inferno
description: A language server protocol implementation for the Inferno language
category: IDE,DSL,Scripting
Expand Down
10 changes: 10 additions & 0 deletions inferno-lsp/src/Inferno/LSP/ParseInfer.hs
Expand Up @@ -494,6 +494,16 @@ inferErrorDiagnostic = \case
: [ indent 2 (pretty c) | c <- Set.toList tyCls
]
]
DuplicateRecordField (Ident f) (s, e) ->
[ errorDiagnosticInfer
(unPos $ sourceLine s)
(unPos $ sourceColumn s)
(unPos $ sourceLine e)
(unPos $ sourceColumn e)
$ renderDoc
$ vsep
$ ["Duplicate record field name:", indent 2 (pretty f)]
]

parseAndInferDiagnostics ::
forall m c.
Expand Down
3 changes: 3 additions & 0 deletions inferno-types/CHANGELOG.md
@@ -1,6 +1,9 @@
# Revision History for inferno-types
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.4.3.0 -- 2024-03-26
* Add record pattern matching

## 0.4.2.0 -- 2024-03-18
* Re-order `TRecord` in `InfernoType` so that existing serialization doesn't break

Expand Down
2 changes: 1 addition & 1 deletion inferno-types/inferno-types.cabal
@@ -1,6 +1,6 @@
cabal-version: >=1.10
name: inferno-types
version: 0.4.2.0
version: 0.4.3.0
synopsis: Core types for Inferno
description: Core types for the Inferno language
category: DSL,Scripting
Expand Down

0 comments on commit 02bb23a

Please sign in to comment.