Skip to content

Commit

Permalink
Remove a few more lambdas in Lang IR
Browse files Browse the repository at this point in the history
Slight tweak to lambda lifting, combining nested lambdas.
Also eta contraction to change \x -> f x to f, which helps when 'f' is a
local variable so no need to construct a new thunk.
  • Loading branch information
edwinb committed Dec 15, 2017
1 parent 7282c0a commit 24f580d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 33 deletions.
17 changes: 9 additions & 8 deletions src/IRTS/Lang.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import Idris.Core.TT

import Control.Monad.State hiding (lift)
import Data.Data (Data)
import Data.List
import Data.List
import qualified Data.Map.Strict as Map
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
Expand Down Expand Up @@ -141,7 +141,7 @@ addTags i ds = tag i ds []
tag i (x : as) acc = tag i as (x : acc)
tag i [] acc = (i, reverse acc)

data LiftState = LS (Maybe Name) Int [(Name, LDecl)]
data LiftState = LS (Maybe Name) Int [(Name, LDecl)]
(Map.Map ([Name], LExp) Name) -- map from args/expressions
-- to names, so we don't create the same function
-- multiple times
Expand Down Expand Up @@ -171,12 +171,12 @@ renameArgs args e
(map snd newargs, rename newargs e)

addFn :: Name -> LDecl -> State LiftState ()
addFn fn d
addFn fn d
= do LS n i ds done <- get
put (LS n i ((fn, d) : ds) done)

makeFn :: [Name] -> LExp -> State LiftState Name
makeFn args exp
makeFn args exp
= do fn <- getNextName
let (args', exp') = renameArgs args exp
LS n i ds done <- get
Expand All @@ -189,7 +189,7 @@ makeFn args exp
return fn

liftAll :: [(Name, LDecl)] -> [(Name, LDecl)]
liftAll xs =
liftAll xs =
let (LS _ _ decls _) = execState (mapM_ liftDef xs) (LS Nothing 0 [] Map.empty) in
decls

Expand Down Expand Up @@ -222,6 +222,7 @@ lift env (LForce e) = do e' <- lift env e
lift env (LLet n v e) = do v' <- lift env v
e' <- lift (env ++ [n]) e
return (LLet n v' e')
lift env (LLam args (LLam args' e)) = lift env (LLam (args ++ args') e)
lift env (LLam args e) = do e' <- lift (env ++ args) e
let usedArgs = nub $ usedIn env e'
fn <- makeFn (usedArgs ++ args) e'
Expand Down Expand Up @@ -363,15 +364,15 @@ lsubst n new (LCase t e alts) = let e' = lsubst n new e
lsubst n new tm = tm

