ErrorList.hs 5.43 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
module ADTypes.ErrorList where

import Prelude hiding ((>=), (<=))
import qualified Data.Map as M
import Data.Maybe
import ADTypes.Language
import Util.ErrorMessages

type Error = [String]

data Infer a
  = Inferred a
  | NotInferred Error deriving (Eq, Show)

type Check = Infer ()

instance Functor Infer where
  fmap _ (NotInferred err) = NotInferred err
  fmap f (Inferred ty) = Inferred $ f ty

instance Applicative Infer where
  pure = Inferred
  (NotInferred err) <*> _ = NotInferred err
  (Inferred a) <*> something = fmap a something

instance Monad Infer where
  return = Inferred
  (Inferred ty) >>= f = f ty
  (NotInferred err) >>= _ = NotInferred err
  fail msg = NotInferred [msg]

-- matching functions that extract the inner types if possible
matchNat :: Type -> String -> Check
matchNat Nat _ = return ()
matchNat ty err = fail $ natError ty err

matchBool :: Type -> String -> Check
matchBool Bool _ = return ()
matchBool ty err = fail $ boolError ty err

matchFun :: Type -> String -> Infer (Type, Type)
matchFun (Fun ty1 ty2) _ = return (ty1, ty2)
matchFun ty err = fail $ funError ty err

matchType :: Type -> Type -> String -> Check
matchType ty1 ty2 _
  | ty1 == ty2 = return ()
matchType ty1 ty2 err = fail $ generalError (show ty1) ty2 err


matchADT :: Type -> String -> Infer (Name, M.Map Name [Type])
matchADT (ADT n cotrs) _ = return (n, cotrs)
matchADT ty err = fail $ sumError ty err

liftMaybe :: Monad m => Maybe a -> String -> m a
liftMaybe (Just a) _ = return a
liftMaybe Nothing err = fail err

lookupVar :: Ctx -> Name -> Infer Type
lookupVar Empty x = fail $ "Unbound variable " ++ show x
lookupVar (Bind c x t) y
  | x == y = return t
  | otherwise = lookupVar c y

lookupTypeVar :: TypeMap -> Name -> Infer Type
lookupTypeVar Empty x = fail $ "Unbound type variable " ++ show x
lookupTypeVar (Bind c x t) y
  | x == y = return t
  | otherwise = lookupTypeVar c y

matchTypeVar :: TypeMap -> Type -> Infer Type
matchTypeVar tymap (TypeVar x) = lookupTypeVar tymap x
matchTypeVar _ ty = return ty

lookupCotr :: TypeMap -> Name -> Infer (Type, [Type])
lookupCotr Empty x = fail $ "Unbound constructor " ++ show x
lookupCotr (Bind c x ty) y = case ty of
  t@(ADT n cotrs) ->
    let mayCotr = M.lookup y cotrs in
    if isJust mayCotr
    then return (t, fromJust mayCotr)
    else lookupCotr c y
  _ -> lookupCotr c y

inferType :: Ctx -> TypeMap -> Term -> Infer Type
inferType _ _ (Zero _) = return Nat
inferType ctx tymap (Succ t _) = do
  checkType ctx tymap t Nat
  return Nat
inferType ctx tymap (Add t1 t2 _) = do
  checkType ctx tymap t1 Nat
  checkType ctx tymap t2 Nat
  return Nat
inferType ctx tymap (Mult t1 t2 _) = do
  checkType ctx tymap t1 Nat
  checkType ctx tymap t2 Nat
  return Nat
inferType _ _ (Tru _) = return Bool
inferType _ _ (Fls _) = return Bool
inferType ctx tymap (Not t _) = do
  checkType ctx tymap t Bool
  return Bool
inferType ctx tymap (And t1 t2 _) = do
  checkType ctx tymap t1 Bool
  checkType ctx tymap t2 Bool
  return Bool
inferType ctx tymap (Or t1 t2 _) = do
  checkType ctx tymap t1 Bool
  checkType ctx tymap t2 Bool
  return Bool
111 112 113
inferType ctx tymap (Var name _) = do
  ty <- lookupVar ctx name
  matchTypeVar tymap ty
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
inferType ctx tymap (Let name t body _) = do
  tyt <- inferType ctx tymap t
  inferType (Bind ctx name tyt) tymap body
inferType ctx tymap (Anno term ty _) = do
  checkType ctx tymap term ty
  return ty
inferType ctx tymap (App t1 t2 _) = do
  ty <- inferType ctx tymap t1
  (ty1, ty2) <- matchFun ty (show t1)
  checkType ctx tymap t2 ty1
  return ty2
inferType ctx tymap (LetData n adty t _) =
  -- TODO check that n is not already bound
  -- TODO check that elements in n are not bound
  inferType ctx (Bind tymap n adty) t
inferType ctx tymap (Cotr n ts p) = do
  -- TODO find better solution
  -- PROBLEM we need to find the binding that contains the constructor

  (adty, tys) <- lookupCotr tymap n
  if length ts == length tys
  then do
    -- check types accordingly to definition
    let subchecks = zipWith (\t ty -> do
          ty' <- matchTypeVar tymap ty
          checkType ctx tymap t ty'
          ) ts tys -- basically a map over the [(Term, Type)] and apply check ctx
    foldl (>>) (return ()) subchecks
    return adty
  else fail $ "Expected number of arguments violated for " ++ show (Cotr n ts p)
inferType _ _ t = fail $ "Cannot infer type of term " ++ show t

checkType :: Ctx -> TypeMap -> Term -> Type -> Check
checkType ctx tymap p@(Lam name t _) ty = do
  (ty1, ty2) <- matchFun ty (show p)
  checkType (Bind ctx name ty1) tymap t ty2
checkType ctx tymap (If cond t1 t2 _) ty = do
  checkType ctx tymap cond Bool
  checkType ctx tymap t1 ty
  checkType ctx tymap t2 ty
checkType ctx tymap (Match m cases _) ty = do
  mty <- inferType ctx tymap m
  mty' <- matchTypeVar tymap mty
  (_, cotrs) <- matchADT mty' (show m)
  if length cases == length cotrs
  then do
    let casechecks = map (\c -> do
          tys <- liftMaybe (M.lookup (labelOfCase c) cotrs) "Could not find constructor" :: Infer [Type]
          if length (bindingsOfCase c) == length tys
163 164 165
          then do
              let ctx' = foldl (\r (b, bty) -> Bind r b bty) ctx (zip (bindingsOfCase c) tys)
              checkType ctx' tymap (termOfCase c) ty
166 167 168 169 170 171 172
          else fail "number of bindings does not match number of args of constructor"
          ) cases
    foldl (>>) (return ()) casechecks
  else fail $ "cases do not match number of constructors of" ++ show m
checkType ctx tymap t ty = do
  ty' <- inferType ctx tymap t
  matchType ty ty' (show t)