Skip to content

Commit

Permalink
SegOp Argument Reordering. (#2225)
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDue authored Feb 26, 2025
1 parent 7818cdc commit 9bb2e4c
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 164 deletions.
6 changes: 3 additions & 3 deletions src/Futhark/Analysis/Interference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ analyseSegOp ::
m (InUse, LastUsed, Graph VName)
analyseSegOp lumap inuse (SegMap _ _ _ body) =
analyseKernelBody lumap inuse body
analyseSegOp lumap inuse (SegRed _ _ binops _ body) =
analyseSegOp lumap inuse (SegRed _ _ _ body binops) =
segWithBinOps lumap inuse binops body
analyseSegOp lumap inuse (SegScan _ _ binops _ body) = do
analyseSegOp lumap inuse (SegScan _ _ _ body binops) = do
segWithBinOps lumap inuse binops body
analyseSegOp lumap inuse (SegHist _ _ histops _ body) = do
analyseSegOp lumap inuse (SegHist _ _ _ body histops) = do
(inuse', lus', graph) <- analyseKernelBody lumap inuse body
(inuse'', lus'', graph') <- mconcat <$> mapM (analyseHistOp lumap inuse') histops
pure (inuse'', lus' <> lus'', graph <> graph')
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/Analysis/LastUse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,17 @@ lastUseSegOp (SegMap _ _ tps kbody) used_nms = do
(used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn tps
(body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms')
pure (body_lutab, lu_vars, used_nms' <> used_nms'')
lastUseSegOp (SegRed _ _ sbos tps kbody) used_nms = do
lastUseSegOp (SegRed _ _ tps kbody sbos) used_nms = do
(lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms
(used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps
(body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms')
pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'')
lastUseSegOp (SegScan _ _ sbos tps kbody) used_nms = do
lastUseSegOp (SegScan _ _ tps kbody sbos) used_nms = do
(lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms
(used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps
(body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms')
pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'')
lastUseSegOp (SegHist _ _ hos tps kbody) used_nms = do
lastUseSegOp (SegHist _ _ tps kbody hos) used_nms = do
(lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseHistOp hos used_nms
(used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps
(body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms')
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/Analysis/MemAlias.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ type MemAliasesM inner a = Reader (Env inner) a
analyzeHostOp :: MemAliases -> HostOp NoOp GPUMem -> MemAliasesM (HostOp NoOp GPUMem) MemAliases
analyzeHostOp m (SegOp (SegMap _ _ _ kbody)) =
analyzeStms (kernelBodyStms kbody) m
analyzeHostOp m (SegOp (SegRed _ _ _ _ kbody)) =
analyzeHostOp m (SegOp (SegRed _ _ _ kbody _)) =
analyzeStms (kernelBodyStms kbody) m
analyzeHostOp m (SegOp (SegScan _ _ _ _ kbody)) =
analyzeHostOp m (SegOp (SegScan _ _ _ kbody _)) =
analyzeStms (kernelBodyStms kbody) m
analyzeHostOp m (SegOp (SegHist _ _ _ _ kbody)) =
analyzeHostOp m (SegOp (SegHist _ _ _ kbody _)) =
analyzeStms (kernelBodyStms kbody) m
analyzeHostOp m SizeOp {} = pure m
analyzeHostOp m GPUBody {} = pure m
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/CodeGen/ImpGen/GPU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ segOpCompiler ::
CallKernelGen ()
segOpCompiler pat (SegMap lvl space _ kbody) =
compileSegMap pat lvl space kbody
segOpCompiler pat (SegRed lvl@(SegThread _ _) space reds _ kbody) =
segOpCompiler pat (SegRed lvl@(SegThread _ _) space _ kbody reds) =
compileSegRed pat lvl space reds kbody
segOpCompiler pat (SegScan lvl@(SegThread _ _) space scans _ kbody) =
segOpCompiler pat (SegScan lvl@(SegThread _ _) space _ kbody scans) =
compileSegScan pat lvl space scans kbody
segOpCompiler pat (SegHist lvl@(SegThread _ _) space ops _ kbody) =
segOpCompiler pat (SegHist lvl@(SegThread _ _) space _ kbody ops) =
compileSegHist pat lvl space ops kbody
segOpCompiler pat segop =
compilerBugS $ "segOpCompiler: unexpected " ++ prettyString (segLevel segop) ++ " for rhs of pattern " ++ prettyString pat
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/CodeGen/ImpGen/GPU/Block.hs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ compileBlockOp pat (Inner (SegOp (SegMap lvl space _ body))) = do
zipWithM_ (compileThreadResult space) (patElems pat) $
kernelBodyResult body
sOp $ Imp.ErrorSync Imp.FenceLocal
compileBlockOp pat (Inner (SegOp (SegScan lvl space scans _ body))) = do
compileBlockOp pat (Inner (SegOp (SegScan lvl space _ body scans))) = do
compileFlatId space

let (ltids, dims) = unzip $ unSegSpace space
Expand Down Expand Up @@ -412,7 +412,7 @@ compileBlockOp pat (Inner (SegOp (SegScan lvl space scans _ body))) = do
(product dims')
(segBinOpLambda scan)
arrs_flat
compileBlockOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
compileBlockOp pat (Inner (SegOp (SegRed lvl space _ body ops))) = do
compileFlatId space

let dims' = map pe64 dims
Expand Down Expand Up @@ -533,7 +533,7 @@ compileBlockOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
(map (unitSlice 0) (init dims') ++ [DimFix $ last dims' - 1])

sOp $ Imp.Barrier Imp.FenceLocal
compileBlockOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
compileBlockOp pat (Inner (SegOp (SegHist lvl space _ kbody ops))) = do
compileFlatId space
let (ltids, dims) = unzip $ unSegSpace space

Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/CodeGen/ImpGen/Multicore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ compileSegOp ::
SegOp () MCMem ->
TV Int32 ->
ImpM MCMem HostEnv Imp.Multicore Imp.MCCode
compileSegOp pat (SegHist _ space histops _ kbody) ntasks =
compileSegOp pat (SegHist _ space _ kbody histops) ntasks =
compileSegHist pat space histops kbody ntasks
compileSegOp pat (SegScan _ space scans _ kbody) ntasks =
compileSegOp pat (SegScan _ space _ kbody scans) ntasks =
compileSegScan pat space scans kbody ntasks
compileSegOp pat (SegRed _ space reds _ kbody) ntasks =
compileSegOp pat (SegRed _ space _ kbody reds) ntasks =
compileSegRed pat space reds kbody ntasks
compileSegOp pat (SegMap _ space _ kbody) _ =
compileSegMap pat space kbody
Expand Down
17 changes: 5 additions & 12 deletions src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -919,18 +919,10 @@ pSegOp pr pLvl =
keyword "seghist" *> pSegHist
]
where
pSegMap =
SegOp.SegMap
<$> pLvl
<*> pSegSpace
<* pColon
<*> pTypes
<*> braces (pKernelBody pr)
pSegOp' f p =
pSegOp' f =
f
<$> pLvl
<*> pSegSpace
<*> parens (p `sepBy` pComma)
<* pColon
<*> pTypes
<*> braces (pKernelBody pr)
Expand All @@ -953,9 +945,10 @@ pSegOp pr pLvl =
<*> pShape
<* pComma
<*> pLambda pr
pSegRed = pSegOp' SegOp.SegRed pSegBinOp
pSegScan = pSegOp' SegOp.SegScan pSegBinOp
pSegHist = pSegOp' SegOp.SegHist pHistOp
pSegMap = pSegOp' SegOp.SegMap
pSegRed = pSegOp' SegOp.SegRed <*> parens (pSegBinOp `sepBy` pComma)
pSegScan = pSegOp' SegOp.SegScan <*> parens (pSegBinOp `sepBy` pComma)
pSegHist = pSegOp' SegOp.SegHist <*> parens (pHistOp `sepBy` pComma)

pSegLevel :: Parser GPU.SegLevel
pSegLevel =
Expand Down
Loading

0 comments on commit 9bb2e4c

Please sign in to comment.