diff --git a/haskell/src/FORecursiveTypes/Base.hs b/haskell/src/FORecursiveTypes/Base.hs new file mode 100644 index 0000000000000000000000000000000000000000..0c05f8f23a8f26e509124d395f52543988430023 --- /dev/null +++ b/haskell/src/FORecursiveTypes/Base.hs @@ -0,0 +1,139 @@ +module FORecursiveTypes.Base where + +import Prelude hiding ((>=), (<=), lookup) +import Data.List(find) +-- import Data.Map +import qualified Data.Map as Map + +import FORecursiveTypes.Language +import Util.ErrorMessages + +type Error = String + +data Infer a + = Inferred a + | NotInferred Error deriving (Eq, Show) + +type Check = Infer () + +instance Functor Infer where + fmap _ (NotInferred err) = NotInferred err + fmap f (Inferred ty) = Inferred $ f ty + +instance Applicative Infer where + pure = Inferred + (NotInferred err) <*> _ = NotInferred err + (Inferred a) <*> something = fmap a something + +instance Monad Infer where + return = Inferred + (Inferred ty) >>= f = f ty + (NotInferred err) >>= _ = NotInferred err + fail = NotInferred + +-- matching functions that extract the inner types if possible +-- one problem is that we do not get as good error messages, because term is not known in these functions +matchNat :: Type -> String -> Check +matchNat Nat _ = return () +matchNat ty err = fail $ natError ty err + +matchFun :: Type -> String -> Infer (Type, Type) +matchFun (Fun ty1 ty2) _ = return (ty1, ty2) +matchFun ty err = fail $ funError ty err + +matchType :: Type -> Type -> String -> Check +matchType ty1 ty2 _ + | ty1 == ty2 = return () +matchType ty1 ty2 err = fail $ generalError (show ty1) ty2 err + +matchSum :: Type -> String -> Infer (Type, Type) +matchSum (Sum ty1 ty2) _ = return (ty1, ty2) +matchSum ty err = fail $ sumError ty err + +matchVariant :: Type -> String -> Infer (Map.Map Name Type) +matchVariant (Variant types) _ = return types +matchVariant ty err = fail $ variantError ty err + +matchTypeVar :: TypeMap -> Type -> Infer Type +matchTypeVar tymap (TypeVar x) = liftMaybe (Map.lookup x tymap) $ "Could not found type definition " ++ show x +matchTypeVar _ ty = return ty + +liftMaybe :: Monad m => Maybe a -> String -> m a +liftMaybe (Just a) _ = return a +liftMaybe Nothing err = fail err + +lookup :: Ctx -> Name -> Infer Type +lookup Empty x = fail $ "Unbound variable " ++ show x +lookup (Bind c x t) y + | x == y = return t + | otherwise = lookup c y + +inferType :: Ctx -> TypeMap -> Term -> Infer Type +inferType _ _ (Unit _) = return UnitT +inferType _ _ (Zero _) = return Nat +inferType ctx tymap (Succ t _) = do + checkType ctx tymap t Nat + return Nat +inferType ctx tymap (Add t1 t2 _) = do + checkType ctx tymap t1 Nat + checkType ctx tymap t2 Nat + return Nat +inferType ctx tymap (Mult t1 t2 _) = do + checkType ctx tymap t1 Nat + checkType ctx tymap t2 Nat + return Nat +inferType ctx _ (Var name _) = lookup ctx name +inferType ctx tymap (Let name t body _) = do + tyt <- inferType ctx tymap t + inferType (Bind ctx name tyt) tymap body +inferType ctx tymap (Anno term ty _) = do + rty <- matchTypeVar tymap ty + checkType ctx tymap term rty + return ty +inferType ctx tymap (App t1 t2 _) = do + ty <- inferType ctx tymap t1 + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchFun rty (show t1) + checkType ctx tymap t2 ty1 + return ty2 +inferType _ _ t = fail $ "Cannot infer type of term " ++ show t + + +checkType :: Ctx -> TypeMap -> Term -> Type -> Check +checkType ctx tymap p@(Lam name t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchFun rty (show p) + checkType (Bind ctx name ty1) tymap t ty2 +checkType ctx tymap p@(InL t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchSum rty (show p) + checkType ctx tymap t ty1 +checkType ctx tymap p@(InR t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchSum rty (show p) + checkType ctx tymap t ty2 +checkType ctx tymap p@(Case e n1 t1 n2 t2 _) ty = do + ety <- inferType ctx tymap e + rty <- matchTypeVar tymap ety + (ty1, ty2) <- matchSum rty (show e) + checkType (Bind ctx n1 ty1) tymap t1 ty + checkType (Bind ctx n2 ty2) tymap t2 ty +checkType ctx tymap p@(Tag n t _) ty = do + rty <- matchTypeVar tymap ty + types <- matchVariant rty (show p) + let lty = Map.lookup n types + (maybe (fail "") (checkType ctx tymap t) lty) +checkType ctx tymap p@(Match m cases _) ty = do + ety <- inferType ctx tymap m + rty <- matchTypeVar tymap ety + types <- matchVariant rty (show m) + let subchecks = + map (\(l, x, t) -> do + lty <- liftMaybe (Map.lookup l types) "Could not find labeled type" + checkType (Bind ctx x lty) tymap t ty + ) cases + foldl (>>) (return ()) subchecks + +checkType ctx tymap t ty = do + ty' <- inferType ctx tymap t + matchType ty ty' (show t) diff --git a/haskell/src/FORecursiveTypes/ErrorList.hs b/haskell/src/FORecursiveTypes/ErrorList.hs new file mode 100644 index 0000000000000000000000000000000000000000..9e988021855b28ca681b9286b63a93fd3685fa86 --- /dev/null +++ b/haskell/src/FORecursiveTypes/ErrorList.hs @@ -0,0 +1,139 @@ +module FORecursiveTypes.ErrorList where + +import Prelude hiding ((>=), (<=), lookup) +import Data.List(find) +-- import Data.Map +import qualified Data.Map as Map + +import FORecursiveTypes.Language +import Util.ErrorMessages + +type Error = [String] + +data Infer a + = Inferred a + | NotInferred Error deriving (Eq, Show) + +type Check = Infer () + +instance Functor Infer where + fmap _ (NotInferred err) = NotInferred err + fmap f (Inferred ty) = Inferred $ f ty + +instance Applicative Infer where + pure = Inferred + (NotInferred err) <*> _ = NotInferred err + (Inferred a) <*> something = fmap a something + +instance Monad Infer where + return = Inferred + (Inferred ty) >>= f = f ty + (NotInferred err) >>= _ = NotInferred err + fail msg = NotInferred [msg] + +-- matching functions that extract the inner types if possible +-- one problem is that we do not get as good error messages, because term is not known in these functions +matchNat :: Type -> String -> Check +matchNat Nat _ = return () +matchNat ty err = fail $ natError ty err + +matchFun :: Type -> String -> Infer (Type, Type) +matchFun (Fun ty1 ty2) _ = return (ty1, ty2) +matchFun ty err = fail $ funError ty err + +matchType :: Type -> Type -> String -> Check +matchType ty1 ty2 _ + | ty1 == ty2 = return () +matchType ty1 ty2 err = fail $ generalError (show ty1) ty2 err + +matchSum :: Type -> String -> Infer (Type, Type) +matchSum (Sum ty1 ty2) _ = return (ty1, ty2) +matchSum ty err = fail $ sumError ty err + +matchVariant :: Type -> String -> Infer (Map.Map Name Type) +matchVariant (Variant types) _ = return types +matchVariant ty err = fail $ variantError ty err + +matchTypeVar :: TypeMap -> Type -> Infer Type +matchTypeVar tymap (TypeVar x) = liftMaybe (Map.lookup x tymap) $ "Could not found type definition " ++ show x +matchTypeVar _ ty = return ty + +liftMaybe :: Monad m => Maybe a -> String -> m a +liftMaybe (Just a) _ = return a +liftMaybe Nothing err = fail err + +lookup :: Ctx -> Name -> Infer Type +lookup Empty x = fail $ "Unbound variable " ++ show x +lookup (Bind c x t) y + | x == y = return t + | otherwise = lookup c y + +inferType :: Ctx -> TypeMap -> Term -> Infer Type +inferType _ _ (Unit _) = return UnitT +inferType _ _ (Zero _) = return Nat +inferType ctx tymap (Succ t _) = do + checkType ctx tymap t Nat + return Nat +inferType ctx tymap (Add t1 t2 _) = do + checkType ctx tymap t1 Nat + checkType ctx tymap t2 Nat + return Nat +inferType ctx tymap (Mult t1 t2 _) = do + checkType ctx tymap t1 Nat + checkType ctx tymap t2 Nat + return Nat +inferType ctx _ (Var name _) = lookup ctx name +inferType ctx tymap (Let name t body _) = do + tyt <- inferType ctx tymap t + inferType (Bind ctx name tyt) tymap body +inferType ctx tymap (Anno term ty _) = do + rty <- matchTypeVar tymap ty + checkType ctx tymap term rty + return ty +inferType ctx tymap (App t1 t2 _) = do + ty <- inferType ctx tymap t1 + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchFun rty (show t1) + checkType ctx tymap t2 ty1 + return ty2 +inferType _ _ t = fail $ "Cannot infer type of term " ++ show t + + +checkType :: Ctx -> TypeMap -> Term -> Type -> Check +checkType ctx tymap p@(Lam name t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchFun rty (show p) + checkType (Bind ctx name ty1) tymap t ty2 +checkType ctx tymap p@(InL t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchSum rty (show p) + checkType ctx tymap t ty1 +checkType ctx tymap p@(InR t _) ty = do + rty <- matchTypeVar tymap ty + (ty1, ty2) <- matchSum rty (show p) + checkType ctx tymap t ty2 +checkType ctx tymap p@(Case e n1 t1 n2 t2 _) ty = do + ety <- inferType ctx tymap e + rty <- matchTypeVar tymap ety + (ty1, ty2) <- matchSum rty (show e) + checkType (Bind ctx n1 ty1) tymap t1 ty + checkType (Bind ctx n2 ty2) tymap t2 ty +checkType ctx tymap p@(Tag n t _) ty = do + rty <- matchTypeVar tymap ty + types <- matchVariant rty (show p) + let lty = Map.lookup n types + (maybe (fail "") (checkType ctx tymap t) lty) +checkType ctx tymap p@(Match m cases _) ty = do + ety <- inferType ctx tymap m + rty <- matchTypeVar tymap ety + types <- matchVariant rty (show m) + let subchecks = + map (\(l, x, t) -> do + lty <- liftMaybe (Map.lookup l types) "Could not find labeled type" + checkType (Bind ctx x lty) tymap t ty + ) cases + foldl (>>) (return ()) subchecks + +checkType ctx tymap t ty = do + ty' <- inferType ctx tymap t + matchType ty ty' (show t) diff --git a/haskell/src/FORecursiveTypes/Language.hs b/haskell/src/FORecursiveTypes/Language.hs new file mode 100644 index 0000000000000000000000000000000000000000..425086a9cee821ac3bbb38119ba0b6f3f8a07e6e --- /dev/null +++ b/haskell/src/FORecursiveTypes/Language.hs @@ -0,0 +1,153 @@ +module FORecursiveTypes.Language where + +import Prelude hiding (Ord, (<=), (>=)) +import Util.PartialOrd as PO +import Data.List (zip) +import Data.Map +import qualified Data.Map as Map + + +type Name = String +data Term = + -- base constructs + Var Name Parent | + App Term Term Parent | + Lam Name Term Parent | + Anno Term Type Parent | + -- arithmetic + Zero Parent | + Succ Term Parent | + Add Term Term Parent | + Mult Term Term Parent | + -- let binding + Let Name Term Term Parent | + -- sum types + InL Term Parent | + InR Term Parent | + Case Term Name Term Name Term Parent | + -- variant types + Tag Name Term Parent | + Match Term [(Name, Name, Term)] Parent | + Unit Parent + + +instance Eq Term where + (Var n1 _) == (Var n2 _) = n1 == n2 + (App t1 t2 _) == (App t1' t2' _) = t1 == t1' && t2 == t2' + (Lam x t _) == (Lam x' t' _) = x == x' && t == t' + (Anno t ty _) == (Anno t' ty' _) = t == t' && ty == ty' + (Zero _) == (Zero _) = True + (Succ t _) == (Succ t' _) = t == t' + (Add t1 t2 _) == (Add t1' t2' _) = t1 == t1' && t2 == t2' + (Mult t1 t2 _) == (Mult t1' t2' _) = t1 == t1' && t2 == t2' + (Let n t1 t2 _) == (Let n' t1' t2' _) = n == n' && t1 == t1' && t2 == t2' + (InL t _) == (InL t' _) = t == t' + (InR t _) == (InR t' _) = t == t' + (Case e n1 t1 n2 t2 _) == (Case e' n1' t1' n2' t2' _) = e == e' && n1 == n1' && t1 == t1' && n2 == n2' && t2 == t2' + (Tag n1 t1 _) == (Tag n2 t2 _) = n1 == n2 && t1 == t2 + (Match m1 cases1 _) == (Match m2 cases2 _) = m1 == m2 && (all (\((l1, x1, t1),(l2, x2, t2)) -> l1 == l2 && x1 == x2 && t1 == t2) (zip cases1 cases2)) + (Unit _) == (Unit _) = True + _ == _ = False + +type Parent = Maybe Term + +parent :: Term -> Parent +parent (Var _ p) = p +parent (App _ _ p) = p +parent (Lam _ _ p) = p +parent (Anno _ _ p) = p +parent (Zero p) = p +parent (Succ _ p) = p +parent (Add _ _ p) = p +parent (Mult _ _ p) = p +parent (Let _ _ _ p) = p +parent (InL _ p) = p +parent (InR _ p) = p +parent (Case _ _ _ _ _ p) = p +parent (Tag _ _ p) = p +parent (Match _ _ p) = p + + +-- DSL for constructing terms +type TermBuild = Parent -> Term + +cons0 :: (Parent -> Term) -> TermBuild +cons0 cons parent = let res = cons parent in res + +cons1 :: (Term -> Parent -> Term) -> TermBuild -> TermBuild +cons1 cons e parent = let res = cons (e (Just res)) parent in res + +cons2 :: (Term -> Term -> Parent -> Term) -> TermBuild -> TermBuild -> TermBuild +cons2 cons e1 e2 parent = let res = cons (e1 (Just res)) (e2 (Just res)) parent in res + +cons3 :: (Term -> Term -> Term -> Parent -> Term) -> TermBuild -> TermBuild -> TermBuild -> TermBuild +cons3 cons e1 e2 e3 parent = let res = cons (e1 (Just res)) (e2 (Just res)) (e3 (Just res)) parent in res + +root :: TermBuild -> Term +root t = t Nothing + +var n = cons0 (Var n) +app = cons2 App +lam n = cons1 (Lam n) +anno ty = cons1 (flip Anno ty) + +unit = cons0 Unit +zero = cons0 Zero +succ = cons1 Succ +add = cons2 Add +mult = cons2 Mult + + +let' n = cons2 (Let n) + +inl = cons1 InL +inr = cons1 InR + +case' e n1 t1 n2 t2 parent = let res = Case (e (Just res)) n1 (t1 (Just res)) n2 (t2 (Just res)) parent in res + +tag n = cons1 (Tag n) + +match0 m parent = let res = Match (m (Just res)) [] parent in res +match1 m (l1, x1, t1) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res))] p in res +match2 m (l1, x1, t1) (l2, x2, t2) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res))] p in res +match3 m (l1, x1, t1) (l2, x2, t2) (l3, x3, t3) p = let res = Match (m (Just res)) [(l1, x1, t1 (Just res)), (l2, x2, t2 (Just res)), (l3, x3, t3 (Just res))] p in res + + + +instance Show Term where + showsPrec _ (Var x _) = showString x + showsPrec p (App e1 e2 _) = showParen (p > app) (showsPrec app' e1 . showString " " . showsPrec app' e2) + where app = 10 + app' = app+1 + showsPrec p (Lam x e parent) = showParen (not parentIsLam && p > lam) (showString "\\" . showString x . showString " -> " . showsPrec lam' e) + where lam = 5 + lam' = lam+1 + parentIsLam = case parent of Just Lam{} -> True; _ -> False + showsPrec p (Anno e t _) = showParen True (showsPrec p e . showString " : " . showsPrec p t) + showsPrec _ (Zero _) = showString "Zero" + showsPrec p (Succ n _) = showString "Succ " . showsPrec (p + 1) n + showsPrec p (Add l r _) = showString "Add " . showsPrec (p + 1) l . showString " " . showsPrec (p + 1) r + showsPrec p (Mult l r _) = showString "Mult " . showsPrec (p + 1) l . showString " " . showsPrec (p + 1) r + + showsPrec p (Let n e b _) = + showString "let " . showString n . showString " = " . showsPrec (p + 1) e . showString " in " . showsPrec (p + 1) b + showsPrec p (InL t _) = showString "InL " . showsPrec (p + 1) t + showsPrec p (InR t _) = showString "InR " . showsPrec (p + 1) t + showsPrec p (Tag n t _) = showString "Tag " . showString n . showsPrec (p + 1) t + showsPrec p (Match n cases _) = showString "Match" + showsPrec p (Unit _) = showString "unit" + +data Type = UnitT | Nat | Fun Type Type | Sum Type Type | Variant (Map Name Type) | TypeVar Name | AnyType + deriving (Show, Eq) + +type TypeMap = Map.Map Name Type + +instance PO.PartialOrd Type where + _ <= AnyType = True + (Fun t1 t2) <= (Fun u1 u2) = t1 <= u1 && t2 <= u2 + (Sum t1 t2) <= (Sum u1 u2) = t1 <= u1 && t2 <= u2 + (Variant pairs1) <= (Variant pairs2) = (length pairs1 == length pairs2) && (all (\(l1, t1) -> let t2 = (Map.lookup l1 pairs2) in (maybe False (==t1) t2)) $ toList pairs1) + ty1 <= ty2 = ty1 == ty2 + +data Ctx = Empty | Bind Ctx Name Type + deriving (Show, Eq) diff --git a/haskell/test/FORecursiveTypes/BaseSpec.hs b/haskell/test/FORecursiveTypes/BaseSpec.hs new file mode 100644 index 0000000000000000000000000000000000000000..ea2771de2f81d2dcd8912ae0f5ab664be607ba86 --- /dev/null +++ b/haskell/test/FORecursiveTypes/BaseSpec.hs @@ -0,0 +1,22 @@ +module FORecursiveTypes.BaseSpec where + +import Prelude hiding (lookup,(*), (**)) +import Test.Hspec +import Data.Map + +import FORecursiveTypes.Base as B +import FORecursiveTypes.SharedSpecs +import FORecursiveTypes.Language + + +instance ConvertToBInfer B.Infer where + convert x = x + +main :: IO () +main = hspec spec + +typeMap :: TypeMap +typeMap = fromList [("Nat", Variant (fromList [("zero", UnitT), ("succ", TypeVar "Nat")]))] + +spec :: Spec +spec = sharedSpec $ B.inferType Empty typeMap diff --git a/haskell/test/FORecursiveTypes/ErrorListSpec.hs b/haskell/test/FORecursiveTypes/ErrorListSpec.hs new file mode 100644 index 0000000000000000000000000000000000000000..73d36066f8f8967012c1b595318ace895de9cedb --- /dev/null +++ b/haskell/test/FORecursiveTypes/ErrorListSpec.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE FlexibleInstances #-} +module FORecursiveTypes.ErrorListSpec where + +import Prelude hiding (lookup,(*), (**)) +import Test.Hspec +import Data.Map + +import FORecursiveTypes.Base as B +import FORecursiveTypes.SharedSpecs +import FORecursiveTypes.ErrorList as E +import FORecursiveTypes.Language + + +instance ConvertToBInfer E.Infer where + convert (E.Inferred ty) = B.Inferred ty + convert (E.NotInferred err) = B.NotInferred $ head err + + +main :: IO () +main = hspec spec + +typeMap :: TypeMap +typeMap = fromList [("Nat", Variant (fromList [("zero", UnitT), ("succ", TypeVar "Nat")]))] + +spec :: Spec +spec = sharedSpec $ E.inferType Empty typeMap + diff --git a/haskell/test/FORecursiveTypes/SharedSpecs.hs b/haskell/test/FORecursiveTypes/SharedSpecs.hs new file mode 100644 index 0000000000000000000000000000000000000000..060a9212bf1ebaf406379180dfb6156d957420a3 --- /dev/null +++ b/haskell/test/FORecursiveTypes/SharedSpecs.hs @@ -0,0 +1,94 @@ +module FORecursiveTypes.SharedSpecs where + +import Prelude hiding (lookup,(*), (**)) +import Test.Hspec + +import Data.Map +import qualified Data.Map as Map +import FORecursiveTypes.Base as B +import FORecursiveTypes.TestCases +import FORecursiveTypes.Language + + +class ConvertToBInfer m where + convert :: m Type -> B.Infer Type + +isInferred :: B.Infer Type -> Bool +isInferred (B.Inferred _) = True +isInferred (B.NotInferred _) = False + +class ConvertInferredToStr a where + convertToStr :: a -> String + +sharedSpec :: (ConvertToBInfer m) => (Term -> m Type) -> Spec +sharedSpec inferType = do + describe "inferType" $ do + it "should infer Zero to be of type Nat" $ do + let res = convert $ inferType tOkZero in res `shouldBe` B.Inferred Nat + it "should infer (Succ Zero) to be of type Nat" $ do + let res = convert $ inferType tOkSucc in res `shouldBe` B.Inferred Nat + it "should infer arithmetic expression to be of type Nat" $ do + let res = convert $ inferType tOkArithmetic in res `shouldBe` B.Inferred Nat + it "should infer let binding with arithmetic expression to be of type Nat" $ do + let res = convert $ inferType tOkLetBindingWithArith in res `shouldBe` B.Inferred Nat + it "should infer nested lambda expressions" $ do + let res = convert $ inferType tOkAppLambdaAnno in res `shouldBe` B.Inferred Nat + it "should infer annotated lambda to be of type (Fun Nat Nat)" $ do + let res = convert $ inferType tOkLambdaAnno in res `shouldBe` B.Inferred (Fun Nat Nat) + it "should infer application of lambda expression" $ do + let res = convert $ inferType tOkAnno in res `shouldBe` B.Inferred (Fun (Fun Nat Nat) (Fun (Fun Nat Nat) (Fun Nat Nat))) + it "should infer let binding with annotation in named expression" $ do + let res = convert $ inferType tOkAnnoInBindingLet in res `shouldBe` B.Inferred Nat + it "should infer left injection " $ do + let res = convert $ inferType tOkInL in res `shouldBe` B.Inferred (Sum Nat (Fun Nat Nat)) + it "should infer right injection " $ do + let res = convert $ inferType tOkInR in res `shouldBe` B.Inferred (Sum Nat (Fun Nat Nat)) + it "should infer case for left injection " $ do + let res = convert $ inferType tOkCaseLeft in res `shouldBe` B.Inferred Nat + it "should infer case for right injection " $ do + let res = convert $ inferType tOkCaseRight in res `shouldBe` B.Inferred Nat + it "should infer case with shadowing variable name" $ do + let res = convert $ inferType tOkCaseLeftShadowBinding in res `shouldBe` B.Inferred Nat + it "should infer injection into variant with single element " $ do + let res = convert $ inferType tOkSingleVariant in res `shouldBe` B.Inferred (Variant $ fromList [("a", Nat)]) + + it "should infer match against variant with single element" $ do + let res = convert $ inferType tOkMatchSingle in res `shouldBe` B.Inferred Nat + it "should infer match against variant with two elements" $ do + let res = convert $ inferType tOkMatchTwo in res `shouldBe` B.Inferred Nat + it "should infer match against variant with two elements where order of cases is different from order of variant" $ do + let res = convert $ inferType tOkMatchWithDifferentOrder in res `shouldBe` B.Inferred Nat + it "should infer match where case shadows outer binding"$ do + let res = convert $ inferType tOkMatchShadowBinding in res `shouldBe` B.Inferred Nat + it "should infer match where bindings are and terms are the same" $ do + let res = convert $ inferType tOkCasesSameBinding in res `shouldBe` B.Inferred (Fun Nat Nat) + it "should infer recursive zero of recursive nat type" $ do + let res = convert $ inferType tOkZeroRecNatType in res `shouldBe` B.Inferred (TypeVar "Nat") + it "should infer recursive nested succ of recursive nat type" $ do + let res = convert $ inferType tOkNestedSuccRecNatType in res `shouldBe` B.Inferred (TypeVar "Nat") + + it "should fail while inferring an arithmetic expression with an unapplied lambda expression" $ do + let res = isInferred $ convert $ inferType tFailArithmetic in res `shouldBe` False + it "should fail while inferring an application of a lambda expression without type annotation" $ do + let res = isInferred $ convert $ inferType tFailLambdaNoAnno in res `shouldBe` False + it "should fail while trying to infer a let binding without annotation of lambda in named expression" $ do + let res = isInferred $ convert $ inferType tFailLambdaNotAnnoInBindingLet in res `shouldBe` False + it "should fail while infering left injection" $ do + let res = isInferred $ convert $ inferType tFailInL in res `shouldBe` False + it "should fail while infering right injection" $ do + let res = isInferred $ convert $ inferType tFailInR in res `shouldBe` False + it "should fail while matching on left injection where returntypes are unequal" $ do + let res = isInferred $ convert $ inferType tFailCaseUnequalReturnTypes in res `shouldBe` False + it "should fail while matching on right injection where binding should shadow" $ do + let res = isInferred $ convert $ inferType tFailCaseRightShadowBinding in res `shouldBe` False + + it "should fail while tagging wrong type with label" $ do + let res = isInferred $ convert $ inferType tFailWrongTypeForTaggedLabel in res `shouldBe` False + it "should fail while matching against non variant type" $ do + let res = isInferred $ convert $ inferType tFailMatchNotVariant in res `shouldBe` False + it "should fail while tagging wrong label" $ do + let res = isInferred $ convert $ inferType tFailWrongLabel in res `shouldBe` False + it "should fail while matching where return types of cases are not equal" $ do + let res = isInferred $ convert $ inferType tFailMatchUnequalTypesForCases in res `shouldBe` False + it "should fail while infering a undefined type" $ do + let res = isInferred $ convert $ inferType tFailRecTypeNotDefined in res `shouldBe` False diff --git a/haskell/test/FORecursiveTypes/TestCases.hs b/haskell/test/FORecursiveTypes/TestCases.hs new file mode 100644 index 0000000000000000000000000000000000000000..69f9862c219427d6c23b2aaaef9232edf70fd05f --- /dev/null +++ b/haskell/test/FORecursiveTypes/TestCases.hs @@ -0,0 +1,55 @@ +{-# OPTIONS_GHC -Wno-orphans #-} +module FORecursiveTypes.TestCases where + +import Prelude hiding (lookup,(*), (**), succ, fst, snd) + +import FORecursiveTypes.Language +import Data.Map +import qualified Data.Map as Map + +tOkZero = root zero +tOkSucc = root $ succ $ succ zero +tOkArithmetic = root $ (add (succ $ succ zero) (mult zero (succ $ succ $ succ $ succ zero))) +tOkLetBindingWithArith = root $ (let' "x" (add zero (succ zero)) (mult (var "x") (succ $ succ zero))) +tOkAnno = root $ (anno (Fun (Fun Nat Nat) (Fun (Fun Nat Nat) (Fun Nat Nat))) (lam "f" $ lam "g" $ lam "b" (app (var "g") (app (var "f") (var "b"))))) +tOkLambdaAnno = root $ (anno (Fun Nat Nat) (lam "a" (add (var "a") zero))) +tOkAppLambdaAnno = root $ (app (anno (Fun Nat Nat) (lam "b" (add (var "b") zero))) (succ $ succ zero)) +tOkAnnoInBindingLet = root $ (let' "e" (anno Nat (add zero (succ zero))) (mult (var "e") (succ zero))) + +tOkInL = root $ (anno (Sum Nat (Fun Nat Nat)) (inl $ succ zero)) +tOkInR = root $ (anno (Sum Nat (Fun Nat Nat)) (inr (anno (Fun Nat Nat) (lam "b" zero)))) +tOkCaseLeft = root $ (anno Nat (case' (anno (Sum Nat Nat) (inl $ zero)) "a" (var "a") "b" zero)) +tOkCaseRight = root $ (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" $ add zero $ var "x")) "a" zero "b" (app (var "b") (succ $ succ zero)))) +tOkCaseLeftShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum (Fun Nat Nat) Nat) (inl $ lam "x" zero)) "a" (app (var "a") zero) "b" (var "a")))) + +tOkSingleVariant = root $ (anno (Variant (fromList [("a", Nat)])) (tag "a" (succ zero))) +tOkTwoVariant = root $ (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero))) +tOkThreeVariant = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" zero))) +tOkMatchSingle = root $ (anno Nat (match1 (anno (Variant (fromList [("a", Nat)])) (tag "a" zero)) ("a", "x", succ $ var "x"))) +tOkMatchTwo = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero)))) +tOkMatchWithDifferentOrder = root $ (anno Nat (match2 (anno (Variant (fromList [("b", (Fun Nat Nat)), ("a", Nat)])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero)))) +tOkMatchShadowBinding = root $ (let' "x" zero (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (app (var "x") zero))))) +tOkCasesSameBinding = root $ (anno (Fun Nat Nat) (match2 (anno (Variant (fromList [("a", Nat), ("b", Nat)])) (tag "b" zero)) ("a", "x", (lam "y" (mult (var "y") (var "x")))) ("b", "x", (lam "y" (mult (var "y") (var "x")))))) + + +tFailArithmetic = root $ (add (succ $ succ zero) (mult (anno (Fun Nat Nat) (lam "f" (var "f"))) (succ $ succ $ succ $ succ zero))) +tFailLambdaNoAnno = root $ (anno Nat (app (lam "b" (add (var "b") zero)) (succ $ succ zero))) +tFailLambdaNotAnnoInBindingLet = root $ (let' "e" (lam "n" (add (var "n") (succ zero))) (mult (app (var "e") zero) (succ zero))) + +tFailInL = root $ (anno (Sum Nat Nat) (inl (lam "b" zero))) +tFailInR = root $ (anno (Sum Nat Nat) (inr (lam "x" $ var "x"))) +tFailCaseUnequalReturnTypes = root $ (anno (Fun Nat Nat) (case' (anno Nat (inl $ succ $ succ zero)) "a" (lam "x" (add (var "a") zero)) "b" zero)) +tFailCaseRightShadowBinding = root $ (let' "a" zero (anno Nat (case' (anno (Sum Nat (Fun Nat Nat)) (inr $ lam "x" zero)) "a" zero "b" (var "b")))) + + +tFailWrongTypeForTaggedLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" zero)) +tFailWrongLabel = root $ (anno (Variant (fromList [("c", Nat), ("a", Nat), ("b", (Fun Nat Nat))])) (tag "d" zero)) + +tFailMatchNotVariant = root $ (anno Nat (match1 (anno Nat zero) ("a", "x", succ $ var "x"))) + +tFailMatchUnequalTypesForCases = root $ (anno Nat (match2 (anno (Variant (fromList [("a", Nat), ("b", (Fun Nat Nat))])) (tag "b" (lam "x" (var "x")))) ("a", "x", succ $ var "x") ("b", "x", (var "x") ))) + +tOkZeroRecNatType = root $ (anno (TypeVar "Nat") (tag "zero" unit)) +tOkNestedSuccRecNatType = root $ (anno (TypeVar "Nat") (tag "succ" (tag "succ" (tag "succ" (tag "zero" unit))))) + +tFailRecTypeNotDefined = root $ (anno (TypeVar "Int") zero)