rename :: [(Name, Name)] -> LExp -> LExp
rename ns tm@(LV x)
rename ns tm@(LV x)
= case lookup x ns of
Just n -> LV n
_ -> tm
rename ns (LApp t e args)
rename ns (LApp t e args)
= let e' = rename ns e
args' = map (rename ns) args in
LApp t e' args'
rename ns (LLazyApp fn args)
rename ns (LLazyApp fn args)
= let args' = map (rename ns) args in
LLazyApp fn args'
rename ns (LLazyExp e) = LLazyExp (rename ns e)
Expand Down
71 changes: 46 additions & 25 deletions src/IRTS/LangOpts.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ doInline :: LDefs -> LDecl -> LDecl
doInline defs d@(LConstructor _ _ _) = d
doInline defs (LFun opts topn args exp)
= let inl = evalState (eval [] initEnv [topn] defs exp)
(length args)
(length args)
-- do some case floating, which might arise as a result
res = caseFloats 10 inl in
-- then, eta contract
res = eta $ caseFloats 10 inl in
case res of
LLam args' body -> LFun opts topn (map snd initNames ++ args') body
_ -> LFun opts topn (map snd initNames) res
Expand Down Expand Up @@ -110,10 +111,10 @@ eval stk env rec defs (LCase ty e alts)
-- If they're all lambdas, bind the lambda at the top
let prefix = getLams (map getRHS alts')
case prefix of
[] -> return $ LCase ty e' (replaceInAlts e' alts')
[] -> return $ conOpt $ LCase ty e' (replaceInAlts e' alts')
args -> do alts_red <- mapM (dropArgs args) alts'
return $ LLam args
(LCase ty e' (replaceInAlts e' alts_red))
return $ LLam args
(conOpt (LCase ty e' (replaceInAlts e' alts_red)))
eval stk env rec defs (LOp f es)
= unload stk <$> LOp f <$> mapM (eval [] env rec defs) es
eval stk env rec defs (LForeign t s args)
Expand All @@ -128,7 +129,7 @@ eval stk env rec defs (LLam args sc)
[] -> eval stk' env' rec defs sc
as -> do ns' <- mapM (\n -> do n' <- nextN
return (n, n')) args'
LLam (map snd ns') <$>
unload stk' <$> LLam (map snd ns') <$>
eval [] (map (\ (n, n') -> (n, LV n')) ns' ++ env')
rec defs sc
eval stk env rec defs var@(LV n)
Expand Down Expand Up @@ -157,6 +158,11 @@ evalAlt stk env rec defs (LConstCase c e)
evalAlt stk env rec defs (LDefaultCase e)
= LDefaultCase <$> eval stk env rec defs e

apply :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp ->
[Name] -> LExp -> State Int LExp
apply stk env rec defs var args body
= eval stk env rec defs (LLam args body)

dropArgs :: [Name] -> LAlt -> State Int LAlt
dropArgs as (LConCase i n es (LLam args rhs))
= do let old = take (length as) args
Expand All @@ -173,14 +179,15 @@ dropArgs as (LDefaultCase (LLam args rhs))

caseFloat :: LExp -> LExp
caseFloat (LApp tc e es) = LApp tc (caseFloat e) (map caseFloat es)
caseFloat (LLazyExp e) = LLazyExp (caseFloat e)
caseFloat (LForce e) = LForce (caseFloat e)
caseFloat (LCon up i n es) = LCon up i n (map caseFloat es)
caseFloat (LOp f es) = LOp f (map caseFloat es)
caseFloat (LLam ns sc) = LLam ns (caseFloat sc)
caseFloat (LLet v val sc) = LLet v (caseFloat val) (caseFloat sc)
caseFloat (LCase _ (LCase ct exp alts) alts')
| all conRHS alts || length alts == 1
= replaceInCase (LCase ct (caseFloat exp) (map (updateWith alts') alts))
= conOpt $ replaceInCase (LCase ct (caseFloat exp) (map (updateWith alts') alts))
where
conRHS (LConCase _ _ _ (LCon _ _ _ _)) = True
conRHS (LConstCase _ (LCon _ _ _ _)) = True
Expand All @@ -194,10 +201,19 @@ caseFloat (LCase _ (LCase ct exp alts) alts')
updateWith alts (LDefaultCase rhs) =
LDefaultCase (caseFloat (conOpt (LCase Shared (caseFloat rhs) alts)))

conOpt (LCase ct (LCon _ t n args) alts)
= pickAlt n args alts
conOpt tm = tm
caseFloat (LCase ct exp alts')
= conOpt $ replaceInCase (LCase ct (caseFloat exp) (map cfAlt alts'))
where
cfAlt (LConCase i n es rhs) = LConCase i n es (caseFloat rhs)
cfAlt (LConstCase c rhs) = LConstCase c (caseFloat rhs)
cfAlt (LDefaultCase rhs) = LDefaultCase (caseFloat rhs)
caseFloat exp = exp

-- Case of constructor
conOpt :: LExp -> LExp
conOpt (LCase ct (LCon _ t n args) alts)
= pickAlt n args alts
where
pickAlt n args (LConCase i n' es rhs : as) | n == n'
= substAll (zip es args) rhs
pickAlt _ _ (LDefaultCase rhs : as) = rhs
Expand All @@ -206,14 +222,7 @@ caseFloat (LCase _ (LCase ct exp alts) alts')

substAll [] rhs = rhs
substAll ((n, tm) : ss) rhs = lsubst n tm (substAll ss rhs)

caseFloat (LCase ct exp alts')
= replaceInCase (LCase ct (caseFloat exp) (map cfAlt alts'))
where
cfAlt (LConCase i n es rhs) = LConCase i n es (caseFloat rhs)
cfAlt (LConstCase c rhs) = LConstCase c (caseFloat rhs)
cfAlt (LDefaultCase rhs) = LDefaultCase (caseFloat rhs)
caseFloat exp = exp
conOpt tm = tm

replaceInCase :: LExp -> LExp
replaceInCase (LCase ty e alts)
Expand All @@ -225,7 +234,7 @@ replaceInAlts exp alts = dropDups $ concatMap (replaceInAlt exp) alts

-- Drop overlapping case (arising from case merging of overlapping
-- patterns)
dropDups (alt@(LConCase _ i n ns) : alts)
dropDups (alt@(LConCase _ i n ns) : alts)
= alt : dropDups (filter (notTag i) alts)
where
notTag i (LConCase _ j n ns) = i /= j
Expand All @@ -248,9 +257,9 @@ replaceInAlt exp@(LV var) (LDefaultCase (LCase ty (LV var') alts))
replaceInAlt exp a = [a]

replaceExp :: LExp -> LExp -> LExp -> LExp
replaceExp (LCon _ t n args) new (LCon _ t' n' args')
replaceExp (LCon _ t n args) new (LCon _ t' n' args')
| n == n' && args == args' = new
replaceExp (LCon _ t n args) new (LApp _ (LV n') args')
replaceExp (LCon _ t n args) new (LApp _ (LV n') args')
| n == n' && args == args' = new
replaceExp old new tm = tm

Expand All @@ -271,7 +280,19 @@ getLamPrefix as (LLam args tm : cs)
| otherwise = getLamPrefix as cs
getLamPrefix as (_ : cs) = []

apply :: [LExp] -> [(Name, LExp)] -> [Name] -> LDefs -> LExp ->
[Name] -> LExp -> State Int LExp
apply stk env rec defs var args body
= eval stk env rec defs (LLam args body)
-- eta contract ('\x -> f x' can just be compiled as 'f' when f is local)
eta :: LExp -> LExp
eta (LApp tc a es) = LApp tc (eta a) (map eta es)
eta (LLazyApp n es) = LLazyApp n (map eta es)
eta (LLazyExp e) = LLazyExp (eta e)
eta (LForce e) = LForce (eta e)
eta (LLet n val sc) = LLet n (eta val) (eta sc)
eta (LLam args (LApp tc f args'))
| args' == map LV args = eta f
eta (LLam args e) = LLam args (eta e)
eta (LProj e i) = LProj (eta e) i
eta (LCon a t n es) = LCon a t n (map eta es)
eta (LCase ct e alts) = LCase ct (eta e) (map (fmap eta) alts)
eta (LOp f es) = LOp f (map eta es)
eta tm = tm

0 comments on commit 24f580d

Please sign in to comment.