Skip to content

Commit

Permalink
Track data dependencies in SOACs (diku-dk#2006)
Browse files Browse the repository at this point in the history
  • Loading branch information
nhey authored and CKuke committed Nov 8, 2023
1 parent 4575c8f commit e744728
Show file tree
Hide file tree
Showing 18 changed files with 513 additions and 36 deletions.
73 changes: 71 additions & 2 deletions src/Futhark/Analysis/DataDependencies.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
module Futhark.Analysis.DataDependencies
( Dependencies,
dataDependencies,
depsOf,
depsOf',
depsOfArrays,
depsOfShape,
lambdaDependencies,
reductionDependencies,
findNecessaryForReturned,
)
where
Expand All @@ -27,6 +33,23 @@ dataDependencies' ::
Dependencies
dataDependencies' startdeps = foldl grow startdeps . bodyStms
where
grow deps (Let pat _ (WithAcc inputs lam)) =
let input_deps = foldMap depsOfWithAccInput inputs
-- Dependencies of each input reduction are concatenated.
-- Input to lam is cert_1, ..., cert_n, acc_1, ..., acc_n.
lam_deps = lambdaDependencies deps lam (input_deps <> input_deps)
transitive = map (depsOfNames deps) lam_deps
in M.fromList (zip (patNames pat) transitive) `M.union` deps
where
depsOfArrays' shape =
map (\arr -> oneName arr <> depsOfShape shape)
depsOfWithAccInput (shape, arrs, Nothing) =
depsOfArrays' shape arrs
depsOfWithAccInput (shape, arrs, Just (lam', nes)) =
reductionDependencies deps lam' nes (depsOfArrays' shape arrs)
grow deps (Let pat _ (Op op)) =
let op_deps = map (depsOfNames deps) (opDependencies op)
in M.fromList (zip (patNames pat) op_deps) `M.union` deps
grow deps (Let pat _ (Match c cases defbody _)) =
let cases_deps = map (dataDependencies' deps . caseBody) cases
defbody_deps = dataDependencies' deps defbody
Expand All @@ -50,16 +73,62 @@ dataDependencies' startdeps = foldl grow startdeps . bodyStms
in M.unions $ [branchdeps, deps, defbody_deps] ++ cases_deps
grow deps (Let pat _ e) =
let free = freeIn pat <> freeIn e
freeDeps = mconcat $ map (depsOfVar deps) $ namesToList free
in M.fromList [(name, freeDeps) | name <- patNames pat] `M.union` deps
free_deps = depsOfNames deps free
in M.fromList [(name, free_deps) | name <- patNames pat] `M.union` deps

depsOf :: Dependencies -> SubExp -> Names
depsOf _ (Constant _) = mempty
depsOf deps (Var v) = depsOfVar deps v

depsOf' :: SubExp -> Names
depsOf' (Constant _) = mempty
depsOf' (Var v) = depsOfVar mempty v

depsOfVar :: Dependencies -> VName -> Names
depsOfVar deps name = oneName name <> M.findWithDefault mempty name deps

depsOfRes :: Dependencies -> SubExpRes -> Names
depsOfRes deps (SubExpRes _ se) = depsOf deps se

-- | Extend @names@ with direct dependencies in @deps@.
depsOfNames :: Dependencies -> Names -> Names
depsOfNames deps names = mconcat $ map (depsOfVar deps) $ namesToList names

depsOfArrays :: SubExp -> [VName] -> [Names]
depsOfArrays size = map (\arr -> oneName arr <> depsOf mempty size)

depsOfShape :: Shape -> Names
depsOfShape shape = mconcat $ map (depsOf mempty) (shapeDims shape)

-- | Determine the variables on which the results of applying
-- anonymous function @lam@ to @inputs@ depend.
lambdaDependencies ::
(ASTRep rep) =>
Dependencies ->
Lambda rep ->
[Names] ->
[Names]
lambdaDependencies deps lam inputs =
let names_in_scope = freeIn lam <> mconcat inputs
deps_in = M.fromList $ zip (boundByLambda lam) inputs
deps' = dataDependencies' (deps_in <> deps) (lambdaBody lam)
in map
(namesIntersection names_in_scope . depsOfRes deps')
(bodyResult $ lambdaBody lam)

-- | Like 'lambdaDependencies', but @lam@ is a binary operation
-- with a neutral element.
reductionDependencies ::
(ASTRep rep) =>
Dependencies ->
Lambda rep ->
[SubExp] ->
[Names] ->
[Names]
reductionDependencies deps lam nes inputs =
let nes' = map (depsOf deps) nes
in lambdaDependencies deps lam (zipWith (<>) nes' inputs)

-- | @findNecessaryForReturned p merge deps@ computes which of the
-- loop parameters (@merge@) are necessary for the result of the loop,
-- where @p@ given a loop parameter indicates whether the final value
Expand Down
7 changes: 7 additions & 0 deletions src/Futhark/IR/GPU/Op.hs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ instance Rename SizeOp where
instance IsOp SizeOp where
safeOp _ = True
cheapOp _ = True
opDependencies op = [freeIn op]

instance TypedOp SizeOp where
opType (GetSize _ _) = pure [Prim int64]
Expand Down Expand Up @@ -291,6 +292,12 @@ instance (ASTRep rep, IsOp (op rep)) => IsOp (HostOp op rep) where
-- transfer scalars to device.
SQ.null (bodyStms body) && all ((== 0) . arrayRank) types

opDependencies (SegOp op) = opDependencies op
opDependencies (OtherOp op) = opDependencies op
opDependencies op@(SizeOp {}) = [freeIn op]
opDependencies (GPUBody _ body) =
replicate (length . bodyResult $ body) (freeIn body)

instance (TypedOp (op rep)) => TypedOp (HostOp op rep) where
opType (SegOp op) = opType op
opType (OtherOp op) = opType op
Expand Down
3 changes: 3 additions & 0 deletions src/Futhark/IR/MC/Op.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ instance (ASTRep rep, IsOp (op rep)) => IsOp (MCOp op rep) where
cheapOp (ParOp _ op) = cheapOp op
cheapOp (OtherOp op) = cheapOp op

opDependencies (ParOp _ op) = opDependencies op
opDependencies (OtherOp op) = opDependencies op

instance (TypedOp (op rep)) => TypedOp (MCOp op rep) where
opType (ParOp _ op) = opType op
opType (OtherOp op) = opType op
Expand Down
2 changes: 2 additions & 0 deletions src/Futhark/IR/Mem.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ instance (IsOp (inner rep)) => IsOp (MemOp inner rep) where
safeOp (Inner k) = safeOp k
cheapOp (Inner k) = cheapOp k
cheapOp Alloc {} = True
opDependencies op@(Alloc {}) = [freeIn op]
opDependencies (Inner op) = opDependencies op

instance (CanBeWise inner) => CanBeWise (MemOp inner) where
addOpWisdom (Alloc size space) = Alloc size space
Expand Down
4 changes: 4 additions & 0 deletions src/Futhark/IR/Prop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,13 @@ class (ASTConstraints op, TypedOp op) => IsOp op where
-- | Should we try to hoist this out of branches?
cheapOp :: op -> Bool

-- | Compute the data dependencies of an operation.
opDependencies :: op -> [Names]

instance IsOp (NoOp rep) where
safeOp NoOp = True
cheapOp NoOp = True
opDependencies NoOp = []

-- | Representation-specific attributes; also means the rep supports
-- some basic facilities.
Expand Down
70 changes: 67 additions & 3 deletions src/Futhark/IR/SOACS/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import Data.List (intersperse)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
Expand All @@ -69,7 +70,7 @@ import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util (chunks, maybeNth, splitAt3)
import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))
Expand Down Expand Up @@ -170,9 +171,13 @@ data Scan rep = Scan
}
deriving (Eq, Ord, Show)

-- | What are the sizes of reduction results produced by these 'Scan's?
scanSizes :: [Scan rep] -> [Int]
scanSizes = map (length . scanNeutral)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan rep] -> Int
scanResults = sum . map (length . scanNeutral)
scanResults = sum . scanSizes

-- | Combine multiple scan operators to a single operator.
singleScan :: (Buildable rep) => [Scan rep] -> Scan rep
Expand All @@ -189,9 +194,13 @@ data Reduce rep = Reduce
}
deriving (Eq, Ord, Show)

