Commit 9371bf48 authored by André Pacak's avatar André Pacak

fo recursive types eliminate typemap argument

parent 0b2702b3
{-# LANGUAGE ConstraintKinds, TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RebindableSyntax #-}
module FORecursiveTypes.EliminateTypeMapArgument where
import Prelude hiding (Monad(..), (>=), (<=), lookup)
import GHC.Exts (Constraint)
import Data.List(find)
import Data.Maybe(isJust, fromJust)
import Data.Map(Map)
import qualified Data.Map as Map
import FORecursiveTypes.Language
import Util.ErrorMessages
-- is needed because we use the RebindableSyntax extension
ifThenElse :: Bool -> a -> a -> a
ifThenElse True thn _ = thn
ifThenElse False _ els = els
type Error = [String]
data Infer a
= Inferred a
| NotInferred Error deriving (Eq, Show)
type Check = Infer ()
instance Functor Infer where
fmap _ (NotInferred err) = NotInferred err
fmap f (Inferred ty) = Inferred $ f ty
instance Applicative Infer where
pure = Inferred
(NotInferred err) <*> _ = NotInferred err
(Inferred a) <*> something = fmap a something
class WithTop a where
top :: a
instance WithTop Type where
top = AnyType
instance WithTop () where
top = ()
instance (WithTop a, WithTop b) => WithTop (a, b) where
top = (top, top)
instance WithTop a => WithTop [a] where
top = [top]
instance WithTop a => WithTop (Map Name a) where
top = Map.empty
-- Had to define an own monad type class.
-- It is not possible otherwise to get the type constraint WithTop a.
-- We use the extension ConstraintKinds to support this.
-- Could not find a simpler solution for this problem.
-- The restricted monad problem is common.
class RMonad m where
type RMonadCtx m a :: Constraint
return :: RMonadCtx m a => a -> m a
(>>=) :: (RMonadCtx m a, RMonadCtx m b) => m a -> (a -> m b) -> m b
(>>) :: (RMonadCtx m a, RMonadCtx m b) => m a -> m b -> m b
m >> k = m >>= \_ -> k
fail :: [String] -> m a
instance RMonad Infer where
type RMonadCtx Infer a = WithTop a
return = Inferred
(Inferred ty) >>= f = f ty
NotInferred err1 >>= f =
-- we know that top is Inferred AnyType, (AnyType, AnyType) or () by definition
case f top of
Inferred _ -> fail err1
NotInferred err2 -> fail $ err1 ++ err2
fail = NotInferred
-- matching functions that extract the inner types if possible
-- one problem is that we do not get as good error messages, because term is not known in these functions
matchNat :: Type -> String -> Check
matchNat Nat _ = return ()
matchNat ty err = fail [natError ty err]
matchFun :: Type -> String -> Infer (Type, Type)
matchFun (Fun ty1 ty2) _ = return (ty1, ty2)
matchFun ty err = fail [funError ty err]
matchType :: Type -> Type -> String -> Check
matchType ty1 ty2 _
| ty1 == ty2 = return ()
matchType ty1 ty2 err = fail [generalError (show ty1) ty2 err]
matchSum :: Type -> String -> Infer (Type, Type)
matchSum (Sum ty1 ty2) _ = return (ty1, ty2)
matchSum ty err = fail [sumError ty err]
matchVariant :: Type -> String -> Infer (Map.Map Name Type)
matchVariant (Variant types) _ = return types
matchVariant ty err = fail [variantError ty err]
-- lookupTypeVar :: TypeMap -> Name -> Infer Type
-- lookupTypeVar Empty x = fail ["Unbound type variable " ++ show x]
-- lookupTypeVar (Bind c x t) y
-- | x == y = return t
-- | otherwise = lookupTypeVar c y
lookupTypeVar :: Term -> Name -> Infer Type
lookupTypeVar t x = case parent t of
Just p@(Succ term _) -> lookupTypeVar p x
Just p@(Add t1 t2 _) | t == t1 -> lookupTypeVar p x
Just p@(Add t1 t2 _) | t == t2 -> lookupTypeVar p x
Just p@(Mult t1 t2 _) | t == t1 -> lookupTypeVar p x
Just p@(Mult t1 t2 _) | t == t2 -> lookupTypeVar p x
Just p@(Var name _) -> lookupTypeVar p name
Just p@(Let name term body _) | t == term -> lookupTypeVar p x
Just p@(Let name term body _) | t == body -> lookupTypeVar p x
Just p@(Anno term ty _) | t == term -> lookupTypeVar p x
Just p@(App t1 t2 _) | t == t1 -> lookupTypeVar p x
Just p@(App t1 t2 _) | t == t2 -> lookupTypeVar p x
Just p@(LetType n ty term _) | t == term -> do
if x == n
then return ty
else lookupTypeVar p x
Just p@(Lam name term _) | t == term -> lookupTypeVar p x
Just p@(InL term _) | t == term -> lookupTypeVar p x
Just p@(InR term _) | t == term -> lookupTypeVar p x
Just p@(Case e n1 t1 n2 t2 _) | t == e -> lookupTypeVar p x
Just p@(Case e n1 t1 n2 t2 _) | t == t1 -> lookupTypeVar p x
Just p@(Case e n1 t1 n2 t2 _) | t == t2 -> lookupTypeVar p x
Just p@(Tag n term _) | t == term -> lookupTypeVar p x
Just p@(Match m cases _) | t == m -> lookupTypeVar p x
Just p@(Match m cases _) | t /= m -> lookupTypeVar p x
Just p -> lookupTypeVar p x
Nothing -> fail ["Unbound type variable " ++ show x]
matchTypeVar :: Term -> Type -> Infer Type
matchTypeVar t (TypeVar x) = lookupTypeVar t x
matchTypeVar _ ty = return ty
liftMaybe :: WithTop a => Maybe a -> String -> Infer a
liftMaybe (Just a) _ = return a
liftMaybe Nothing err = fail [err]
lookup :: Term -> Name -> Infer Type
lookup t x = case parent t of
Just p@(Succ t _) -> lookup p x
Just p@(Add t1 t2 _) | t == t1 -> lookup p x
Just p@(Add t1 t2 _) | t == t2 -> lookup p x
Just p@(Mult t1 t2 _) | t == t1 -> lookup p x
Just p@(Mult t1 t2 _) | t == t2 -> lookup p x
Just p@(Var name _) -> lookup p x
Just p@(Let name term body _) | t == term -> lookup p x
Just p@(Let name term body _) | t == body ->
if name == x
then do
ty <- inferType term
return ty
else lookup p x
Just p@(Anno term ty _) -> lookup p x
Just p@(App t1 t2 _) | t == t1 -> lookup p x
Just p@(App t1 t2 _) | t == t2 -> lookup p x
Just p@(Lam name term _) ->
if name == x
then do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchFun rty (show p)
return ty1
else lookup p x
Just p@(InL term _) -> lookup p x
Just p@(InR term _) -> lookup p x
Just p@(Case e n1 t1 n2 t2 _) | t == t1 ->
if n1 == x
then do
ety <- inferType e
rty <- matchTypeVar p ety
(ty1, ty2) <- matchSum rty (show e)
return ty1
else lookup p x
Just p@(Case e n1 t1 n2 t2 _) | t == t2 ->
if n2 == x
then do
ety <- inferType e
rty <- matchTypeVar p ety
(ty1, ty2) <- matchSum rty (show e)
return ty2
else lookup p x
Just p@(Tag n t' _) | t == t' -> lookup p x
Just p@(Match m cases _) | t == m -> lookup p x
Just p@(Match m cases _) | t /= m -> do
ety <- inferType m
rty <- matchTypeVar p ety
typeMap <- matchVariant rty (show m)
let ml = find (\(_, x', t') -> t == t' && x == x') cases
if isJust ml
then do
let (l, _, _) = fromJust ml
liftMaybe (Map.lookup l typeMap) "Could not find labeled type"
else lookup p x
Just p@(LetType x ty term _) | t == term -> lookup p x
Just p -> lookup p x
Nothing -> fail ["Unbound variable " ++ show x]
inferType :: Term -> Infer Type
inferType (Unit _) = return UnitT
inferType (Zero _) = return Nat
inferType (Succ t _) = do
checkType t
return Nat
inferType (Add t1 t2 _) = do
checkType t1
checkType t2
return Nat
inferType (Mult t1 t2 _) = do
checkType t1
checkType t2
return Nat
inferType p@(Var name _) = lookup p name
inferType (Let name t body _) = do
tyt <- inferType t
inferType body
inferType (Anno term ty _) = do
checkType term
return ty
inferType p@(App t1 t2 _) = do
ty <- inferType t1
rty <- matchTypeVar p ty
(ty1, ty2) <- matchFun rty (show t1)
checkType t2
return ty2
inferType (LetType n ty t _) = inferType t
inferType t = fail ["Cannot infer type of term " ++ show t]
checkType :: Term -> Check
checkType p@(Lam name t _) = do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchFun rty (show p)
checkType t
checkType p@(InL t _) = do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchSum rty (show p)
checkType t
checkType p@(InR t _) = do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchSum rty (show p)
checkType t
checkType p@(Case e n1 t1 n2 t2 _) = do
ety <- inferType e
rty <- matchTypeVar p ety
(ty1, ty2) <- matchSum rty (show e)
ty <- requiredType p
checkType t1
checkType t2
checkType p@(Tag n t _) = do
ty <- requiredType p
rty <- matchTypeVar p ty
types <- matchVariant rty (show p)
let lty = Map.lookup n types
lty <- liftMaybe (Map.lookup n types) "Label not contained in Variant"
checkType t
checkType p@(Match m cases _) = do
ety <- inferType m
rty <- matchTypeVar p ety
types <- matchVariant rty (show m)
ty <- requiredType p
let subchecks =
map (\(l, x, t) -> do
lty <- liftMaybe (Map.lookup l types) "Could not find labeled type"
checkType t
) cases
foldl (>>) (return ()) subchecks
checkType t = do
ty <- requiredType t
ty' <- inferType t
matchType ty ty' (show t)
requiredType :: Term -> Infer Type
requiredType t = case parent t of
Just (Succ t' _) | t == t' -> return Nat
Just (Add t1 t2 _) | t == t1 -> return Nat
Just (Add t1 t2 _) | t == t2 -> return Nat
Just (Mult t1 t2 _) | t == t1 -> return Nat
Just (Mult t1 t2 _) | t == t2 -> return Nat
Just (Anno term ty _) | t == term -> return ty
Just p@(App t1 t2 _) | t == t2 -> do
ty <- inferType t1
rty <- matchTypeVar p ty
(ty1, ty2) <- matchFun rty (show t1)
return ty1
Just p@(Lam name term _) | t == term -> do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchFun rty (show p)
return ty2
Just p@(InL term _) | t == term -> do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchSum rty (show p)
return ty1
Just p@(InR term _) | t == term -> do
ty <- requiredType p
rty <- matchTypeVar p ty
(ty1, ty2) <- matchSum rty (show p)
return ty2
Just p@(Case e n1 t1 n2 t2 _) | t == t1 -> do
ety <- inferType e
rty <- matchTypeVar p ety
(ty1, ty2) <- matchSum rty (show e)
ty <- requiredType p
return ty
Just p@(Case e n1 t1 n2 t2 _) | t == t2 -> do
ety <- inferType e
rty <- matchTypeVar p ety
(ty1, ty2) <- matchSum rty (show e)
ty <- requiredType p
return ty
Just p@(Tag n term _) | t == term -> do
ty <- requiredType p
rty <- matchTypeVar p ty
types <- matchVariant rty (show p)
let lty = Map.lookup n types
liftMaybe (Map.lookup n types) "Label not contained in Variant"
Just p@(Match m cases _) -> do
ty <- requiredType p
return ty
_ -> fail ["Could not determine required type"]
{-# LANGUAGE FlexibleInstances #-}
module FORecursiveTypes.EliminateTypeMapArgumentSpec where
import Prelude hiding (lookup,(*), (**))
import Test.Hspec
import FORecursiveTypes.Base as B
import FORecursiveTypes.SharedSpecs
import FORecursiveTypes.EliminateTypeMapArgument as E
import FORecursiveTypes.Language
instance ConvertToBInfer E.Infer where
convert (E.Inferred ty) = B.Inferred ty
convert (E.NotInferred err) = B.NotInferred $ head err
main :: IO ()
main = hspec spec
spec :: Spec
spec = sharedSpec $ E.inferType
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment