ContinueAfterFail.hs 6.84 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
{-# LANGUAGE ConstraintKinds, TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RebindableSyntax #-}
module ADTypes.ContinueAfterFail where

import Prelude hiding (Monad(..), (>=), (<=), lookup)
import GHC.Exts (Constraint)
import qualified Data.Map as M
import Data.Maybe
import ADTypes.Language
import Util.ErrorMessages
import Util.PartialOrd

-- is needed because we use the RebindableSyntax extension
ifThenElse :: Bool -> a -> a -> a
ifThenElse True thn _ = thn
ifThenElse False _ els = els

type Error = [String]

data Infer a
  = Inferred a
  | NotInferred Error deriving (Eq, Show)

type Check = Infer ()

instance Functor Infer where
  fmap _ (NotInferred err) = NotInferred err
  fmap f (Inferred ty) = Inferred $ f ty

instance Applicative Infer where
  pure = Inferred
  (NotInferred err) <*> _ = NotInferred err
  (Inferred a) <*> something = fmap a something

class WithTop a where
  top :: a

instance WithTop Type where
  top = AnyType

instance WithTop () where
  top = ()

instance WithTop Char where
  top = '_'

instance (WithTop a, WithTop b) => WithTop (a, b) where
  top = (top, top)

instance (WithTop a) => WithTop [a] where
  top = [top]

instance (WithTop v) => WithTop (M.Map Name v) where
  top = M.empty

-- Had to define an own monad type class.
-- It is not possible otherwise to get the type constraint WithTop a.
-- We use the extension ConstraintKinds to support this.
-- Could not find a simpler solution for this problem.
-- The restricted monad problem is common.
class RMonad m where
  type RMonadCtx m a :: Constraint
  return :: RMonadCtx m a => a -> m a
  (>>=) :: (RMonadCtx m a, RMonadCtx m b) => m a -> (a -> m b) -> m b
  (>>) :: (RMonadCtx m a, RMonadCtx m b) => m a ->  m b -> m b
  m >> k = m >>= \_ -> k
  fail :: [String] -> m a

instance RMonad Infer where
  type RMonadCtx Infer a = WithTop a
  return = Inferred
  (Inferred ty) >>= f = f ty
  NotInferred err1 >>= f =
    -- we know that top is Inferred AnyType, (AnyType, AnyType) or () by definition
    case f top of
      Inferred _ -> fail err1
      NotInferred err2 -> fail $ err1 ++ err2
  fail = NotInferred

-- matching functions that extract the inner types if possible
matchNat :: Type -> String -> Check
matchNat Nat _ = return ()
matchNat ty err = fail [natError ty err]

matchBool :: Type -> String -> Check
matchBool Bool _ = return ()
matchBool ty err = fail [boolError ty err]

matchFun :: Type -> String -> Infer (Type, Type)
matchFun (Fun ty1 ty2) _ = return (ty1, ty2)
matchFun ty err = fail [funError ty err]

matchType :: Type -> Type -> String -> Check
matchType ty1 ty2 _
  | ty1 == ty2 = return ()
matchType ty1 ty2 err = fail [generalError (show ty1) ty2 err]


matchADT :: Type -> String -> Infer (Name, M.Map Name [Type])
matchADT (ADT n cotrs) _ = return (n, cotrs)
matchADT ty err = fail [sumError ty err]

liftMaybe :: WithTop a => Maybe a -> String -> Infer a
liftMaybe (Just a) _ = return a
liftMaybe Nothing err = fail [err]

lookupVar :: Ctx -> Name -> Infer Type
lookupVar Empty x = fail ["Unbound variable " ++ show x]
lookupVar (Bind c x t) y
  | x == y = return t
  | otherwise = lookupVar c y

lookupTypeVar :: TypeMap -> Name -> Infer Type
lookupTypeVar Empty x = fail ["Unbound type variable " ++ show x]
lookupTypeVar (Bind c x t) y
  | x == y = return t
  | otherwise = lookupTypeVar c y

matchTypeVar :: TypeMap -> Type -> Infer Type
matchTypeVar tymap (TypeVar x) = lookupTypeVar tymap x
matchTypeVar _ ty = return ty

lookupCotr :: TypeMap -> Name -> Infer (Type, [Type])
lookupCotr Empty x = fail ["Unbound constructor " ++ show x]
lookupCotr (Bind c x ty) y = case ty of
  t@(ADT n cotrs) ->
    let mayCotr = M.lookup y cotrs in
    if isJust mayCotr
    then return (t, fromJust mayCotr)
    else lookupCotr c y
  _ -> lookupCotr c y

inferType :: Ctx -> TypeMap -> Term -> Infer Type
inferType _ _ (Zero _) = return Nat
inferType ctx tymap (Succ t _) = do
  checkType ctx tymap t Nat
  return Nat
inferType ctx tymap (Add t1 t2 _) = do
  checkType ctx tymap t1 Nat
  checkType ctx tymap t2 Nat
  return Nat
inferType ctx tymap (Mult t1 t2 _) = do
  checkType ctx tymap t1 Nat
  checkType ctx tymap t2 Nat
  return Nat
inferType _ _ (Tru _) = return Bool
inferType _ _ (Fls _) = return Bool
inferType ctx tymap (Not t _) = do
  checkType ctx tymap t Bool
  return Bool
inferType ctx tymap (And t1 t2 _) = do
  checkType ctx tymap t1 Bool
  checkType ctx tymap t2 Bool
  return Bool
inferType ctx tymap (Or t1 t2 _) = do
  checkType ctx tymap t1 Bool
  checkType ctx tymap t2 Bool
  return Bool
160 161 162
inferType ctx tymap (Var name _) = do
  ty <- lookupVar ctx name
  matchTypeVar tymap ty
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
inferType ctx tymap (Let name t body _) = do
  tyt <- inferType ctx tymap t
  inferType (Bind ctx name tyt) tymap body
inferType ctx tymap (Anno term ty _) = do
  checkType ctx tymap term ty
  return ty
inferType ctx tymap (App t1 t2 _) = do
  ty <- inferType ctx tymap t1
  (ty1, ty2) <- matchFun ty (show t1)
  checkType ctx tymap t2 ty1
  return ty2
inferType ctx tymap (LetData n adty t _) =
  -- TODO check that n is not already bound
  -- TODO check that elements in n are not bound
  inferType ctx (Bind tymap n adty) t
inferType ctx tymap (Cotr n ts p) = do
  -- TODO find better solution
  -- PROBLEM we need to find the binding that contains the constructor

  (adty, tys) <- lookupCotr tymap n
  if length ts == length tys
  then do
    -- check types accordingly to definition
186
    let subchecks = zipWith (\t ty -> do
187 188
          ty' <- matchTypeVar tymap ty
          checkType ctx tymap t ty'
189
          ) ts tys -- basically a map over the [(Term, Type)] and apply check ctx
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    foldl (>>) (return ()) subchecks
    return adty
  else fail ["Expected number of arguments violated for " ++ show (Cotr n ts p)]
inferType _ _ t = fail ["Cannot infer type of term " ++ show t]

checkType :: Ctx -> TypeMap -> Term -> Type -> Check
checkType ctx tymap p@(Lam name t _) ty = do
  (ty1, ty2) <- matchFun ty (show p)
  checkType (Bind ctx name ty1) tymap t ty2
checkType ctx tymap (If cond t1 t2 _) ty = do
  checkType ctx tymap cond Bool
  checkType ctx tymap t1 ty
  checkType ctx tymap t2 ty
checkType ctx tymap (Match m cases _) ty = do
  mty <- inferType ctx tymap m
  mty' <- matchTypeVar tymap mty
  (_, cotrs) <- matchADT mty' (show m)
  if length cases == length cotrs
  then do
    let casechecks = map (\c -> do
          tys <- liftMaybe (M.lookup (labelOfCase c) cotrs) "Could not find constructor" :: Infer [Type]
          if length (bindingsOfCase c) == length tys
212 213 214
          then do
              let ctx' = foldl (\r (b, bty) -> Bind r b bty) ctx (zip (bindingsOfCase c) tys)
              checkType ctx' tymap (termOfCase c) ty
215 216 217 218 219 220 221
          else fail ["number of bindings does not match number of args of constructor"]
          ) cases
    foldl (>>) (return ()) casechecks
  else fail ["cases do not match number of constructors of" ++ show m]
checkType ctx tymap t ty = do
  ty' <- inferType ctx tymap t
  matchType ty ty' (show t)