Commit 5d715e25 authored by André Pacak's avatar André Pacak
Browse files

base implementation of variant type in haskell with tests

parent cf45fb42
...@@ -2,6 +2,8 @@ module SumTypes.Base where ...@@ -2,6 +2,8 @@ module SumTypes.Base where
import Prelude hiding ((>=), (<=), lookup) import Prelude hiding ((>=), (<=), lookup)
import Data.List(find) import Data.List(find)
-- import Data.Map
import qualified Data.Map as Map
import SumTypes.Language import SumTypes.Language
import Util.ErrorMessages import Util.ErrorMessages
...@@ -48,6 +50,14 @@ matchSum :: Type -> String -> Infer (Type, Type) ...@@ -48,6 +50,14 @@ matchSum :: Type -> String -> Infer (Type, Type)
matchSum (Sum ty1 ty2) _ = return (ty1, ty2) matchSum (Sum ty1 ty2) _ = return (ty1, ty2)
matchSum ty err = fail $ sumError ty err matchSum ty err = fail $ sumError ty err
matchVariant :: Type -> String -> Infer (Map.Map Name Type)
matchVariant (Variant types) _ = return types
matchVariant ty err = fail $ variantError ty err
liftMaybe :: Monad m => Maybe a -> String -> m a
liftMaybe (Just a) _ = return a
liftMaybe Nothing err = fail err
lookup :: Ctx -> Name -> Infer Type lookup :: Ctx -> Name -> Infer Type
lookup Empty x = fail $ "Unbound variable " ++ show x lookup Empty x = fail $ "Unbound variable " ++ show x
lookup (Bind c x t) y lookup (Bind c x t) y
...@@ -96,6 +106,21 @@ checkType ctx p@(Case e n1 t1 n2 t2 _) ty = do ...@@ -96,6 +106,21 @@ checkType ctx p@(Case e n1 t1 n2 t2 _) ty = do
(ty1, ty2) <- matchSum ety (show e) (ty1, ty2) <- matchSum ety (show e)
checkType (Bind ctx n1 ty1) t1 ty checkType (Bind ctx n1 ty1) t1 ty
checkType (Bind ctx n2 ty2) t2 ty checkType (Bind ctx n2 ty2) t2 ty
checkType ctx p@(Tag n t _) ty = do
types <- matchVariant ty (show p)
let lty = Map.lookup n types
(maybe (fail "") (checkType ctx t) lty)
checkType ctx p@(Match m cases _) ty = do
ety <- inferType ctx m
types <- matchVariant ety (show m)
let subchecks =
map (\(l, x, t) -> do
lty <- liftMaybe (Map.lookup l types) "Could not find labeled type"
checkType (Bind ctx x lty) t ty
) cases
foldl (>>) (return ()) subchecks
checkType ctx t ty = do checkType ctx t ty = do
ty' <- inferType ctx t ty' <- inferType ctx t
......
...@@ -2,6 +2,9 @@ module SumTypes.Language where ...@@ -2,6 +2,9 @@ module SumTypes.Language where
import Prelude hiding (Ord, (<=), (>=)) import Prelude hiding (Ord, (<=), (>=))
import Util.PartialOrd as PO import Util.PartialOrd as PO
import Data.List (zip)
import Data.Map
import qualified Data.Map as Map
type Name = String type Name = String
...@@ -21,7 +24,11 @@ data Term = ...@@ -21,7 +24,11 @@ data Term =
-- sum types -- sum types
InL Term Parent | InL Term Parent |
InR Term Parent | InR Term Parent |
Case Term Name Term Name Term Parent Case Term Name Term Name Term Parent |
-- variant types
Tag Name Term Parent |
Match Term [(Name, Name, Term)] Parent
instance Eq Term where instance Eq Term where
(Var n1 _) == (Var n2 _) = n1 == n2 (Var n1 _) == (Var n2 _) = n1 == n2
...@@ -36,6 +43,8 @@ instance Eq Term where ...@@ -36,6 +43,8 @@ instance Eq Term where
(InL t _) == (InL t' _) = t == t' (InL t _) == (InL t' _) = t == t'
(InR t _) == (InR t' _) = t == t' (InR t _) == (InR t' _) = t == t'
(Case e n1 t1 n2 t2 _) == (Case e' n1' t1' n2' t2' _) = e == e' && n1 == n1' && t1 == t1' && n2 == n2' && t2 == t2' (Case e n1 t1 n2 t2 _) == (Case e' n1' t1' n2' t2' _) = e == e' && n1 == n1' && t1 == t1' && n2 == n2' && t2 == t2'
(Tag n1 t1 _) == (Tag n2 t2 _) = n1 == n2 && t1 == t2
(Match m1 cases1 _) == (Match m2 cases2 _) = m1 == m2 && (all (\((l1, x1, t1),(l2, x2, t2)) -> l1 == l2 && x1 == x2 && t1 == t2) (zip cases1 cases2))
_ == _ = False _ == _ = False
type Parent = Maybe Term type Parent = Maybe Term
...@@ -53,6 +62,8 @@ parent (Let _ _ _ p) = p ...@@ -53,6 +62,8 @@ parent (Let _ _ _ p) = p
parent (InL _ p) = p parent (InL _ p) = p
parent (InR _ p) = p parent (InR _ p) = p
parent (Case _ _ _ _ _ p) = p parent (Case _ _ _ _ _ p) = p
parent (Tag _ _ p) = p
parent (Match _ _ p) = p
-- DSL for constructing terms -- DSL for constructing terms
...@@ -91,6 +102,14 @@ inr = cons1 InR ...@@ -91,6 +102,14 @@ inr = cons1 InR
case' e n1 t1 n2 t2 parent = let res = Case (e (Just res)) n1 (t1 (Just res)) n2 (t2 (Just res)) parent in res case' e n1 t1 n2 t2 parent = let res = Case (e (Just res)) n1 (t1 (Just res)) n2 (t2 (Just res)) parent in res
tag n = cons1 (Tag n)
match0 m parent = let res = Match (m (Just res)) [] parent in res
match1 m (l1, x1, t1) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res))] p in res
match2 m (l1, x1, t1) (l2, x2, t2) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res))] p in res
match3 m (l1, x1, t1) (l2, x2, t2) (l3, x3, t3) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res)), (l3, x3, t3 (Just res))] p in res
instance Show Term where instance Show Term where
showsPrec _ (Var x _) = showString x showsPrec _ (Var x _) = showString x
...@@ -111,14 +130,17 @@ instance Show Term where ...@@ -111,14 +130,17 @@ instance Show Term where
showString "let " . showString n . showString " = " . showsPrec (p + 1) e . showString " in " . showsPrec (p + 1) b showString "let " . showString n . showString " = " . showsPrec (p + 1) e . showString " in " . showsPrec (p + 1) b
showsPrec p (InL t _) = showString "InL " . showsPrec (p + 1) t showsPrec p (InL t _) = showString "InL " . showsPrec (p + 1) t
showsPrec p (InR t _) = showString "InR " . showsPrec (p + 1) t showsPrec p (InR t _) = showString "InR " . showsPrec (p + 1) t
showsPrec p (Tag n t _) = showString "Tag " . showString n . showsPrec (p + 1) t
showsPrec p (Match n cases _) = showString "Match"
data Type = Nat | Fun Type Type | Sum Type Type | AnyType data Type = Nat | Fun Type Type | Sum Type Type | Variant (Map Name Type) | AnyType
deriving (Show, Eq) deriving (Show, Eq)
instance PO.PartialOrd Type where instance PO.PartialOrd Type where
_ <= AnyType = True _ <= AnyType = True
(Fun t1 t2) <= (Fun u1 u2) = t1 <= u1 && t2 <= u2 (Fun t1 t2) <= (Fun u1 u2) = t1 <= u1 && t2 <= u2
(Sum t1 t2) <= (Sum u1 u2) = t1 <= u1 && t2 <= u2 (Sum t1 t2) <= (Sum u1 u2) = t1 <= u1 && t2 <= u2
(Variant pairs1) <= (Variant pairs2) = (length pairs1 == length pairs2) && (all (\(l1, t1) -> let t2 = (Map.lookup l1 pairs2) in (maybe False (==t1) t2)) $ toList pairs1)
ty1 <= ty2 = ty1 == ty2 ty1 <= ty2 = ty1 == ty2
data Ctx = Empty | Bind Ctx Name Type data Ctx = Empty | Bind Ctx Name Type
......
...@@ -13,9 +13,8 @@ prodError = generalError "Prod" ...@@ -13,9 +13,8 @@ prodError = generalError "Prod"
funError :: Show a => a -> String -> String funError :: Show a => a -> String -> String
funError = generalError "Fun" funError = generalError "Fun"
recordError :: Show a => a -> String -> String
recordError = generalError "Record"
sumError :: Show a => a -> String -> String sumError :: Show a => a -> String -> String
sumError = generalError "Sum" sumError = generalError "Sum"
variantError :: Show a => a -> String -> String
variantError = generalError "Variant"
...@@ -49,6 +49,17 @@ sharedSpec inferType = do ...@@ -49,6 +49,17 @@ sharedSpec inferType = do
let res = convert $ inferType tOkCaseRight in res `shouldBe` B.Inferred Nat let res = convert $ inferType tOkCaseRight in res `shouldBe` B.Inferred Nat
it "should infer case with shadowing variable name" $ do it "should infer case with shadowing variable name" $ do
let res = convert $ inferType tOkCaseLeftShadowBinding in res `shouldBe` B.Inferred Nat let res = convert $ inferType tOkCaseLeftShadowBinding in res `shouldBe` B.Inferred Nat
it "should infer injection into variant with single element " $ do
let res = convert $ inferType tOkSingleVariant in res `shouldBe` B.Inferred (Variant $ fromList [("a", Nat)])
it "should infer match against variant with single element" $ do
let res = convert $ inferType tOkMatchSingle in res `shouldBe` B.Inferred Nat
it "should infer match against variant with two elements" $ do
let res = convert $ inferType tOkMatchTwo in res `shouldBe` B.Inferred Nat
it "should infer match against variant with two elements where order of cases is different from order of variant" $ do
let res = convert $ inferType tOkMatchWithDifferentOrder in res `shouldBe` B.Inferred Nat
it "should infer match where case shadows outer binding"$ do
let res = convert $ inferType tOkMatchShadowBinding in res `shouldBe` B.Inferred Nat
it "should fail while inferring an arithmetic expression with an unapplied lambda expression" $ do it "should fail while inferring an arithmetic expression with an unapplied lambda expression" $ do
let res = isInferred $ convert $ inferType tFailArithmetic in res `shouldBe` False let res = isInferred $ convert $ inferType tFailArithmetic in res `shouldBe` False
...@@ -64,3 +75,12 @@ sharedSpec inferType = do ...@@ -64,3 +75,12 @@ sharedSpec inferType = do
let res = isInferred $ convert $ inferType tFailCaseUnequalReturnTypes in res `shouldBe` False let res = isInferred $ convert $ inferType tFailCaseUnequalReturnTypes in res `shouldBe` False
it "should fail while matching on right injection where binding should shadow" $ do it "should fail while matching on right injection where binding should shadow" $ do
let res = isInferred $ convert $ inferType tFailCaseRightShadowBinding in res `shouldBe` False let res = isInferred $ convert $ inferType tFailCaseRightShadowBinding in res `shouldBe` False
it "should fail while tagging wrong type with label" $ do
let res = isInferred $ convert $ inferType tFailWrongTypeForTaggedLabel in res `shouldBe` False
it "should fail while matching against non variant type" $ do
let res = isInferred $ convert $ inferType tFailMatchNotVariant in res `shouldBe` False
it "should fail while tagging wrong label" $ do
let res = isInferred $ convert $ inferType tFailWrongLabel in res `shouldBe` False
it "should fail while matching where return types of cases are not equal" $ do
let res = isInferred $ convert $ inferType tFailMatchUnequalTypesForCases in res `shouldBe` False
{-# OPTIONS_GHC -Wno-orphans #-}
module SumTypes.TestCases where module SumTypes.TestCases where
import Prelude hiding (lookup,(*), (**), succ, fst, snd) import Prelude hiding (lookup,(*), (**), succ, fst, snd)
import SumTypes.Language import SumTypes.Language
import Data.Map
import qualified Data.Map as Map
tOkZero = root zero tOkZero = root zero
tOkSucc = root $ succ $ succ zero tOkSucc = root $ succ $ succ zero
...@@ -19,6 +22,15 @@ tOkCaseLeft = root $ (anno Nat (case' (anno (Sum Nat Nat) (inl $ zero)) "a" (var ...@@ -19,6 +22,15 @@ tOkCaseLeft = root $ (anno Nat (case' (anno (Sum Nat Nat) (inl $ zero)) "a" (var
tOkCaseRight = root $ (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" $ add zero $ var "x")) "a" zero "b" (app (var "b") (succ $ succ zero)))) tOkCaseRight = root $ (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" $ add zero $ var "x")) "a" zero "b" (app (var "b") (succ $ succ zero))))
tOkCaseLeftShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum (Fun Nat Nat) Nat) (inl $ lam "x" zero)) "a" (app (var "a") zero) "b" (var "a")))) tOkCaseLeftShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum (Fun Nat Nat) Nat) (inl $ lam "x" zero)) "a" (app (var "a") zero) "b" (var "a"))))
tOkSingleVariant = root $ (anno (Variant (fromList [("a", Nat)])) (tag "a" (succ zero)))
tOkTwoVariant = root $ (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero)))
tOkThreeVariant = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero)))
tOkMatchSingle = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) (tag "a" zero)) ("a", "x", succ $ var "x")))
tOkMatchTwo = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
tOkMatchWithDifferentOrder = root $ (anno Nat (match2 (anno (Variant (fromList [("b", (Fun Nat Nat)), ("a", Nat)])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
tOkMatchShadowBinding = root $ (let' "x" zero (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero)))))
tFailArithmetic = root $ (add (succ $ succ zero) (mult (anno (Fun Nat Nat) (lam "f" (var "f"))) (succ $ succ $ succ $ succ zero))) tFailArithmetic = root $ (add (succ $ succ zero) (mult (anno (Fun Nat Nat) (lam "f" (var "f"))) (succ $ succ $ succ $ succ zero)))
tFailLambdaNoAnno = root $ (anno Nat (app (lam "b" (add (var "b") zero)) (succ $ succ zero))) tFailLambdaNoAnno = root $ (anno Nat (app (lam "b" (add (var "b") zero)) (succ $ succ zero)))
tFailLambdaNotAnnoInBindingLet = root $ (let' "e" (lam "n" (add (var "n") (succ zero))) (mult (app (var "e") zero) (succ zero))) tFailLambdaNotAnnoInBindingLet = root $ (let' "e" (lam "n" (add (var "n") (succ zero))) (mult (app (var "e") zero) (succ zero)))
...@@ -27,3 +39,14 @@ tFailInL = root $ (anno (Sum Nat Nat) (inl (lam "b" zero))) ...@@ -27,3 +39,14 @@ tFailInL = root $ (anno (Sum Nat Nat) (inl (lam "b" zero)))
tFailInR = root $ (anno (Sum Nat Nat) (inr (lam "x" $ var "x"))) tFailInR = root $ (anno (Sum Nat Nat) (inr (lam "x" $ var "x")))
tFailCaseUnequalReturnTypes = root $ (anno (Fun Nat Nat) (case' (anno Nat (inl $ succ $ succ zero)) "a" (lam "x" (add (var "a") zero)) "b" zero)) tFailCaseUnequalReturnTypes = root $ (anno (Fun Nat Nat) (case' (anno Nat (inl $ succ $ succ zero)) "a" (lam "x" (add (var "a") zero)) "b" zero))
tFailCaseRightShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" zero)) "a" zero "b" (var "b")))) tFailCaseRightShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" zero)) "a" zero "b" (var "b"))))
tFailWrongTypeForTaggedLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" zero))
tFailWrongLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "d" zero))
tFailMatchNotVariant = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) zero) ("a", "x", succ $ var "x")))
tFailMatchUnequalTypesForCases = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (var "x") )))
-- tOkMatchSingle = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) (tag "a" zero)) ("a", "x", succ $ var "x")))
-- tOkMatchTwo = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
-- tOkMatchWithDifferentOrder = root $ (anno Nat (match2 (anno (Variant (fromList [("b", (Fun Nat Nat)), ("a", Nat)])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment