sourcecode:
|
module Inference.Inference where
import Analysis.ProgInfo
import Analysis.TotallyDefined ( siblingCons )
import Data.Char ( toUpper )
import Data.List
( (\\), maximum, nub, partition, union )
import qualified Data.Map as DM
import Data.Maybe
( fromJust, fromMaybe, isNothing, mapMaybe )
import FlatCurry.Annotated.Goodies
( allVars, allVarsInFunc, annExpr, branchPattern, patCons, unAnnExpr
, unAnnFuncDecl )
import FlatCurry.Annotated.Pretty ( ppFuncDecl )
import qualified FlatCurry.Goodies as FCG
import FlatCurry.ShowIntMod ( showFuncDeclAsFlatCurry )
import FlatCurry.Typed.Goodies
import FlatCurry.Typed.Types
import FlatCurry.Types
import Inference.Flattening
import Inference.Simplification
import Text.Pretty ( pPrint )
import Utils ( encodeSpecialChars )
type InfInfo = ( QName -- Function name
, Bool -- Can the function fail?
, [QName] -- List of called functions
, [TAFuncDecl] -- Non-fail conditions associated with this function
)
--- Infer NFCs for a list of function declarations
inferNFConds :: ModuleName
-> ProgInfo (TypeDecl, [Constructor])
-> [TAFuncDecl]
-> [TAFuncDecl]
inferNFConds modname info fdecls
= let freshVars = [maximum (concatMap allVarsInFunc fdecls) + 1 ..]
(_, flatDecls) = flattenFuncs freshVars fdecls
res = map (inferNF modname info) flatDecls
decls = filterRelevantNFCs modname res
in map unrec decls
--- Filters relevant NFCs according to the following criteria:
--- 1. the respective function can fail itself
--- 2. a failing function calls the respective function
filterRelevantNFCs :: ModuleName -> [InfInfo] -> [TAFuncDecl]
filterRelevantNFCs modname res = fds
where
canFail
= [r | r@((mname, _), canFail, _, _) <- res, mname == modname && canFail]
calledQNs = nub [qn | (_, _, qns, _) <- canFail, qn <- qns]
fds = [fd | (_, _, _, ds) <- canFail `union` called calledQNs [], fd <- ds]
called :: [QName] -> [InfInfo] -> [InfInfo]
called qns is
= if qns == qns' then is `union` is' else called qns' (is `union` is')
where
is' = nub [info | qn <- qns, info@(qn', _, _, _) <- res, qn == qn']
qns' = nub (qns ++ concatMap (\(qn, _, cqns, _) -> qn : cqns) is')
--- Infer NFCs for a single function declaration
inferNF
:: ModuleName -> ProgInfo (TypeDecl, [Constructor]) -> TAFuncDecl -> InfInfo
inferNF modname info f@(AFunc qn@(mname, fname) arity vis ty rule) =
(qn, canFail, calls, fdecls)
where
argtys = FCG.argTypes ty
nftype = foldr FuncType boolType argtys
ty' = foldr FuncType boolType
(replicate arity boolType ++ argtys)
qn' s = (mname, fname ++ s)
(canFail, calls, rule') = inferNFRule info arity ty' rule
expr' = isTrivialRule rule'
callRule = inferNFCallRule arity nftype ty'
(qn' "_nonfailspec") rule expr'
fdecls
= [ AFunc (qn' "'nonfail") arity vis nftype callRule
, AFunc (qn' "_nonfailspec") (2 * arity) vis ty' rule'
]
--- If a rule has a constant Boolean expression, return it
isTrivialRule :: TARule -> Maybe TAExpr
isTrivialRule (AExternal _ _) = Nothing
isTrivialRule (ARule _ _ expr) = case expr of
AComb _ _ (qn, _) _ -> case qn of
("Prelude", "True") -> Just expr
("Prelude", "False") -> Just expr
_ -> Nothing
_ -> Nothing
--- Return transformed NFC rule along with information about called functions
--- and whether the function can fail
inferNFRule :: ProgInfo (TypeDecl, [Constructor])
-> Int
-> TypeExpr
-> TARule
-> (Bool, [QName], TARule)
inferNFRule _ _ _ r@(AExternal _ _) = (False, [], r)
inferNFRule info arity ty' (ARule _ argVars expr)
| canFail = (canFail, calls, ARule ty' argVars' expr')
| otherwise = (False, [], ARule ty' (argVars' ++ rhsVars) (boolExpr "True")) -- Calls of non-failing functions can be omitted
where
(canFail, calls, expr') = inferNFExpr info expr
argVars' = bArgs ++ argVars
argIndxs = map fst argVars'
freshVar = maximum (argIndxs ++ allVars expr') + 1
rhsVars = drop (length argVars')
(addTypes2Vars (argIndxs ++ [freshVar ..]) (stripForall ty'))
addTypes2Vars (v : vs) ty = case ty of
FuncType t1 t2 -> (v, t1) : addTypes2Vars vs t2
_ -> []
bArgs = map (\(v, _) -> (starVar v, boolType)) argVars
--- Return rule for calling inferred NFC with initial boolean values
inferNFCallRule
:: Int -> TypeExpr -> TypeExpr -> QName -> TARule -> Maybe TAExpr -> TARule
inferNFCallRule _ _ _ _ r@(AExternal _ _) _ = r
inferNFCallRule arity nftype ty' qn (ARule _ argVars _) texp = ARule nftype
(argVars ++ rhsVars) expr
where
expr = case texp of
Nothing -> AComb boolType FuncCall (qn, ty')
(replicate (length argVars) (boolExpr "True")
++ map (\(i, t) -> AVar t i) argVars)
Just e -> e
argIndxs = map fst argVars
freshVar = maximum (argIndxs ++ allVars expr) + 1
rhsVars = drop (length argVars)
(addTypes2Vars (argIndxs ++ [freshVar ..]) (stripForall nftype))
addTypes2Vars (v : vs) ty = case ty of
FuncType t1 t2 -> (v, t1) : addTypes2Vars vs t2
_ -> []
--- Introduce new variables
starVar :: Int -> Int
starVar = (+ 10000)
--- Transform variable into a starred variable
starVarExp :: TAExpr -> TAExpr
starVarExp e = case e of
AVar _ i -> AVar boolType (starVar i)
_ -> error $ "Function argument must be variable but found " ++ show e
-- TODO: type annotations => type as result?
--- Return transformed NFC expression along with information about called functions
--- and whether the expression can fail
inferNFExpr
:: ProgInfo (TypeDecl, [Constructor]) -> TAExpr -> (Bool, [QName], TAExpr)
inferNFExpr info expr
= let inf = inferNFExpr info
in case expr of
AVar ty i -> (False, [], AVar boolType (starVar i))
ALit _ _ -> (False, [], boolExpr "True")
AComb _ _ (("Prelude", "failed"), _) [] ->
(True, [], boolExpr "False")
AComb _ _ (("Prelude", "error"), _) _ -> (True, [], boolExpr "False")
AComb _ ConsCall _ _ -> (False, [], boolExpr "True")
AComb _ (ConsPartCall _) _ _ -> (False, [], boolExpr "True")
AComb _ FuncCall ((_, '_' : 'i' : 'm' : 'p' : 'l' : _), _) _ ->
(False, [], boolExpr "True") -- TODO: Implement whitelist?
AComb ty FuncCall (("Prelude", "apply"), _) _ ->
(False, [], boolExpr "True")
AComb _ FuncCall (qn@(modname, fun), qnty) es
| isHO qnty -> (False, [], boolExpr "True")
| otherwise ->
( False
, [qn]
, AComb boolType FuncCall
((modname, fun ++ "_nonfailspec"), nonFailType qnty)
(map starVarExp es ++ es)
)
where
isHO :: TypeExpr -> Bool
isHO t = case t of
FuncType (FuncType _ _) _ -> True
FuncType _ t -> isHO t
_ -> False
AComb _ (FuncPartCall n) (qn@(modname, fun), qnty) es ->
(False, [], boolExpr "True")
-- ( False
-- , [qn]
-- , AComb boolType (FuncPartCall n)
-- ((modname, fun ++ "_nonfailspec"), nonFailType qnty) es -- Intertwine arguments with Boolean values?
-- )
ALet _ binds e ->
let (b, qs, e') = inf e
(bs, qss, starBinds) = unzip3
(map (\((v, _), exp) ->
let (b2, qs2, exp') = inf exp
in (b2, qs2, ((starVar v, boolType), exp'))) binds)
in ( b || or bs
, qs ++ concat qss
, ALet boolType (starBinds ++ binds) e'
)
AOr _ e1 e2 ->
let (b1, qs1, e1') = inf e1
(b2, qs2, e2') = inf e2
canFail = b1 || b2
in if canFail
then (True, [], boolExpr "False")
else (b1 || b2, qs1 `union` qs2, AOr boolType e1' e2')
ACase _ ct e brs ->
let (mdecl, misscons, _) = missingConsInBranch info brs
(bs, qss, brs') = unzip3
(map (\(ABranch p exp) -> let (b, qs, e'') = inf exp
in (b, qs, ABranch p e'')) brs)
canFail = or bs || not (null misscons)
qs' = concat qss
e' = case mdecl of
Just (Type tqn _ vs cs) -> ACase boolType ct e
(map addStarVars2Branch (brs' ++ newBrs))
where
newBrs = map c2br misscons
c2br c = ABranch (patGen c) (boolExpr "False")
cqnts = map (\(Cons qn _ _ ts) -> (qn, ts)) cs
consType = annExpr e --TCons tqn (map (TVar . fst) vs) -- TODO: Specialize type?
patGen (qn, ar) = APattern consType (qn, consType)
(zip [1 .. ar] (fromMaybe [] (lookup qn cqnts)))
Nothing -> ACase boolType Rigid e brs -- TODO: Literal cases
_ -> error "Something went wrong" -- Should not happen
ite = ACase boolType Rigid (starVarExp e)
[ ABranch (boolPat "True") e'
, ABranch (boolPat "False") (boolExpr "False")
]
boolPat str = APattern boolType (("Prelude", str), boolType) []
in (canFail, qs', ite)
ATyped _ e ty' -> let (b, qs, e') = inf e
in (b, qs, ATyped boolType e' ty')
AFree _ vars e -> let (b, qs, e') = inf e
in (b, qs, AFree boolType vars e')
--- Introduces starred variables for each pattern variable
addStarVars2Branch :: TABranchExpr -> TABranchExpr
addStarVars2Branch (ABranch p e) = case p of
APattern _ _ vars@(_ : _) -> ABranch p
(ALet (annExpr e) (zip (zip (map (starVar . fst) vars) (repeat boolType))
(repeat (boolExpr "True"))) e)
_ -> ABranch p e
type Arity = Int
type Constructor = (QName, Arity)
type ModuleName = String
--- Smart constructor for Boolean constants
boolExpr :: String -> TAExpr
boolExpr cons = AComb boolType ConsCall (("Prelude", cons), boolType) []
showQNameAsFun :: QName -> String
showQNameAsFun (mod, fun) = mod ++ toUpper c : cs
where
(c : cs) = encodeSpecialChars fun
boolType :: TypeExpr
boolType = TCons ("Prelude", "Bool") []
-- Splits the constructors (name/arity) which are missing in the given
-- branches of a case construct from the ones covered
missingConsInBranch :: ProgInfo (TypeDecl, [Constructor])
-> [TABranchExpr]
-> (Maybe TypeDecl, [Constructor], [Constructor])
missingConsInBranch _ []
= error "missingConsInBranch: case with empty branches!"
missingConsInBranch _ (ABranch (ALPattern _ _) _ : _) = (Nothing, [], []) --error "TODO: case with literal pattern"
missingConsInBranch info (ABranch (APattern _ (cons, _) args) _ : brs)
= let (decl, othercons) = fromMaybe
(error
$ "Sibling constructors of " ++ showQName cons ++ " not found!")
(lookupProgInfo cons info)
branchcons = map (patCons . branchPattern) brs
(missing, covered) = partition ((`notElem` branchcons) . fst) othercons
in (Just decl, missing, (cons, length args) : covered)
nonFailType :: TypeExpr -> TypeExpr
nonFailType ty = foldr FuncType boolType
(replicate (length tys) boolType ++ tys)
where
tys = FCG.argTypes ty
-- Information about variables: We either know that a variable
-- has a specific outermost constructor or that is has an outermost
-- constructor not contained in a list of constructor names.
type VarMap = DM.Map VarIndex (Either QName [QName])
--- Remove redundant recursion from a function declaration (experimental)
unrec :: TAFuncDecl -> TAFuncDecl
unrec decl@(AFunc qn arity vis ty rule) = decl'
where
decl' = AFunc qn arity vis ty (unrecRule rule)
unrecRule (AExternal _ _) = rule
unrecRule (ARule typ argVars expr) = ARule typ argVars
(simplifyExpr (unrecExpr decl DM.empty expr))
--- Remove recursion from an expression
unrecExpr :: TAFuncDecl -> VarMap -> TAExpr -> TAExpr
unrecExpr decl@(AFunc fname _ _ _ _) vmap expr
= let frec = unrecExpr decl vmap
in case expr of
AVar _ _ -> expr
ALit _ _ -> expr
AComb ty ct qn@(fn, _) es
| fn == fname -> case constantResult decl vmap es of
Nothing -> AComb ty ct qn (map frec es)
Just exp -> exp
| otherwise -> AComb ty ct qn (map frec es)
ALet ty binds e -> ALet ty binds
(unrecExpr decl (DM.union vmap (binds2VarMap binds)) e) -- todo: unrec binds
AFree ty fvars e -> AFree ty fvars (frec e)
AOr ty e1 e2 -> AOr ty (frec e1) (frec e2)
ACase ty ct v@(AVar vty i) branches -> ACase ty Rigid v
(unrecBranches branches [])
where
unrecBranches :: [TABranchExpr] -> [QName] -> [TABranchExpr]
unrecBranches [] _ = []
unrecBranches (ABranch (APattern typ (qn, qty) vs) e : brs) qns
= ABranch pat e' : unrecBranches brs (qn : qns)
where
pat = APattern typ (qn, qty) vs
e' = unrecExpr decl vmap' e
vmap' = DM.insertWith upd i (Left qn) vmap
upd _ new@(Left _) = new
upd (Right qs1) (Right qs2) = Right (qs1 ++ qs2)
upd old@(Left _) (Right _) = old
unrecBranches (ABranch p@(ALPattern _ _) e : brs) qns
= ABranch p (unrecExpr decl vmap e) : unrecBranches brs qns
ACase ty ct e branches -> ACase ty ct (frec e)
(map (\(ABranch p exp) -> ABranch p (frec exp)) branches)
ATyped ty e typ -> ATyped ty (frec e) typ
binds2VarMap :: [((VarIndex, TypeExpr), TAExpr)] -> VarMap
binds2VarMap [] = DM.empty
binds2VarMap (((i, typ), e) : bs) = case e of
AComb _ ConsCall (qn, _) _ -> DM.insert i (Left qn) (binds2VarMap bs)
_ -> binds2VarMap bs
--- Determines whether a function returns a constant result for a list of arguments
--- which might be bound to values contained in a map
constantResult :: TAFuncDecl -> VarMap -> [TAExpr] -> Maybe TAExpr
constantResult decl@(AFunc qn arity vis ty (ARule rty vars exp)) vmap argExps
= case go vmap' exp of
[x] -> Just x
_ -> Nothing -- todo: no results?
where
vmap' = foldr upd DM.empty (zip vars argExps)
upd ((j, _), e) m = case e of
AVar _ i -> case DM.lookup i vmap of
Just x -> DM.insert j x m
Nothing -> m
_ ->
error $ "Inference.constantResult: normalization failure for " ++ show e
go :: VarMap -> TAExpr -> [TAExpr]
go vm e = case e of
AVar _ i -> case DM.lookup i vm of
Just x -> case x of
Left qn -> case qn of
("Prelude", "True") -> [boolExpr "True"]
("Prelude", "False") -> [boolExpr "False"]
_ -> error
$ "Inference.constantResult.go: Unknown constructor " ++ show qn
Right qns -> error
$ "Inference.constantResult.go: Missing constructor for variable "
++ show i
Nothing -> error
$ "Inference.constantResult.go: Missing binding for variable "
++ show i
ALit _ _ -> [] -- todo: literals
AComb _ _ (("Prelude", "True"), _) [] -> [boolExpr "True"]
AComb _ _ (("Prelude", "False"), _) [] -> [boolExpr "False"]
AComb _ _ _ _ -> [] -- todo: function calls
ALet _ binds expr -> go (DM.union vm (binds2VarMap binds)) expr
AFree _ binds expr -> go vm expr -- todo: binds
AOr _ e1 e2 -> go vm e1 ++ go vm e2
ACase _ _ (AVar _ i) branches -> case DM.lookup i vm of
Just x -> case x of
Left qn -> selectBranchExprs qn branches
Right qns -> removeBranchExprs qns branches
Nothing -> concatMap (\(ABranch _ e) -> go vm e) branches
where
selectBranchExprs :: QName -> [TABranchExpr] -> [TAExpr]
selectBranchExprs qname brs = concatMap (go vm) $ mapMaybe match brs
where
match (ABranch (APattern _ (pqn, _) _) e) | pqn == qname = Just e
| otherwise = Nothing
removeBranchExprs :: [QName] -> [TABranchExpr] -> [TAExpr]
removeBranchExprs qns brs = concatMap (go vm) $ mapMaybe match brs
where
match (ABranch (APattern _ (pqn, _) _) e) | pqn `elem` qns = Nothing
| otherwise = Just e
ATyped _ e _ -> go vm e
debug f = putStrLn
$ showFuncDeclAsFlatCurry
$ unAnnFuncDecl
$ unrec (snd $ flattenFunc [42 ..] f)
|