diff --git a/src/Futhark/Analysis/Interference.hs b/src/Futhark/Analysis/Interference.hs index 4742a39a40..e665c5ce4e 100644 --- a/src/Futhark/Analysis/Interference.hs +++ b/src/Futhark/Analysis/Interference.hs @@ -159,11 +159,11 @@ analyseSegOp :: m (InUse, LastUsed, Graph VName) analyseSegOp lumap inuse (SegMap _ _ _ body) = analyseKernelBody lumap inuse body -analyseSegOp lumap inuse (SegRed _ _ binops _ body) = +analyseSegOp lumap inuse (SegRed _ _ _ body binops) = segWithBinOps lumap inuse binops body -analyseSegOp lumap inuse (SegScan _ _ binops _ body) = do +analyseSegOp lumap inuse (SegScan _ _ _ body binops) = do segWithBinOps lumap inuse binops body -analyseSegOp lumap inuse (SegHist _ _ histops _ body) = do +analyseSegOp lumap inuse (SegHist _ _ _ body histops) = do (inuse', lus', graph) <- analyseKernelBody lumap inuse body (inuse'', lus'', graph') <- mconcat <$> mapM (analyseHistOp lumap inuse') histops pure (inuse'', lus' <> lus'', graph <> graph') diff --git a/src/Futhark/Analysis/LastUse.hs b/src/Futhark/Analysis/LastUse.hs index f0389580b5..560f27b3c6 100644 --- a/src/Futhark/Analysis/LastUse.hs +++ b/src/Futhark/Analysis/LastUse.hs @@ -312,17 +312,17 @@ lastUseSegOp (SegMap _ _ tps kbody) used_nms = do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (body_lutab, lu_vars, used_nms' <> used_nms'') -lastUseSegOp (SegRed _ _ sbos tps kbody) used_nms = do +lastUseSegOp (SegRed _ _ tps kbody sbos) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'') -lastUseSegOp (SegScan _ _ sbos tps kbody) used_nms = do +lastUseSegOp (SegScan _ _ tps kbody sbos) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'') -lastUseSegOp (SegHist _ _ hos tps kbody) used_nms = do +lastUseSegOp (SegHist _ _ tps kbody hos) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseHistOp hos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') diff --git a/src/Futhark/Analysis/MemAlias.hs b/src/Futhark/Analysis/MemAlias.hs index bdb0b6b287..39d496081b 100644 --- a/src/Futhark/Analysis/MemAlias.hs +++ b/src/Futhark/Analysis/MemAlias.hs @@ -68,11 +68,11 @@ type MemAliasesM inner a = Reader (Env inner) a analyzeHostOp :: MemAliases -> HostOp NoOp GPUMem -> MemAliasesM (HostOp NoOp GPUMem) MemAliases analyzeHostOp m (SegOp (SegMap _ _ _ kbody)) = analyzeStms (kernelBodyStms kbody) m -analyzeHostOp m (SegOp (SegRed _ _ _ _ kbody)) = +analyzeHostOp m (SegOp (SegRed _ _ _ kbody _)) = analyzeStms (kernelBodyStms kbody) m -analyzeHostOp m (SegOp (SegScan _ _ _ _ kbody)) = +analyzeHostOp m (SegOp (SegScan _ _ _ kbody _)) = analyzeStms (kernelBodyStms kbody) m -analyzeHostOp m (SegOp (SegHist _ _ _ _ kbody)) = +analyzeHostOp m (SegOp (SegHist _ _ _ kbody _)) = analyzeStms (kernelBodyStms kbody) m analyzeHostOp m SizeOp {} = pure m analyzeHostOp m GPUBody {} = pure m diff --git a/src/Futhark/CodeGen/ImpGen/GPU.hs b/src/Futhark/CodeGen/ImpGen/GPU.hs index d709c18ba1..ddae7c543f 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU.hs @@ -153,11 +153,11 @@ segOpCompiler :: CallKernelGen () segOpCompiler pat (SegMap lvl space _ kbody) = compileSegMap pat lvl space kbody -segOpCompiler pat (SegRed lvl@(SegThread _ _) space reds _ kbody) = +segOpCompiler pat (SegRed lvl@(SegThread _ _) space _ kbody reds) = compileSegRed pat lvl space reds kbody -segOpCompiler pat (SegScan lvl@(SegThread _ _) space scans _ kbody) = +segOpCompiler pat (SegScan lvl@(SegThread _ _) space _ kbody scans) = compileSegScan pat lvl space scans kbody -segOpCompiler pat (SegHist lvl@(SegThread _ _) space ops _ kbody) = +segOpCompiler pat (SegHist lvl@(SegThread _ _) space _ kbody ops) = compileSegHist pat lvl space ops kbody segOpCompiler pat segop = compilerBugS $ "segOpCompiler: unexpected " ++ prettyString (segLevel segop) ++ " for rhs of pattern " ++ prettyString pat diff --git a/src/Futhark/CodeGen/ImpGen/GPU/Block.hs b/src/Futhark/CodeGen/ImpGen/GPU/Block.hs index 128e79a873..f8251042ef 100644 --- a/src/Futhark/CodeGen/ImpGen/GPU/Block.hs +++ b/src/Futhark/CodeGen/ImpGen/GPU/Block.hs @@ -365,7 +365,7 @@ compileBlockOp pat (Inner (SegOp (SegMap lvl space _ body))) = do zipWithM_ (compileThreadResult space) (patElems pat) $ kernelBodyResult body sOp $ Imp.ErrorSync Imp.FenceLocal -compileBlockOp pat (Inner (SegOp (SegScan lvl space scans _ body))) = do +compileBlockOp pat (Inner (SegOp (SegScan lvl space _ body scans))) = do compileFlatId space let (ltids, dims) = unzip $ unSegSpace space @@ -412,7 +412,7 @@ compileBlockOp pat (Inner (SegOp (SegScan lvl space scans _ body))) = do (product dims') (segBinOpLambda scan) arrs_flat -compileBlockOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do +compileBlockOp pat (Inner (SegOp (SegRed lvl space _ body ops))) = do compileFlatId space let dims' = map pe64 dims @@ -533,7 +533,7 @@ compileBlockOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do (map (unitSlice 0) (init dims') ++ [DimFix $ last dims' - 1]) sOp $ Imp.Barrier Imp.FenceLocal -compileBlockOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do +compileBlockOp pat (Inner (SegOp (SegHist lvl space _ kbody ops))) = do compileFlatId space let (ltids, dims) = unzip $ unSegSpace space diff --git a/src/Futhark/CodeGen/ImpGen/Multicore.hs b/src/Futhark/CodeGen/ImpGen/Multicore.hs index c4cf6f6263..1ca768c43a 100644 --- a/src/Futhark/CodeGen/ImpGen/Multicore.hs +++ b/src/Futhark/CodeGen/ImpGen/Multicore.hs @@ -164,11 +164,11 @@ compileSegOp :: SegOp () MCMem -> TV Int32 -> ImpM MCMem HostEnv Imp.Multicore Imp.MCCode -compileSegOp pat (SegHist _ space histops _ kbody) ntasks = +compileSegOp pat (SegHist _ space _ kbody histops) ntasks = compileSegHist pat space histops kbody ntasks -compileSegOp pat (SegScan _ space scans _ kbody) ntasks = +compileSegOp pat (SegScan _ space _ kbody scans) ntasks = compileSegScan pat space scans kbody ntasks -compileSegOp pat (SegRed _ space reds _ kbody) ntasks = +compileSegOp pat (SegRed _ space _ kbody reds) ntasks = compileSegRed pat space reds kbody ntasks compileSegOp pat (SegMap _ space _ kbody) _ = compileSegMap pat space kbody diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index 298e149617..7b7d5b383a 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -919,18 +919,10 @@ pSegOp pr pLvl = keyword "seghist" *> pSegHist ] where - pSegMap = - SegOp.SegMap - <$> pLvl - <*> pSegSpace - <* pColon - <*> pTypes - <*> braces (pKernelBody pr) - pSegOp' f p = + pSegOp' f = f <$> pLvl <*> pSegSpace - <*> parens (p `sepBy` pComma) <* pColon <*> pTypes <*> braces (pKernelBody pr) @@ -953,9 +945,10 @@ pSegOp pr pLvl = <*> pShape <* pComma <*> pLambda pr - pSegRed = pSegOp' SegOp.SegRed pSegBinOp - pSegScan = pSegOp' SegOp.SegScan pSegBinOp - pSegHist = pSegOp' SegOp.SegHist pHistOp + pSegMap = pSegOp' SegOp.SegMap + pSegRed = pSegOp' SegOp.SegRed <*> parens (pSegBinOp `sepBy` pComma) + pSegScan = pSegOp' SegOp.SegScan <*> parens (pSegBinOp `sepBy` pComma) + pSegHist = pSegOp' SegOp.SegHist <*> parens (pHistOp `sepBy` pComma) pSegLevel :: Parser GPU.SegLevel pSegLevel = diff --git a/src/Futhark/IR/SegOp.hs b/src/Futhark/IR/SegOp.hs index ad5a98f14d..6dd42227a7 100644 --- a/src/Futhark/IR/SegOp.hs +++ b/src/Futhark/IR/SegOp.hs @@ -450,9 +450,9 @@ data SegOp lvl rep = SegMap lvl SegSpace [Type] (KernelBody rep) | -- | The KernelSpace must always have at least two dimensions, -- implying that the result of a SegRed is always an array. - SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep) - | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep) - | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep) + SegRed lvl SegSpace [Type] (KernelBody rep) [SegBinOp rep] + | SegScan lvl SegSpace [Type] (KernelBody rep) [SegBinOp rep] + | SegHist lvl SegSpace [Type] (KernelBody rep) [HistOp rep] deriving (Eq, Ord, Show) -- | The level of a 'SegOp'. @@ -474,9 +474,9 @@ segBody :: SegOp lvl rep -> KernelBody rep segBody segop = case segop of SegMap _ _ _ body -> body - SegRed _ _ _ _ body -> body - SegScan _ _ _ _ body -> body - SegHist _ _ _ _ body -> body + SegRed _ _ _ body _ -> body + SegScan _ _ _ body _ -> body + SegHist _ _ _ body _ -> body segResultShape :: SegSpace -> Type -> KernelResult -> Type segResultShape _ t (WriteReturns {}) = @@ -492,7 +492,7 @@ segResultShape _ t (RegTileReturns _ dims_n_tiles _) = segOpType :: SegOp lvl rep -> [Type] segOpType (SegMap _ space ts kbody) = zipWith (segResultShape space) ts $ kernelBodyResult kbody -segOpType (SegRed _ space reds ts kbody) = +segOpType (SegRed _ space ts kbody reds) = red_ts ++ zipWith (segResultShape space) @@ -505,7 +505,7 @@ segOpType (SegRed _ space reds ts kbody) = op <- reds let shape = Shape segment_dims <> segBinOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ segBinOpLambda op) -segOpType (SegScan _ space scans ts kbody) = +segOpType (SegScan _ space ts kbody scans) = scan_ts ++ zipWith (segResultShape space) @@ -517,7 +517,7 @@ segOpType (SegScan _ space scans ts kbody) = op <- scans let shape = Shape (segSpaceDims space) <> segBinOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ segBinOpLambda op) -segOpType (SegHist _ space ops _ _) = do +segOpType (SegHist _ space _ _ ops) = do op <- ops let shape = Shape segment_dims <> histShape op <> histOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ histOp op) @@ -533,11 +533,11 @@ instance (ASTConstraints lvl) => AliasedOp (SegOp lvl) where consumedInOp (SegMap _ _ _ kbody) = consumedInKernelBody kbody - consumedInOp (SegRed _ _ _ _ kbody) = + consumedInOp (SegRed _ _ _ kbody _) = consumedInKernelBody kbody - consumedInOp (SegScan _ _ _ _ kbody) = + consumedInOp (SegScan _ _ _ kbody _) = consumedInKernelBody kbody - consumedInOp (SegHist _ _ ops _ kbody) = + consumedInOp (SegHist _ _ _ kbody ops) = namesFromList (concatMap histDest ops) <> consumedInKernelBody kbody -- | Type check a 'SegOp', given a checker for its level. @@ -549,7 +549,7 @@ typeCheckSegOp :: typeCheckSegOp checkLvl (SegMap lvl space ts kbody) = do checkLvl lvl checkScanRed space [] ts kbody -typeCheckSegOp checkLvl (SegRed lvl space reds ts body) = do +typeCheckSegOp checkLvl (SegRed lvl space ts body reds) = do checkLvl lvl checkScanRed space reds' ts body where @@ -558,7 +558,7 @@ typeCheckSegOp checkLvl (SegRed lvl space reds ts body) = do (map segBinOpLambda reds) (map segBinOpNeutral reds) (map segBinOpShape reds) -typeCheckSegOp checkLvl (SegScan lvl space scans ts body) = do +typeCheckSegOp checkLvl (SegScan lvl space ts body scans) = do checkLvl lvl checkScanRed space scans' ts body where @@ -567,7 +567,7 @@ typeCheckSegOp checkLvl (SegScan lvl space scans ts body) = do (map segBinOpLambda scans) (map segBinOpNeutral scans) (map segBinOpShape scans) -typeCheckSegOp checkLvl (SegHist lvl space ops ts kbody) = do +typeCheckSegOp checkLvl (SegHist lvl space ts kbody ops) = do checkLvl lvl checkSegSpace space mapM_ TC.checkType ts @@ -705,27 +705,27 @@ mapSegOpM tv (SegMap lvl space ts body) = <*> mapOnSegSpace tv space <*> mapM (mapOnSegOpType tv) ts <*> mapOnSegOpBody tv body -mapSegOpM tv (SegRed lvl space reds ts lam) = +mapSegOpM tv (SegRed lvl space ts body reds) = SegRed <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space - <*> mapM (mapSegBinOp tv) reds <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts - <*> mapOnSegOpBody tv lam -mapSegOpM tv (SegScan lvl space scans ts body) = + <*> mapOnSegOpBody tv body + <*> mapM (mapSegBinOp tv) reds +mapSegOpM tv (SegScan lvl space ts body scans) = SegScan <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space - <*> mapM (mapSegBinOp tv) scans <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body -mapSegOpM tv (SegHist lvl space ops ts body) = + <*> mapM (mapSegBinOp tv) scans +mapSegOpM tv (SegHist lvl space ts body ops) = SegHist <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space - <*> mapM onHistOp ops <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body + <*> mapM onHistOp ops where onHistOp (HistOp w rf arrs nes shape op) = HistOp @@ -771,21 +771,18 @@ rephraseKernelBody r (KernelBody dec stms res) = instance RephraseOp (SegOp lvl) where rephraseInOp r (SegMap lvl space ts body) = SegMap lvl space ts <$> rephraseKernelBody r body - rephraseInOp r (SegRed lvl space reds ts body) = - SegRed lvl space - <$> mapM (rephraseBinOp r) reds - <*> pure ts - <*> rephraseKernelBody r body - rephraseInOp r (SegScan lvl space scans ts body) = - SegScan lvl space - <$> mapM (rephraseBinOp r) scans - <*> pure ts - <*> rephraseKernelBody r body - rephraseInOp r (SegHist lvl space hists ts body) = - SegHist lvl space - <$> mapM onOp hists - <*> pure ts - <*> rephraseKernelBody r body + rephraseInOp r (SegRed lvl space ts body reds) = + SegRed lvl space ts + <$> rephraseKernelBody r body + <*> mapM (rephraseBinOp r) reds + rephraseInOp r (SegScan lvl space ts body scans) = + SegScan lvl space ts + <$> rephraseKernelBody r body + <*> mapM (rephraseBinOp r) scans + rephraseInOp r (SegHist lvl space ts body hists) = + SegHist lvl space ts + <$> rephraseKernelBody r body + <*> mapM onOp hists where onOp (HistOp w rf arrs nes shape op) = HistOp w rf arrs nes shape <$> rephraseLambda r op @@ -844,15 +841,15 @@ instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where instance (OpMetrics (Op rep)) => OpMetrics (SegOp lvl rep) where opMetrics (SegMap _ _ _ body) = inside "SegMap" $ kernelBodyMetrics body - opMetrics (SegRed _ _ reds _ body) = + opMetrics (SegRed _ _ _ body reds) = inside "SegRed" $ do mapM_ (inside "SegBinOp" . lambdaMetrics . segBinOpLambda) reds kernelBodyMetrics body - opMetrics (SegScan _ _ scans _ body) = + opMetrics (SegScan _ _ _ body scans) = inside "SegScan" $ do mapM_ (inside "SegBinOp" . lambdaMetrics . segBinOpLambda) scans kernelBodyMetrics body - opMetrics (SegHist _ _ ops _ body) = + opMetrics (SegHist _ _ _ body ops) = inside "SegHist" $ do mapM_ (lambdaMetrics . histOp) ops kernelBodyMetrics body @@ -887,30 +884,30 @@ instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where <+> PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) - pretty (SegRed lvl space reds ts body) = + pretty (SegRed lvl space ts body reds) = "segred" <> pretty lvl PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty reds) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) - pretty (SegScan lvl space scans ts body) = + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty reds) + pretty (SegScan lvl space ts body scans) = "segscan" <> pretty lvl PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty scans) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) - pretty (SegHist lvl space ops ts body) = + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty scans) + pretty (SegHist lvl space ts body ops) = "seghist" <> pretty lvl PP.align (pretty space) - PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) + PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops) where ppOp (HistOp w rf dests nes shape op) = pretty w @@ -1125,7 +1122,7 @@ simplifySegOp (SegMap lvl space ts kbody) = do ( SegMap lvl' space' ts' kbody', body_hoisted ) -simplifySegOp (SegRed lvl space reds ts kbody) = do +simplifySegOp (SegRed lvl space ts kbody reds) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) (reds', reds_hoisted) <- Engine.localVtable (<> scope_vtable) $ @@ -1133,13 +1130,13 @@ simplifySegOp (SegRed lvl space reds ts kbody) = do (kbody', body_hoisted) <- simplifyKernelBody space kbody pure - ( SegRed lvl' space' reds' ts' kbody', + ( SegRed lvl' space' ts' kbody' reds', mconcat reds_hoisted <> body_hoisted ) where scope = scopeOfSegSpace space scope_vtable = ST.fromScope scope -simplifySegOp (SegScan lvl space scans ts kbody) = do +simplifySegOp (SegScan lvl space ts kbody scans) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) (scans', scans_hoisted) <- Engine.localVtable (<> scope_vtable) $ @@ -1147,13 +1144,13 @@ simplifySegOp (SegScan lvl space scans ts kbody) = do (kbody', body_hoisted) <- simplifyKernelBody space kbody pure - ( SegScan lvl' space' scans' ts' kbody', + ( SegScan lvl' space' ts' kbody' scans', mconcat scans_hoisted <> body_hoisted ) where scope = scopeOfSegSpace space scope_vtable = ST.fromScope scope -simplifySegOp (SegHist lvl space ops ts kbody) = do +simplifySegOp (SegHist lvl space ts kbody ops) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) Engine.localVtable (flip (foldr ST.consume) $ concatMap histDest ops) $ do @@ -1176,7 +1173,7 @@ simplifySegOp (SegHist lvl space ops ts kbody) = do (kbody', body_hoisted) <- simplifyKernelBody space kbody pure - ( SegHist lvl' space' ops' ts' kbody', + ( SegHist lvl' space' ts' kbody' ops', mconcat ops_hoisted <> body_hoisted ) where @@ -1251,7 +1248,7 @@ topDownSegOp vtable (Pat kpes) dec (SegMap lvl space ts (KernelBody _ kstms kres -- If a SegRed contains two reduction operations that have the same -- vector shape, merge them together. This saves on communication -- overhead, but can in principle lead to more shared memory usage. -topDownSegOp _ (Pat pes) _ (SegRed lvl space ops ts kbody) +topDownSegOp _ (Pat pes) _ (SegRed lvl space ts kbody ops) | length ops > 1, op_groupings <- groupBy sameShape $ @@ -1264,7 +1261,7 @@ topDownSegOp _ (Pat pes) _ (SegRed lvl space ops ts kbody) pes' = red_pes' ++ map_pes ts' = red_ts' ++ map_ts kbody' = kbody {kernelBodyResult = red_res' ++ map_res} - letBind (Pat pes') $ Op $ segOp $ SegRed lvl space ops' ts' kbody' + letBind (Pat pes') $ Op $ segOp $ SegRed lvl space ts' kbody' ops' where (red_pes, map_pes) = splitAt (segBinOpResults ops) pes (red_ts, map_ts) = splitAt (segBinOpResults ops) ts @@ -1317,12 +1314,12 @@ segOpGuts :: ) segOpGuts (SegMap lvl space kts body) = (kts, body, 0, SegMap lvl space) -segOpGuts (SegScan lvl space ops kts body) = - (kts, body, segBinOpResults ops, SegScan lvl space ops) -segOpGuts (SegRed lvl space ops kts body) = - (kts, body, segBinOpResults ops, SegRed lvl space ops) -segOpGuts (SegHist lvl space ops kts body) = - (kts, body, sum $ map (length . histDest) ops, SegHist lvl space ops) +segOpGuts (SegScan lvl space kts body ops) = + (kts, body, segBinOpResults ops, \t b -> SegScan lvl space t b ops) +segOpGuts (SegRed lvl space kts body ops) = + (kts, body, segBinOpResults ops, \t b -> SegRed lvl space t b ops) +segOpGuts (SegHist lvl space kts body ops) = + (kts, body, sum $ map (length . histDest) ops, \t b -> SegHist lvl space t b ops) bottomUpSegOp :: (Aliased rep, HasSegOp rep, BuilderOps rep) => @@ -1441,9 +1438,9 @@ segOpReturns :: m [ExpReturns] segOpReturns k@(SegMap _ _ _ kbody) = kernelBodyReturns kbody . extReturns =<< opType k -segOpReturns k@(SegRed _ _ _ _ kbody) = +segOpReturns k@(SegRed _ _ _ kbody _) = kernelBodyReturns kbody . extReturns =<< opType k -segOpReturns k@(SegScan _ _ _ _ kbody) = +segOpReturns k@(SegScan _ _ _ kbody _) = kernelBodyReturns kbody . extReturns =<< opType k -segOpReturns (SegHist _ _ ops _ _) = +segOpReturns (SegHist _ _ _ _ ops) = concat <$> mapM (mapM varReturns . histDest) ops diff --git a/src/Futhark/Optimise/ArrayShortCircuiting.hs b/src/Futhark/Optimise/ArrayShortCircuiting.hs index 0840200024..6db5f2f68f 100644 --- a/src/Futhark/Optimise/ArrayShortCircuiting.hs +++ b/src/Futhark/Optimise/ArrayShortCircuiting.hs @@ -166,15 +166,15 @@ replaceInSegOp :: replaceInSegOp (SegMap lvl sp tps body) = do stms <- updateStms $ kernelBodyStms body pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} -replaceInSegOp (SegRed lvl sp binops tps body) = do +replaceInSegOp (SegRed lvl sp tps body binops) = do stms <- updateStms $ kernelBodyStms body - pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} -replaceInSegOp (SegScan lvl sp binops tps body) = do + pure $ SegRed lvl sp tps (body {kernelBodyStms = stms}) binops +replaceInSegOp (SegScan lvl sp tps body binops) = do stms <- updateStms $ kernelBodyStms body - pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} -replaceInSegOp (SegHist lvl sp hist_ops tps body) = do + pure $ SegScan lvl sp tps (body {kernelBodyStms = stms}) binops +replaceInSegOp (SegHist lvl sp tps body hist_ops) = do stms <- updateStms $ kernelBodyStms body - pure $ SegHist lvl sp hist_ops tps $ body {kernelBodyStms = stms} + pure $ SegHist lvl sp tps (body {kernelBodyStms = stms}) hist_ops replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem) replaceInHostOp (SegOp op) = SegOp <$> replaceInSegOp op diff --git a/src/Futhark/Optimise/ArrayShortCircuiting/ArrayCoalescing.hs b/src/Futhark/Optimise/ArrayShortCircuiting/ArrayCoalescing.hs index 9973c17f09..59c89c8b4e 100644 --- a/src/Futhark/Optimise/ArrayShortCircuiting/ArrayCoalescing.hs +++ b/src/Futhark/Optimise/ArrayShortCircuiting/ArrayCoalescing.hs @@ -201,7 +201,7 @@ shortCircuitSegOp :: shortCircuitSegOp lvlOK lutab pat pat_certs (SegMap lvl space _ kernel_body) td_env bu_env = -- No special handling necessary for 'SegMap'. Just call the helper-function. shortCircuitSegOpHelper 0 lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env -shortCircuitSegOp lvlOK lutab pat pat_certs (SegRed lvl space binops _ kernel_body) td_env bu_env = +shortCircuitSegOp lvlOK lutab pat pat_certs (SegRed lvl space _ kernel_body binops) td_env bu_env = -- When handling 'SegRed', we we first invalidate all active coalesce-entries -- where any of the variables in 'vartab' are also free in the list of -- 'SegBinOp'. In other words, anything that is used as part of the reduction @@ -218,14 +218,14 @@ shortCircuitSegOp lvlOK lutab pat pat_certs (SegRed lvl space binops _ kernel_bo op <- binops let shp = Shape segment_dims <> segBinOpShape op map (`arrayOfShape` shp) (lambdaReturnType $ segBinOpLambda op) -shortCircuitSegOp lvlOK lutab pat pat_certs (SegScan lvl space binops _ kernel_body) td_env bu_env = +shortCircuitSegOp lvlOK lutab pat pat_certs (SegScan lvl space _ kernel_body binops) td_env bu_env = -- Like in the handling of 'SegRed', we do not want to coalesce anything that -- is used in the 'SegBinOp' let to_fail = M.filter (\entry -> namesFromList (M.keys $ vartab entry) `namesIntersect` foldMap (freeIn . segBinOpLambda) binops) $ activeCoals bu_env (active, inh) = foldl markFailedCoal (activeCoals bu_env, inhibit bu_env) $ M.keys to_fail bu_env' = bu_env {activeCoals = active, inhibit = inh} in shortCircuitSegOpHelper 0 lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env' -shortCircuitSegOp lvlOK lutab pat pat_certs (SegHist lvl space histops _ kernel_body) td_env bu_env = do +shortCircuitSegOp lvlOK lutab pat pat_certs (SegHist lvl space _ kernel_body histops) td_env bu_env = do -- Need to take zipped patterns and histDest (flattened) and insert transitive coalesces let to_fail = M.filter (\entry -> namesFromList (M.keys $ vartab entry) `namesIntersect` foldMap (freeIn . histOp) histops) $ activeCoals bu_env (active, inh) = foldl markFailedCoal (activeCoals bu_env, inhibit bu_env) $ M.keys to_fail diff --git a/src/Futhark/Optimise/HistAccs.hs b/src/Futhark/Optimise/HistAccs.hs index be7377b94c..0e2f72fa67 100644 --- a/src/Futhark/Optimise/HistAccs.hs +++ b/src/Futhark/Optimise/HistAccs.hs @@ -146,7 +146,7 @@ optimiseStm accs (Let pat aux (Op (SegOp (SegMap lvl space _ kbody)))) (space', kbody'') <- flatKernelBody space kbody' hist_dest_upd <- - letTupExp "hist_dest_upd" $ Op $ SegOp $ SegHist lvl space' [histop] ts' kbody'' + letTupExp "hist_dest_upd" $ Op $ SegOp $ SegHist lvl space' ts' kbody'' [histop] addStm . Let pat aux =<< addArrsToAcc lvl acc_shape hist_dest_upd acc optimiseStm accs (Let pat aux e) = diff --git a/src/Futhark/Optimise/MemoryBlockMerging.hs b/src/Futhark/Optimise/MemoryBlockMerging.hs index 4ee960ce00..8598f90a0c 100644 --- a/src/Futhark/Optimise/MemoryBlockMerging.hs +++ b/src/Futhark/Optimise/MemoryBlockMerging.hs @@ -37,11 +37,11 @@ getAllocsStm _ = mempty getAllocsSegOp :: SegOp lvl GPUMem -> Allocs getAllocsSegOp (SegMap _ _ _ body) = foldMap getAllocsStm (kernelBodyStms body) -getAllocsSegOp (SegRed _ _ _ _ body) = +getAllocsSegOp (SegRed _ _ _ body _) = foldMap getAllocsStm (kernelBodyStms body) -getAllocsSegOp (SegScan _ _ _ _ body) = +getAllocsSegOp (SegScan _ _ _ body _) = foldMap getAllocsStm (kernelBodyStms body) -getAllocsSegOp (SegHist _ _ _ _ body) = +getAllocsSegOp (SegHist _ _ _ body _) = foldMap getAllocsStm (kernelBodyStms body) setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem @@ -69,15 +69,18 @@ setAllocsSegOp :: setAllocsSegOp m (SegMap lvl sp tps body) = SegMap lvl sp tps $ body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} -setAllocsSegOp m (SegRed lvl sp segbinops tps body) = - SegRed lvl sp segbinops tps $ - body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} -setAllocsSegOp m (SegScan lvl sp segbinops tps body) = - SegScan lvl sp segbinops tps $ - body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} -setAllocsSegOp m (SegHist lvl sp segbinops tps body) = - SegHist lvl sp segbinops tps $ - body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} +setAllocsSegOp m (SegRed lvl sp tps body ops) = + SegRed lvl sp tps body' ops + where + body' = body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} +setAllocsSegOp m (SegScan lvl sp tps body ops) = + SegScan lvl sp tps body' ops + where + body' = body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} +setAllocsSegOp m (SegHist lvl sp tps body ops) = + SegHist lvl sp tps body' ops + where + body' = body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} maxSubExp :: (MonadBuilder m) => Set SubExp -> m SubExp maxSubExp = helper . S.toList @@ -105,15 +108,15 @@ onKernelBodyStms :: onKernelBodyStms (SegMap lvl space ts body) f = do stms <- f $ kernelBodyStms body pure $ SegMap lvl space ts $ body {kernelBodyStms = stms} -onKernelBodyStms (SegRed lvl space binops ts body) f = do +onKernelBodyStms (SegRed lvl space ts body binops) f = do stms <- f $ kernelBodyStms body - pure $ SegRed lvl space binops ts $ body {kernelBodyStms = stms} -onKernelBodyStms (SegScan lvl space binops ts body) f = do + pure $ SegRed lvl space ts (body {kernelBodyStms = stms}) binops +onKernelBodyStms (SegScan lvl space ts body binops) f = do stms <- f $ kernelBodyStms body - pure $ SegScan lvl space binops ts $ body {kernelBodyStms = stms} -onKernelBodyStms (SegHist lvl space binops ts body) f = do + pure $ SegScan lvl space ts (body {kernelBodyStms = stms}) binops +onKernelBodyStms (SegHist lvl space ts body binops) f = do stms <- f $ kernelBodyStms body - pure $ SegHist lvl space binops ts $ body {kernelBodyStms = stms} + pure $ SegHist lvl space ts (body {kernelBodyStms = stms}) binops -- | This is the actual optimiser. Given an interference graph and a @SegOp@, -- replace allocations and references to memory blocks inside with a (hopefully) @@ -148,15 +151,18 @@ optimiseKernel graph segop0 = do SegMap lvl sp tps body -> SegMap lvl sp tps $ body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} - SegRed lvl sp binops tps body -> - SegRed lvl sp binops tps $ - body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} - SegScan lvl sp binops tps body -> - SegScan lvl sp binops tps $ - body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} - SegHist lvl sp binops tps body -> - SegHist lvl sp binops tps $ - body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} + SegRed lvl sp tps body ops -> + SegRed lvl sp tps body' ops + where + body' = body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} + SegScan lvl sp tps body ops -> + SegScan lvl sp tps body' ops + where + body' = body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} + SegHist lvl sp tps body ops -> + SegHist lvl sp tps body' ops + where + body' = body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} -- | Helper function that modifies kernels found inside some statements. onKernels :: diff --git a/src/Futhark/Optimise/ReduceDeviceSyncs.hs b/src/Futhark/Optimise/ReduceDeviceSyncs.hs index 0219812b30..a7f4ea33aa 100644 --- a/src/Futhark/Optimise/ReduceDeviceSyncs.hs +++ b/src/Futhark/Optimise/ReduceDeviceSyncs.hs @@ -328,15 +328,18 @@ optimizeWithAccInput acc (shape, arrs, Just (op, nes)) = do optimizeHostOp :: HostOp op GPU -> ReduceM (HostOp op GPU) optimizeHostOp (SegOp (SegMap lvl space types kbody)) = SegOp . SegMap lvl space types <$> addReadsToKernelBody kbody -optimizeHostOp (SegOp (SegRed lvl space ops types kbody)) = do +optimizeHostOp (SegOp (SegRed lvl space types kbody ops)) = do ops' <- mapM addReadsToSegBinOp ops - SegOp . SegRed lvl space ops' types <$> addReadsToKernelBody kbody -optimizeHostOp (SegOp (SegScan lvl space ops types kbody)) = do + kbody' <- addReadsToKernelBody kbody + pure . SegOp $ SegRed lvl space types kbody' ops' +optimizeHostOp (SegOp (SegScan lvl space types kbody ops)) = do ops' <- mapM addReadsToSegBinOp ops - SegOp . SegScan lvl space ops' types <$> addReadsToKernelBody kbody -optimizeHostOp (SegOp (SegHist lvl space ops types kbody)) = do + kbody' <- addReadsToKernelBody kbody + pure . SegOp $ SegScan lvl space types kbody' ops' +optimizeHostOp (SegOp (SegHist lvl space types kbody ops)) = do ops' <- mapM addReadsToHistOp ops - SegOp . SegHist lvl space ops' types <$> addReadsToKernelBody kbody + kbody' <- addReadsToKernelBody kbody + pure . SegOp $ SegHist lvl space types kbody' ops' optimizeHostOp (SizeOp op) = pure (SizeOp op) optimizeHostOp OtherOp {} = diff --git a/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable.hs b/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable.hs index bd2a00171b..b99a541336 100644 --- a/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable.hs +++ b/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable.hs @@ -1074,15 +1074,15 @@ graphedScalarOperands e = collectHostOp (SegOp (SegMap lvl sp _ _)) = do collectSegLevel lvl collectSegSpace sp - collectHostOp (SegOp (SegRed lvl sp ops _ _)) = do + collectHostOp (SegOp (SegRed lvl sp _ _ ops)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops - collectHostOp (SegOp (SegScan lvl sp ops _ _)) = do + collectHostOp (SegOp (SegScan lvl sp _ _ ops)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops - collectHostOp (SegOp (SegHist lvl sp ops _ _)) = do + collectHostOp (SegOp (SegHist lvl sp _ _ ops)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectHistOp ops diff --git a/src/Futhark/Pass/ExpandAllocations.hs b/src/Futhark/Pass/ExpandAllocations.hs index 6d5fb4f422..79d84705bd 100644 --- a/src/Futhark/Pass/ExpandAllocations.hs +++ b/src/Futhark/Pass/ExpandAllocations.hs @@ -117,28 +117,28 @@ transformExp (Op (Inner (SegOp (SegMap lvl space ts kbody)))) = do ( alloc_stms, Op $ Inner $ SegOp $ SegMap lvl' space ts kbody' ) -transformExp (Op (Inner (SegOp (SegRed lvl space reds ts kbody)))) = do +transformExp (Op (Inner (SegOp (SegRed lvl space ts kbody reds)))) = do (alloc_stms, (lvl', lams, kbody')) <- transformScanRed lvl space (map segBinOpLambda reds) kbody let reds' = zipWith (\red lam -> red {segBinOpLambda = lam}) reds lams pure ( alloc_stms, - Op $ Inner $ SegOp $ SegRed lvl' space reds' ts kbody' + Op $ Inner $ SegOp $ SegRed lvl' space ts kbody' reds' ) -transformExp (Op (Inner (SegOp (SegScan lvl space scans ts kbody)))) = do +transformExp (Op (Inner (SegOp (SegScan lvl space ts kbody scans)))) = do (alloc_stms, (lvl', lams, kbody')) <- transformScanRed lvl space (map segBinOpLambda scans) kbody let scans' = zipWith (\red lam -> red {segBinOpLambda = lam}) scans lams pure ( alloc_stms, - Op $ Inner $ SegOp $ SegScan lvl' space scans' ts kbody' + Op $ Inner $ SegOp $ SegScan lvl' space ts kbody' scans' ) -transformExp (Op (Inner (SegOp (SegHist lvl space ops ts kbody)))) = do +transformExp (Op (Inner (SegOp (SegHist lvl space ts kbody ops)))) = do (alloc_stms, (lvl', lams', kbody')) <- transformScanRed lvl space lams kbody let ops' = zipWith onOp ops lams' pure ( alloc_stms, - Op $ Inner $ SegOp $ SegHist lvl' space ops' ts kbody' + Op $ Inner $ SegOp $ SegHist lvl' space ts kbody' ops' ) where lams = map histOp ops diff --git a/src/Futhark/Pass/ExplicitAllocations/GPU.hs b/src/Futhark/Pass/ExplicitAllocations/GPU.hs index 62e30c77d0..57be416534 100644 --- a/src/Futhark/Pass/ExplicitAllocations/GPU.hs +++ b/src/Futhark/Pass/ExplicitAllocations/GPU.hs @@ -105,7 +105,7 @@ kernelExpHints (BasicOp (Manifest perm v)) = do pure [Hint lmad $ Space "device"] kernelExpHints (Op (Inner (SegOp (SegMap lvl@(SegThread _ _) space ts body)))) = zipWithM (mapResultHint lvl space) ts $ kernelBodyResult body -kernelExpHints (Op (Inner (SegOp (SegRed lvl@(SegThread _ _) space reds ts body)))) = +kernelExpHints (Op (Inner (SegOp (SegRed lvl@(SegThread _ _) space ts body reds)))) = (map (const NoHint) red_res <>) <$> zipWithM (mapResultHint lvl space) (drop num_reds ts) map_res where num_reds = segBinOpResults reds diff --git a/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs b/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs index f95d4758b0..d30977b497 100644 --- a/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs +++ b/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs @@ -88,7 +88,7 @@ segRed lvl pat cs w ops map_lam arrs ispace inps = runBuilder_ $ do letBind pat $ Op $ segOp $ - SegRed lvl kspace ops (lambdaReturnType map_lam) kbody + SegRed lvl kspace (lambdaReturnType map_lam) kbody ops segScan :: (MonadFreshNames m, DistRep rep, HasScope rep m) => @@ -107,7 +107,7 @@ segScan lvl pat cs w ops map_lam arrs ispace inps = runBuilder_ $ do letBind pat $ Op $ segOp $ - SegScan lvl kspace ops (lambdaReturnType map_lam) kbody + SegScan lvl kspace (lambdaReturnType map_lam) kbody ops segMap :: (MonadFreshNames m, DistRep rep, HasScope rep m) => @@ -197,7 +197,7 @@ segHist lvl pat arr_w ispace inps ops lam arrs = runBuilder_ $ do forM res $ \(SubExpRes cs se) -> pure $ Returns ResultMaySimplify cs se - letBind pat $ Op $ segOp $ SegHist lvl space ops (lambdaReturnType lam) kbody + letBind pat $ Op $ segOp $ SegHist lvl space (lambdaReturnType lam) kbody ops mapKernelSkeleton :: (DistRep rep, HasScope rep m, MonadFreshNames m) => diff --git a/src/Futhark/Pass/ExtractMulticore.hs b/src/Futhark/Pass/ExtractMulticore.hs index e7a19e3912..1706781da3 100644 --- a/src/Futhark/Pass/ExtractMulticore.hs +++ b/src/Futhark/Pass/ExtractMulticore.hs @@ -188,7 +188,7 @@ transformRedomap rename onBody w reds map_lam arrs = do (reds_stms, reds') <- mapAndUnzipM reduceToSegBinOp reds op' <- renameIfNeeded rename $ - SegRed () space reds' (lambdaReturnType map_lam) kbody + SegRed () space (lambdaReturnType map_lam) kbody reds' pure (reds_stms, op') transformHist :: @@ -205,7 +205,7 @@ transformHist rename onBody w hists map_lam arrs = do (hists_stms, hists') <- mapAndUnzipM histToSegBinOp hists op' <- renameIfNeeded rename $ - SegHist () space hists' (lambdaReturnType map_lam) kbody + SegHist () space (lambdaReturnType map_lam) kbody hists' pure (hists_stms, op') transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC) @@ -245,7 +245,7 @@ transformSOAC pat _ (Screma w arrs form) ( Let pat (defAux ()) $ Op $ ParOp Nothing $ - SegScan () space scans' (lambdaReturnType map_lam) kbody + SegScan () space (lambdaReturnType map_lam) kbody scans' ) | otherwise = do -- This screma is too complicated for us to immediately do diff --git a/src/Futhark/Pass/LiftAllocations.hs b/src/Futhark/Pass/LiftAllocations.hs index 57b1740a23..842125a42a 100644 --- a/src/Futhark/Pass/LiftAllocations.hs +++ b/src/Futhark/Pass/LiftAllocations.hs @@ -120,15 +120,15 @@ liftAllocationsInSegOp :: liftAllocationsInSegOp (SegMap lvl sp tps body) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} -liftAllocationsInSegOp (SegRed lvl sp binops tps body) = do +liftAllocationsInSegOp (SegRed lvl sp tps body binops) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty - pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} -liftAllocationsInSegOp (SegScan lvl sp binops tps body) = do + pure $ SegRed lvl sp tps (body {kernelBodyStms = stms}) binops +liftAllocationsInSegOp (SegScan lvl sp tps body binops) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty - pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} -liftAllocationsInSegOp (SegHist lvl sp histops tps body) = do + pure $ SegScan lvl sp tps (body {kernelBodyStms = stms}) binops +liftAllocationsInSegOp (SegHist lvl sp tps body histops) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty - pure $ SegHist lvl sp histops tps $ body {kernelBodyStms = stms} + pure $ SegHist lvl sp tps (body {kernelBodyStms = stms}) histops liftAllocationsInHostOp :: HostOp NoOp (Aliases GPUMem) -> diff --git a/src/Futhark/Pass/LowerAllocations.hs b/src/Futhark/Pass/LowerAllocations.hs index b600c7d5bc..9c8102f556 100644 --- a/src/Futhark/Pass/LowerAllocations.hs +++ b/src/Futhark/Pass/LowerAllocations.hs @@ -111,15 +111,15 @@ lowerAllocationsInSegOp :: lowerAllocationsInSegOp (SegMap lvl sp tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} -lowerAllocationsInSegOp (SegRed lvl sp binops tps body) = do +lowerAllocationsInSegOp (SegRed lvl sp tps body binops) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty - pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} -lowerAllocationsInSegOp (SegScan lvl sp binops tps body) = do + pure $ SegRed lvl sp tps (body {kernelBodyStms = stms}) binops +lowerAllocationsInSegOp (SegScan lvl sp tps body binops) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty - pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} -lowerAllocationsInSegOp (SegHist lvl sp histops tps body) = do + pure $ SegScan lvl sp tps (body {kernelBodyStms = stms}) binops +lowerAllocationsInSegOp (SegHist lvl sp tps body histops) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty - pure $ SegHist lvl sp histops tps $ body {kernelBodyStms = stms} + pure $ SegHist lvl sp tps (body {kernelBodyStms = stms}) histops lowerAllocationsInHostOp :: HostOp NoOp GPUMem -> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem) lowerAllocationsInHostOp (SegOp op) = SegOp <$> lowerAllocationsInSegOp op