Commit 5d715e25 authored by André Pacak's avatar André Pacak

base implementation of variant type in haskell with tests

parent cf45fb42
......@@ -2,6 +2,8 @@ module SumTypes.Base where
import Prelude hiding ((>=), (<=), lookup)
import Data.List(find)
-- import Data.Map
import qualified Data.Map as Map
import SumTypes.Language
import Util.ErrorMessages
......@@ -48,6 +50,14 @@ matchSum :: Type -> String -> Infer (Type, Type)
matchSum (Sum ty1 ty2) _ = return (ty1, ty2)
matchSum ty err = fail $ sumError ty err
matchVariant :: Type -> String -> Infer (Map.Map Name Type)
matchVariant (Variant types) _ = return types
matchVariant ty err = fail $ variantError ty err
liftMaybe :: Monad m => Maybe a -> String -> m a
liftMaybe (Just a) _ = return a
liftMaybe Nothing err = fail err
lookup :: Ctx -> Name -> Infer Type
lookup Empty x = fail $ "Unbound variable " ++ show x
lookup (Bind c x t) y
......@@ -96,6 +106,21 @@ checkType ctx p@(Case e n1 t1 n2 t2 _) ty = do
(ty1, ty2) <- matchSum ety (show e)
checkType (Bind ctx n1 ty1) t1 ty
checkType (Bind ctx n2 ty2) t2 ty
checkType ctx p@(Tag n t _) ty = do
types <- matchVariant ty (show p)
let lty = Map.lookup n types
(maybe (fail "") (checkType ctx t) lty)
checkType ctx p@(Match m cases _) ty = do
ety <- inferType ctx m
types <- matchVariant ety (show m)
let subchecks =
map (\(l, x, t) -> do
lty <- liftMaybe (Map.lookup l types) "Could not find labeled type"
checkType (Bind ctx x lty) t ty
) cases
foldl (>>) (return ()) subchecks
checkType ctx t ty = do
ty' <- inferType ctx t
......
......@@ -2,6 +2,9 @@ module SumTypes.Language where
import Prelude hiding (Ord, (<=), (>=))
import Util.PartialOrd as PO
import Data.List (zip)
import Data.Map
import qualified Data.Map as Map
type Name = String
......@@ -21,7 +24,11 @@ data Term =
-- sum types
InL Term Parent |
InR Term Parent |
Case Term Name Term Name Term Parent
Case Term Name Term Name Term Parent |
-- variant types
Tag Name Term Parent |
Match Term [(Name, Name, Term)] Parent
instance Eq Term where
(Var n1 _) == (Var n2 _) = n1 == n2
......@@ -36,6 +43,8 @@ instance Eq Term where
(InL t _) == (InL t' _) = t == t'
(InR t _) == (InR t' _) = t == t'
(Case e n1 t1 n2 t2 _) == (Case e' n1' t1' n2' t2' _) = e == e' && n1 == n1' && t1 == t1' && n2 == n2' && t2 == t2'
(Tag n1 t1 _) == (Tag n2 t2 _) = n1 == n2 && t1 == t2
(Match m1 cases1 _) == (Match m2 cases2 _) = m1 == m2 && (all (\((l1, x1, t1),(l2, x2, t2)) -> l1 == l2 && x1 == x2 && t1 == t2) (zip cases1 cases2))
_ == _ = False
type Parent = Maybe Term
......@@ -53,6 +62,8 @@ parent (Let _ _ _ p) = p
parent (InL _ p) = p
parent (InR _ p) = p
parent (Case _ _ _ _ _ p) = p
parent (Tag _ _ p) = p
parent (Match _ _ p) = p
-- DSL for constructing terms
......@@ -91,6 +102,14 @@ inr = cons1 InR
case' e n1 t1 n2 t2 parent = let res = Case (e (Just res)) n1 (t1 (Just res)) n2 (t2 (Just res)) parent in res
tag n = cons1 (Tag n)
match0 m parent = let res = Match (m (Just res)) [] parent in res
match1 m (l1, x1, t1) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res))] p in res
match2 m (l1, x1, t1) (l2, x2, t2) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res))] p in res
match3 m (l1, x1, t1) (l2, x2, t2) (l3, x3, t3) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res)), (l3, x3, t3 (Just res))] p in res
instance Show Term where
showsPrec _ (Var x _) = showString x
......@@ -111,14 +130,17 @@ instance Show Term where
showString "let " . showString n . showString " = " . showsPrec (p + 1) e . showString " in " . showsPrec (p + 1) b
showsPrec p (InL t _) = showString "InL " . showsPrec (p + 1) t
showsPrec p (InR t _) = showString "InR " . showsPrec (p + 1) t
showsPrec p (Tag n t _) = showString "Tag " . showString n . showsPrec (p + 1) t
showsPrec p (Match n cases _) = showString "Match"
data Type = Nat | Fun Type Type | Sum Type Type | AnyType
data Type = Nat | Fun Type Type | Sum Type Type | Variant (Map Name Type) | AnyType
deriving (Show, Eq)
instance PO.PartialOrd Type where
_ <= AnyType = True
(Fun t1 t2) <= (Fun u1 u2) = t1 <= u1 && t2 <= u2
(Sum t1 t2) <= (Sum u1 u2) = t1 <= u1 && t2 <= u2
(Variant pairs1) <= (Variant pairs2) = (length pairs1 == length pairs2) && (all (\(l1, t1) -> let t2 = (Map.lookup l1 pairs2) in (maybe False (==t1) t2)) $ toList pairs1)
ty1 <= ty2 = ty1 == ty2
data Ctx = Empty | Bind Ctx Name Type
......
......@@ -13,9 +13,8 @@ prodError = generalError "Prod"
funError :: Show a => a -> String -> String
funError = generalError "Fun"
recordError :: Show a => a -> String -> String
recordError = generalError "Record"
sumError :: Show a => a -> String -> String
sumError = generalError "Sum"
variantError :: Show a => a -> String -> String
variantError = generalError "Variant"
......@@ -49,6 +49,17 @@ sharedSpec inferType = do
let res = convert $ inferType tOkCaseRight in res `shouldBe` B.Inferred Nat
it "should infer case with shadowing variable name" $ do
let res = convert $ inferType tOkCaseLeftShadowBinding in res `shouldBe` B.Inferred Nat
it "should infer injection into variant with single element " $ do
let res = convert $ inferType tOkSingleVariant in res `shouldBe` B.Inferred (Variant $ fromList [("a", Nat)])
it "should infer match against variant with single element" $ do
let res = convert $ inferType tOkMatchSingle in res `shouldBe` B.Inferred Nat
it "should infer match against variant with two elements" $ do
let res = convert $ inferType tOkMatchTwo in res `shouldBe` B.Inferred Nat
it "should infer match against variant with two elements where order of cases is different from order of variant" $ do
let res = convert $ inferType tOkMatchWithDifferentOrder in res `shouldBe` B.Inferred Nat
it "should infer match where case shadows outer binding"$ do
let res = convert $ inferType tOkMatchShadowBinding in res `shouldBe` B.Inferred Nat
it "should fail while inferring an arithmetic expression with an unapplied lambda expression" $ do
let res = isInferred $ convert $ inferType tFailArithmetic in res `shouldBe` False
......@@ -64,3 +75,12 @@ sharedSpec inferType = do
let res = isInferred $ convert $ inferType tFailCaseUnequalReturnTypes in res `shouldBe` False
it "should fail while matching on right injection where binding should shadow" $ do
let res = isInferred $ convert $ inferType tFailCaseRightShadowBinding in res `shouldBe` False
it "should fail while tagging wrong type with label" $ do
let res = isInferred $ convert $ inferType tFailWrongTypeForTaggedLabel in res `shouldBe` False
it "should fail while matching against non variant type" $ do
let res = isInferred $ convert $ inferType tFailMatchNotVariant in res `shouldBe` False
it "should fail while tagging wrong label" $ do
let res = isInferred $ convert $ inferType tFailWrongLabel in res `shouldBe` False
it "should fail while matching where return types of cases are not equal" $ do
let res = isInferred $ convert $ inferType tFailMatchUnequalTypesForCases in res `shouldBe` False
{-# OPTIONS_GHC -Wno-orphans #-}
module SumTypes.TestCases where
import Prelude hiding (lookup,(*), (**), succ, fst, snd)
import SumTypes.Language
import Data.Map
import qualified Data.Map as Map
tOkZero = root zero
tOkSucc = root $ succ $ succ zero
......@@ -19,6 +22,15 @@ tOkCaseLeft = root $ (anno Nat (case' (anno (Sum Nat Nat) (inl $ zero)) "a" (var
tOkCaseRight = root $ (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" $ add zero $ var "x")) "a" zero "b" (app (var "b") (succ $ succ zero))))
tOkCaseLeftShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum (Fun Nat Nat) Nat) (inl $ lam "x" zero)) "a" (app (var "a") zero) "b" (var "a"))))
tOkSingleVariant = root $ (anno (Variant (fromList [("a", Nat)])) (tag "a" (succ zero)))
tOkTwoVariant = root $ (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero)))
tOkThreeVariant = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero)))
tOkMatchSingle = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) (tag "a" zero)) ("a", "x", succ $ var "x")))
tOkMatchTwo = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
tOkMatchWithDifferentOrder = root $ (anno Nat (match2 (anno (Variant (fromList [("b", (Fun Nat Nat)), ("a", Nat)])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
tOkMatchShadowBinding = root $ (let' "x" zero (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero)))))
tFailArithmetic = root $ (add (succ $ succ zero) (mult (anno (Fun Nat Nat) (lam "f" (var "f"))) (succ $ succ $ succ $ succ zero)))
tFailLambdaNoAnno = root $ (anno Nat (app (lam "b" (add (var "b") zero)) (succ $ succ zero)))
tFailLambdaNotAnnoInBindingLet = root $ (let' "e" (lam "n" (add (var "n") (succ zero))) (mult (app (var "e") zero) (succ zero)))
......@@ -27,3 +39,14 @@ tFailInL = root $ (anno (Sum Nat Nat) (inl (lam "b" zero)))
tFailInR = root $ (anno (Sum Nat Nat) (inr (lam "x" $ var "x")))
tFailCaseUnequalReturnTypes = root $ (anno (Fun Nat Nat) (case' (anno Nat (inl $ succ $ succ zero)) "a" (lam "x" (add (var "a") zero)) "b" zero))
tFailCaseRightShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" zero)) "a" zero "b" (var "b"))))
tFailWrongTypeForTaggedLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" zero))
tFailWrongLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "d" zero))
tFailMatchNotVariant = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) zero) ("a", "x", succ $ var "x")))
tFailMatchUnequalTypesForCases = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (var "x") )))
-- tOkMatchSingle = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) (tag "a" zero)) ("a", "x", succ $ var "x")))
-- tOkMatchTwo = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
-- tOkMatchWithDifferentOrder = root $ (anno Nat (match2 (anno (Variant (fromList [("b", (Fun Nat Nat)), ("a", Nat)])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment