From 9371bf48fb129f27e1bd27bb5e51b64fa71d69b8 Mon Sep 17 00:00:00 2001 From: Andre Pacak Date: Sat, 28 Sep 2019 17:44:01 +0200 Subject: [PATCH] fo recursive types eliminate typemap argument --- .../EliminateTypeMapArgument.hs | 335 ++++++++++++++++++ .../EliminateTypeMapArgumentSpec.hs | 25 ++ 2 files changed, 360 insertions(+) create mode 100644 haskell/src/FORecursiveTypes/EliminateTypeMapArgument.hs create mode 100644 haskell/test/FORecursiveTypes/EliminateTypeMapArgumentSpec.hs diff --git a/haskell/src/FORecursiveTypes/EliminateTypeMapArgument.hs b/haskell/src/FORecursiveTypes/EliminateTypeMapArgument.hs new file mode 100644 index 0000000..967d8bd --- /dev/null +++ b/haskell/src/FORecursiveTypes/EliminateTypeMapArgument.hs @@ -0,0 +1,335 @@ +{-# LANGUAGE ConstraintKinds, TypeFamilies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE RebindableSyntax #-} +module FORecursiveTypes.EliminateTypeMapArgument where + +import Prelude hiding (Monad(..), (>=), (<=), lookup) +import GHC.Exts (Constraint) +import Data.List(find) +import Data.Maybe(isJust, fromJust) +import Data.Map(Map) +import qualified Data.Map as Map + +import FORecursiveTypes.Language +import Util.ErrorMessages + +-- is needed because we use the RebindableSyntax extension +ifThenElse :: Bool -> a -> a -> a +ifThenElse True thn _ = thn +ifThenElse False _ els = els + +type Error = [String] + +data Infer a + = Inferred a + | NotInferred Error deriving (Eq, Show) + +type Check = Infer () + +instance Functor Infer where + fmap _ (NotInferred err) = NotInferred err + fmap f (Inferred ty) = Inferred $ f ty + +instance Applicative Infer where + pure = Inferred + (NotInferred err) <*> _ = NotInferred err + (Inferred a) <*> something = fmap a something + +class WithTop a where + top :: a + +instance WithTop Type where + top = AnyType + +instance WithTop () where + top = () + +instance (WithTop a, WithTop b) => WithTop (a, b) where + top = (top, top) + +instance WithTop a => WithTop [a] where + top = [top] + +instance WithTop a => WithTop (Map Name a) where + top = Map.empty + +-- Had to define an own monad type class. +-- It is not possible otherwise to get the type constraint WithTop a. +-- We use the extension ConstraintKinds to support this. +-- Could not find a simpler solution for this problem. +-- The restricted monad problem is common. +class RMonad m where + type RMonadCtx m a :: Constraint + return :: RMonadCtx m a => a -> m a + (>>=) :: (RMonadCtx m a, RMonadCtx m b) => m a -> (a -> m b) -> m b + (>>) :: (RMonadCtx m a, RMonadCtx m b) => m a -> m b -> m b + m >> k = m >>= \_ -> k + fail :: [String] -> m a + +instance RMonad Infer where + type RMonadCtx Infer a = WithTop a + return = Inferred + (Inferred ty) >>= f = f ty + NotInferred err1 >>= f = + -- we know that top is Inferred AnyType, (AnyType, AnyType) or () by definition + case f top of + Inferred _ -> fail err1 + NotInferred err2 -> fail $ err1 ++ err2 + fail = NotInferred + +-- matching functions that extract the inner types if possible +-- one problem is that we do not get as good error messages, because term is not known in these functions +matchNat :: Type -> String -> Check +matchNat Nat _ = return () +matchNat ty err = fail [natError ty err] + +matchFun :: Type -> String -> Infer (Type, Type) +matchFun (Fun ty1 ty2) _ = return (ty1, ty2) +matchFun ty err = fail [funError ty err] + +matchType :: Type -> Type -> String -> Check +matchType ty1 ty2 _ + | ty1 == ty2 = return () +matchType ty1 ty2 err = fail [generalError (show ty1) ty2 err] + +matchSum :: Type -> String -> Infer (Type, Type) +matchSum (Sum ty1 ty2) _ = return (ty1, ty2) +matchSum ty err = fail [sumError ty err] + +matchVariant :: Type -> String -> Infer (Map.Map Name Type) +matchVariant (Variant types) _ = return types +matchVariant ty err = fail [variantError ty err] + +-- lookupTypeVar :: TypeMap -> Name -> Infer Type +-- lookupTypeVar Empty x = fail ["Unbound type variable " ++ show x] +-- lookupTypeVar (Bind c x t) y +-- | x == y = return t +-- | otherwise = lookupTypeVar c y + + +lookupTypeVar :: Term -> Name -> Infer Type +lookupTypeVar t x = case parent t of + Just p@(Succ term _) -> lookupTypeVar p x + Just p@(Add t1 t2 _) | t == t1 -> lookupTypeVar p x + Just p@(Add t1 t2 _) | t == t2 -> lookupTypeVar p x + Just p@(Mult t1 t2 _) | t == t1 -> lookupTypeVar p x + Just p@(Mult t1 t2 _) | t == t2 -> lookupTypeVar p x + Just p@(Var name _) -> lookupTypeVar p name + Just p@(Let name term body _) | t == term -> lookupTypeVar p x + Just p@(Let name term body _) | t == body -> lookupTypeVar p x + Just p@(Anno term ty _) | t == term -> lookupTypeVar p x + Just p@(App t1 t2 _) | t == t1 -> lookupTypeVar p x + Just p@(App t1 t2 _) | t == t2 -> lookupTypeVar p x + Just p@(LetType n ty term _) | t == term -> do + if x == n + then return ty + else lookupTypeVar p x + Just p@(Lam name term _) | t == term -> lookupTypeVar p x + Just p@(InL term _) | t == term -> lookupTypeVar p x + Just p@(InR term _) | t == term -> lookupTypeVar p x + Just p@(Case e n1 t1 n2 t2 _) | t == e -> lookupTypeVar p x + Just p@(Case e n1 t1 n2 t2 _) | t == t1 -> lookupTypeVar p x + Just p@(Case e n1 t1 n2 t2 _) | t == t2 -> lookupTypeVar p x + Just p@(Tag n term _) | t == term -> lookupTypeVar p x + Just p@(Match m cases _) | t == m -> lookupTypeVar p x + Just p@(Match m cases _) | t /= m -> lookupTypeVar p x + Just p -> lookupTypeVar p x + Nothing -> fail ["Unbound type variable " ++ show x] + + +matchTypeVar :: Term -> Type -> Infer Type +matchTypeVar t (TypeVar x) = lookupTypeVar t x +matchTypeVar _ ty = return ty + +liftMaybe :: WithTop a => Maybe a -> String -> Infer a +liftMaybe (Just a) _ = return a +liftMaybe Nothing err = fail [err] + +lookup :: Term -> Name -> Infer Type +lookup t x = case parent t of + Just p@(Succ t _) -> lookup p x + Just p@(Add t1 t2 _) | t == t1 -> lookup p x + Just p@(Add t1 t2 _) | t == t2 -> lookup p x + Just p@(Mult t1 t2 _) | t == t1 -> lookup p x + Just p@(Mult t1 t2 _) | t == t2 -> lookup p x + Just p@(Var name _) -> lookup p x + Just p@(Let name term body _) | t == term -> lookup p x + Just p@(Let name term body _) | t == body -> + if name == x + then do + ty <- inferType term + return ty + else lookup p x + Just p@(Anno term ty _) -> lookup p x + Just p@(App t1 t2 _) | t == t1 -> lookup p x + Just p@(App t1 t2 _) | t == t2 -> lookup p x + Just p@(Lam name term _) -> + if name == x + then do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchFun rty (show p) + return ty1 + else lookup p x + Just p@(InL term _) -> lookup p x + Just p@(InR term _) -> lookup p x + Just p@(Case e n1 t1 n2 t2 _) | t == t1 -> + if n1 == x + then do + ety <- inferType e + rty <- matchTypeVar p ety + (ty1, ty2) <- matchSum rty (show e) + return ty1 + else lookup p x + Just p@(Case e n1 t1 n2 t2 _) | t == t2 -> + if n2 == x + then do + ety <- inferType e + rty <- matchTypeVar p ety + (ty1, ty2) <- matchSum rty (show e) + return ty2 + else lookup p x + Just p@(Tag n t' _) | t == t' -> lookup p x + Just p@(Match m cases _) | t == m -> lookup p x + Just p@(Match m cases _) | t /= m -> do + ety <- inferType m + rty <- matchTypeVar p ety + typeMap <- matchVariant rty (show m) + let ml = find (\(_, x', t') -> t == t' && x == x') cases + if isJust ml + then do + let (l, _, _) = fromJust ml + liftMaybe (Map.lookup l typeMap) "Could not find labeled type" + else lookup p x + Just p@(LetType x ty term _) | t == term -> lookup p x + Just p -> lookup p x + Nothing -> fail ["Unbound variable " ++ show x] + +inferType :: Term -> Infer Type +inferType (Unit _) = return UnitT +inferType (Zero _) = return Nat +inferType (Succ t _) = do + checkType t + return Nat +inferType (Add t1 t2 _) = do + checkType t1 + checkType t2 + return Nat +inferType (Mult t1 t2 _) = do + checkType t1 + checkType t2 + return Nat +inferType p@(Var name _) = lookup p name +inferType (Let name t body _) = do + tyt <- inferType t + inferType body +inferType (Anno term ty _) = do + checkType term + return ty +inferType p@(App t1 t2 _) = do + ty <- inferType t1 + rty <- matchTypeVar p ty + (ty1, ty2) <- matchFun rty (show t1) + checkType t2 + return ty2 +inferType (LetType n ty t _) = inferType t +inferType t = fail ["Cannot infer type of term " ++ show t] + +checkType :: Term -> Check +checkType p@(Lam name t _) = do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchFun rty (show p) + checkType t +checkType p@(InL t _) = do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchSum rty (show p) + checkType t +checkType p@(InR t _) = do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchSum rty (show p) + checkType t +checkType p@(Case e n1 t1 n2 t2 _) = do + ety <- inferType e + rty <- matchTypeVar p ety + (ty1, ty2) <- matchSum rty (show e) + ty <- requiredType p + checkType t1 + checkType t2 +checkType p@(Tag n t _) = do + ty <- requiredType p + rty <- matchTypeVar p ty + types <- matchVariant rty (show p) + let lty = Map.lookup n types + lty <- liftMaybe (Map.lookup n types) "Label not contained in Variant" + checkType t +checkType p@(Match m cases _) = do + ety <- inferType m + rty <- matchTypeVar p ety + types <- matchVariant rty (show m) + ty <- requiredType p + let subchecks = + map (\(l, x, t) -> do + lty <- liftMaybe (Map.lookup l types) "Could not find labeled type" + checkType t + ) cases + foldl (>>) (return ()) subchecks + +checkType t = do + ty <- requiredType t + ty' <- inferType t + matchType ty ty' (show t) + + +requiredType :: Term -> Infer Type +requiredType t = case parent t of + Just (Succ t' _) | t == t' -> return Nat + Just (Add t1 t2 _) | t == t1 -> return Nat + Just (Add t1 t2 _) | t == t2 -> return Nat + Just (Mult t1 t2 _) | t == t1 -> return Nat + Just (Mult t1 t2 _) | t == t2 -> return Nat + Just (Anno term ty _) | t == term -> return ty + Just p@(App t1 t2 _) | t == t2 -> do + ty <- inferType t1 + rty <- matchTypeVar p ty + (ty1, ty2) <- matchFun rty (show t1) + return ty1 + Just p@(Lam name term _) | t == term -> do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchFun rty (show p) + return ty2 + Just p@(InL term _) | t == term -> do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchSum rty (show p) + return ty1 + Just p@(InR term _) | t == term -> do + ty <- requiredType p + rty <- matchTypeVar p ty + (ty1, ty2) <- matchSum rty (show p) + return ty2 + Just p@(Case e n1 t1 n2 t2 _) | t == t1 -> do + ety <- inferType e + rty <- matchTypeVar p ety + (ty1, ty2) <- matchSum rty (show e) + ty <- requiredType p + return ty + Just p@(Case e n1 t1 n2 t2 _) | t == t2 -> do + ety <- inferType e + rty <- matchTypeVar p ety + (ty1, ty2) <- matchSum rty (show e) + ty <- requiredType p + return ty + Just p@(Tag n term _) | t == term -> do + ty <- requiredType p + rty <- matchTypeVar p ty + types <- matchVariant rty (show p) + let lty = Map.lookup n types + liftMaybe (Map.lookup n types) "Label not contained in Variant" + Just p@(Match m cases _) -> do + ty <- requiredType p + return ty + _ -> fail ["Could not determine required type"] diff --git a/haskell/test/FORecursiveTypes/EliminateTypeMapArgumentSpec.hs b/haskell/test/FORecursiveTypes/EliminateTypeMapArgumentSpec.hs new file mode 100644 index 0000000..7589728 --- /dev/null +++ b/haskell/test/FORecursiveTypes/EliminateTypeMapArgumentSpec.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE FlexibleInstances #-} + +module FORecursiveTypes.EliminateTypeMapArgumentSpec where + +import Prelude hiding (lookup,(*), (**)) +import Test.Hspec + +import FORecursiveTypes.Base as B +import FORecursiveTypes.SharedSpecs +import FORecursiveTypes.EliminateTypeMapArgument as E +import FORecursiveTypes.Language + + +instance ConvertToBInfer E.Infer where + convert (E.Inferred ty) = B.Inferred ty + convert (E.NotInferred err) = B.NotInferred $ head err + + +main :: IO () +main = hspec spec + +spec :: Spec +spec = sharedSpec $ E.inferType + + -- GitLab