-- | What are the sizes of reduction results produced by these 'Reduce's?
redSizes :: [Reduce rep] -> [Int]
redSizes = map (length . redNeutral)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce rep] -> Int
redResults = sum . map (length . redNeutral)
redResults = sum . redSizes

-- | Combine multiple reduction operators to a single operator.
singleReduce :: (Buildable rep) => [Reduce rep] -> Reduce rep
Expand Down Expand Up @@ -588,6 +597,61 @@ instance CanBeAliased SOAC where
instance (ASTRep rep) => IsOp (SOAC rep) where
safeOp _ = False
cheapOp _ = False
opDependencies (Stream w arrs accs lam) =
let accs_deps = map depsOf' accs
arrs_deps = depsOfArrays w arrs
in lambdaDependencies mempty lam (arrs_deps <> accs_deps)
opDependencies (Hist w arrs ops lam) =
let bucket_fun_deps' = lambdaDependencies mempty lam (depsOfArrays w arrs)
-- Bucket function results are indices followed by values.
-- Reshape this to align with list of histogram operations.
ranks = [length (histShape op) | op <- ops]
value_lengths = [length (histNeutral op) | op <- ops]
(indices, values) = splitAt (sum ranks) bucket_fun_deps'
bucket_fun_deps =
zipWith
concatIndicesToEachValue
(chunks ranks indices)
(chunks value_lengths values)
in mconcat $ zipWith (<>) bucket_fun_deps (map depsOfHistOp ops)
where
depsOfHistOp (HistOp dest_shape rf dests nes op) =
let shape_deps = depsOfShape dest_shape
in_deps = map (\vn -> oneName vn <> shape_deps <> depsOf' rf) dests
in reductionDependencies mempty op nes in_deps
-- A histogram operation may use the same index for multiple values.
concatIndicesToEachValue is vs =
let is_flat = mconcat is
in map (is_flat <>) vs
opDependencies (Scatter w arrs lam outputs) =
let deps = lambdaDependencies mempty lam (depsOfArrays w arrs)
in map flattenGroups (groupScatterResults' outputs deps)
where
flattenGroups (indicess, values) = mconcat indicess <> values
opDependencies (JVP lam args vec) =
mconcat $
replicate 2 $
lambdaDependencies mempty lam $
zipWith (<>) (map depsOf' args) (map depsOf' vec)
opDependencies (VJP lam args vec) =
mconcat $
replicate 2 $
lambdaDependencies mempty lam $
zipWith (<>) (map depsOf' args) (map depsOf' vec)
opDependencies (Screma w arrs (ScremaForm scans reds map_lam)) =
let (scans_in, reds_in, map_deps) =
splitAt3 (scanResults scans) (redResults reds) $
lambdaDependencies mempty map_lam (depsOfArrays w arrs)
scans_deps =
concatMap depsOfScan (zip scans $ chunks (scanSizes scans) scans_in)
reds_deps =
concatMap depsOfRed (zip reds $ chunks (redSizes reds) reds_in)
in scans_deps <> reds_deps <> map_deps
where
depsOfScan (Scan lam nes, deps_in) =
reductionDependencies mempty lam nes deps_in
depsOfRed (Reduce _ lam nes, deps_in) =
reductionDependencies mempty lam nes deps_in

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType _ t@Prim {} = t
Expand Down
74 changes: 43 additions & 31 deletions src/Futhark/IR/SOACS/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -521,37 +521,49 @@ isMapWithOp pat e
-- the data dependencies to see that the "dead" result is not
-- actually used for computing one of the live ones.
removeDeadReduction :: BottomUpRuleOp (Wise SOACS)
removeDeadReduction (_, used) pat aux (Screma w arrs form)
| Just ([Reduce comm redlam nes], maplam) <- isRedomapSOAC form,
not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check
let (red_pes, map_pes) = splitAt (length nes) $ patElems pat,
let redlam_deps = dataDependencies $ lambdaBody redlam,
let redlam_res = bodyResult $ lambdaBody redlam,
let redlam_params = lambdaParams redlam,
let used_after =
map snd . filter ((`UT.used` used) . patElemName . fst) $
zip red_pes redlam_params,
let necessary =
findNecessaryForReturned
(`elem` used_after)
(zip redlam_params $ map resSubExp $ redlam_res <> redlam_res)
redlam_deps,
let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params,
not $ and (take (length nes) alive_mask) = Simplify $ do
let fixDeadToNeutral lives ne = if lives then Nothing else Just ne
dead_fix = zipWith fixDeadToNeutral alive_mask nes
(used_red_pes, _, used_nes) =
unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $
zip3 red_pes redlam_params nes

let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam
redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix)

auxing aux $
letBind (Pat $ used_red_pes ++ map_pes) $
Op $
Screma w arrs $
redomapSOAC [Reduce comm redlam' used_nes] maplam'
removeDeadReduction (_, used) pat aux (Screma w arrs form) =
case isRedomapSOAC form of
Just ([Reduce comm redlam rednes], maplam) ->
let mkOp lam nes' = redomapSOAC [Reduce comm lam nes']
in removeDeadReduction' redlam rednes maplam mkOp
_ ->
case isScanomapSOAC form of
Just ([Scan scanlam nes], maplam) ->
let mkOp lam nes' = scanomapSOAC [Scan lam nes']
in removeDeadReduction' scanlam nes maplam mkOp
_ -> Skip
where
removeDeadReduction' redlam nes maplam mkOp
| not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check
let (red_pes, map_pes) = splitAt (length nes) $ patElems pat,
let redlam_deps = dataDependencies $ lambdaBody redlam,
let redlam_res = bodyResult $ lambdaBody redlam,
let redlam_params = lambdaParams redlam,
let used_after =
map snd . filter ((`UT.used` used) . patElemName . fst) $
zip red_pes redlam_params,
let necessary =
findNecessaryForReturned
(`elem` used_after)
(zip redlam_params $ map resSubExp $ redlam_res <> redlam_res)
redlam_deps,
let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params,
not $ and (take (length nes) alive_mask) = Simplify $ do
let fixDeadToNeutral lives ne = if lives then Nothing else Just ne
dead_fix = zipWith fixDeadToNeutral alive_mask nes
(used_red_pes, _, used_nes) =
unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $
zip3 red_pes redlam_params nes

let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam
redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix)

auxing aux $
letBind (Pat $ used_red_pes ++ map_pes) $
Op $
Screma w arrs $
mkOp redlam' used_nes maplam'
removeDeadReduction' _ _ _ _ = Skip
removeDeadReduction _ _ _ _ = Skip

-- | If we are writing to an array that is never used, get rid of it.
Expand Down
2 changes: 2 additions & 0 deletions src/Futhark/IR/SegOp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,8 @@ instance
where
cheapOp _ = False
safeOp _ = True
opDependencies op =
replicate (length . kernelBodyResult $ segBody op) (freeIn op)

--- Simplification

Expand Down
Loading

0 comments on commit e744728

Please sign in to comment.