Skip to content

Commit

Permalink
Merge branch 'master' into hip
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Aug 17, 2023
2 parents beebbd9 + 344cd12 commit cb6d433
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/Futhark/Pass/ExtractKernels/Intragroup.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,29 @@ intraGroupBody body = do
stms <- collectStms_ $ intraGroupStms $ bodyStms body
pure $ mkBody stms $ bodyResult body

intraGroupLambda :: Lambda SOACS -> IntraGroupM (Lambda GPU)
intraGroupLambda lam =
mkLambda (lambdaParams lam) $
bodyBind =<< intraGroupBody (lambdaBody lam)

intraGroupWithAccInput :: WithAccInput SOACS -> IntraGroupM (WithAccInput GPU)
intraGroupWithAccInput (shape, arrs, Nothing) =
pure (shape, arrs, Nothing)
intraGroupWithAccInput (shape, arrs, Just (lam, nes)) = do
lam' <- intraGroupLambda lam
pure (shape, arrs, Just (lam', nes))

intraGroupStm :: Stm SOACS -> IntraGroupM ()
intraGroupStm stm@(Let pat aux e) = do
scope <- askScope
let lvl = SegThread SegNoVirt Nothing

case e of
Loop merge form loopbody ->
localScope (scopeOf form') $
localScope (scopeOfFParams $ map fst merge) $ do
loopbody' <- intraGroupBody loopbody
certifying (stmAuxCerts aux) $
letBind pat $
Loop merge form' loopbody'
localScope (scopeOf form' <> scopeOfFParams (map fst merge)) $ do
loopbody' <- intraGroupBody loopbody
certifying (stmAuxCerts aux) . letBind pat $
Loop merge form' loopbody'
where
form' = case form of
ForLoop i it bound inps -> ForLoop i it bound inps
Expand All @@ -206,6 +216,10 @@ intraGroupStm stm@(Let pat aux e) = do
defbody' <- intraGroupBody defbody
certifying (stmAuxCerts aux) . letBind pat $
Match cond cases' defbody' ifdec
WithAcc inputs lam -> do
inputs' <- mapM intraGroupWithAccInput inputs
lam' <- intraGroupLambda lam
certifying (stmAuxCerts aux) . letBind pat $ WithAcc inputs' lam'
Op soac
| "sequential_outer" `inAttrs` stmAuxAttrs aux ->
intraGroupStms . fmap (certify (stmAuxCerts aux))
Expand Down

0 comments on commit cb6d433

Please sign in to comment.