diff --git a/src/Futhark/Analysis/DataDependencies.hs b/src/Futhark/Analysis/DataDependencies.hs index db0d82e7ff..decce9fee8 100644 --- a/src/Futhark/Analysis/DataDependencies.hs +++ b/src/Futhark/Analysis/DataDependencies.hs @@ -2,6 +2,12 @@ module Futhark.Analysis.DataDependencies ( Dependencies, dataDependencies, + depsOf, + depsOf', + depsOfArrays, + depsOfShape, + lambdaDependencies, + reductionDependencies, findNecessaryForReturned, ) where @@ -27,6 +33,23 @@ dataDependencies' :: Dependencies dataDependencies' startdeps = foldl grow startdeps . bodyStms where + grow deps (Let pat _ (WithAcc inputs lam)) = + let input_deps = foldMap depsOfWithAccInput inputs + -- Dependencies of each input reduction are concatenated. + -- Input to lam is cert_1, ..., cert_n, acc_1, ..., acc_n. + lam_deps = lambdaDependencies deps lam (input_deps <> input_deps) + transitive = map (depsOfNames deps) lam_deps + in M.fromList (zip (patNames pat) transitive) `M.union` deps + where + depsOfArrays' shape = + map (\arr -> oneName arr <> depsOfShape shape) + depsOfWithAccInput (shape, arrs, Nothing) = + depsOfArrays' shape arrs + depsOfWithAccInput (shape, arrs, Just (lam', nes)) = + reductionDependencies deps lam' nes (depsOfArrays' shape arrs) + grow deps (Let pat _ (Op op)) = + let op_deps = map (depsOfNames deps) (opDependencies op) + in M.fromList (zip (patNames pat) op_deps) `M.union` deps grow deps (Let pat _ (Match c cases defbody _)) = let cases_deps = map (dataDependencies' deps . caseBody) cases defbody_deps = dataDependencies' deps defbody @@ -50,16 +73,62 @@ dataDependencies' startdeps = foldl grow startdeps . bodyStms in M.unions $ [branchdeps, deps, defbody_deps] ++ cases_deps grow deps (Let pat _ e) = let free = freeIn pat <> freeIn e - freeDeps = mconcat $ map (depsOfVar deps) $ namesToList free - in M.fromList [(name, freeDeps) | name <- patNames pat] `M.union` deps + free_deps = depsOfNames deps free + in M.fromList [(name, free_deps) | name <- patNames pat] `M.union` deps depsOf :: Dependencies -> SubExp -> Names depsOf _ (Constant _) = mempty depsOf deps (Var v) = depsOfVar deps v +depsOf' :: SubExp -> Names +depsOf' (Constant _) = mempty +depsOf' (Var v) = depsOfVar mempty v + depsOfVar :: Dependencies -> VName -> Names depsOfVar deps name = oneName name <> M.findWithDefault mempty name deps +depsOfRes :: Dependencies -> SubExpRes -> Names +depsOfRes deps (SubExpRes _ se) = depsOf deps se + +-- | Extend @names@ with direct dependencies in @deps@. +depsOfNames :: Dependencies -> Names -> Names +depsOfNames deps names = mconcat $ map (depsOfVar deps) $ namesToList names + +depsOfArrays :: SubExp -> [VName] -> [Names] +depsOfArrays size = map (\arr -> oneName arr <> depsOf mempty size) + +depsOfShape :: Shape -> Names +depsOfShape shape = mconcat $ map (depsOf mempty) (shapeDims shape) + +-- | Determine the variables on which the results of applying +-- anonymous function @lam@ to @inputs@ depend. +lambdaDependencies :: + (ASTRep rep) => + Dependencies -> + Lambda rep -> + [Names] -> + [Names] +lambdaDependencies deps lam inputs = + let names_in_scope = freeIn lam <> mconcat inputs + deps_in = M.fromList $ zip (boundByLambda lam) inputs + deps' = dataDependencies' (deps_in <> deps) (lambdaBody lam) + in map + (namesIntersection names_in_scope . depsOfRes deps') + (bodyResult $ lambdaBody lam) + +-- | Like 'lambdaDependencies', but @lam@ is a binary operation +-- with a neutral element. +reductionDependencies :: + (ASTRep rep) => + Dependencies -> + Lambda rep -> + [SubExp] -> + [Names] -> + [Names] +reductionDependencies deps lam nes inputs = + let nes' = map (depsOf deps) nes + in lambdaDependencies deps lam (zipWith (<>) nes' inputs) + -- | @findNecessaryForReturned p merge deps@ computes which of the -- loop parameters (@merge@) are necessary for the result of the loop, -- where @p@ given a loop parameter indicates whether the final value diff --git a/src/Futhark/IR/GPU/Op.hs b/src/Futhark/IR/GPU/Op.hs index 7edd4391fb..60d70dc327 100644 --- a/src/Futhark/IR/GPU/Op.hs +++ b/src/Futhark/IR/GPU/Op.hs @@ -196,6 +196,7 @@ instance Rename SizeOp where instance IsOp SizeOp where safeOp _ = True cheapOp _ = True + opDependencies op = [freeIn op] instance TypedOp SizeOp where opType (GetSize _ _) = pure [Prim int64] @@ -291,6 +292,12 @@ instance (ASTRep rep, IsOp (op rep)) => IsOp (HostOp op rep) where -- transfer scalars to device. SQ.null (bodyStms body) && all ((== 0) . arrayRank) types + opDependencies (SegOp op) = opDependencies op + opDependencies (OtherOp op) = opDependencies op + opDependencies op@(SizeOp {}) = [freeIn op] + opDependencies (GPUBody _ body) = + replicate (length . bodyResult $ body) (freeIn body) + instance (TypedOp (op rep)) => TypedOp (HostOp op rep) where opType (SegOp op) = opType op opType (OtherOp op) = opType op diff --git a/src/Futhark/IR/MC/Op.hs b/src/Futhark/IR/MC/Op.hs index de2299fd8d..333668ee79 100644 --- a/src/Futhark/IR/MC/Op.hs +++ b/src/Futhark/IR/MC/Op.hs @@ -78,6 +78,9 @@ instance (ASTRep rep, IsOp (op rep)) => IsOp (MCOp op rep) where cheapOp (ParOp _ op) = cheapOp op cheapOp (OtherOp op) = cheapOp op + opDependencies (ParOp _ op) = opDependencies op + opDependencies (OtherOp op) = opDependencies op + instance (TypedOp (op rep)) => TypedOp (MCOp op rep) where opType (ParOp _ op) = opType op opType (OtherOp op) = opType op diff --git a/src/Futhark/IR/Mem.hs b/src/Futhark/IR/Mem.hs index 3d44b8a7e4..e37fdb9961 100644 --- a/src/Futhark/IR/Mem.hs +++ b/src/Futhark/IR/Mem.hs @@ -239,6 +239,8 @@ instance (IsOp (inner rep)) => IsOp (MemOp inner rep) where safeOp (Inner k) = safeOp k cheapOp (Inner k) = cheapOp k cheapOp Alloc {} = True + opDependencies op@(Alloc {}) = [freeIn op] + opDependencies (Inner op) = opDependencies op instance (CanBeWise inner) => CanBeWise (MemOp inner) where addOpWisdom (Alloc size space) = Alloc size space diff --git a/src/Futhark/IR/Prop.hs b/src/Futhark/IR/Prop.hs index 0df63c34dd..fd2741661d 100644 --- a/src/Futhark/IR/Prop.hs +++ b/src/Futhark/IR/Prop.hs @@ -191,9 +191,13 @@ class (ASTConstraints op, TypedOp op) => IsOp op where -- | Should we try to hoist this out of branches? cheapOp :: op -> Bool + -- | Compute the data dependencies of an operation. + opDependencies :: op -> [Names] + instance IsOp (NoOp rep) where safeOp NoOp = True cheapOp NoOp = True + opDependencies NoOp = [] -- | Representation-specific attributes; also means the rep supports -- some basic facilities. diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index 4402053462..488694381c 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -58,6 +58,7 @@ import Data.List (intersperse) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.Alias qualified as Alias +import Futhark.Analysis.DataDependencies import Futhark.Analysis.Metrics import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST @@ -69,7 +70,7 @@ import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Rep import Futhark.Transform.Rename import Futhark.Transform.Substitute -import Futhark.Util (chunks, maybeNth) +import Futhark.Util (chunks, maybeNth, splitAt3) import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), ()) import Futhark.Util.Pretty qualified as PP import Prelude hiding (id, (.)) @@ -170,9 +171,13 @@ data Scan rep = Scan } deriving (Eq, Ord, Show) +-- | What are the sizes of reduction results produced by these 'Scan's? +scanSizes :: [Scan rep] -> [Int] +scanSizes = map (length . scanNeutral) + -- | How many reduction results are produced by these 'Scan's? scanResults :: [Scan rep] -> Int -scanResults = sum . map (length . scanNeutral) +scanResults = sum . scanSizes -- | Combine multiple scan operators to a single operator. singleScan :: (Buildable rep) => [Scan rep] -> Scan rep @@ -189,9 +194,13 @@ data Reduce rep = Reduce } deriving (Eq, Ord, Show) +-- | What are the sizes of reduction results produced by these 'Reduce's? +redSizes :: [Reduce rep] -> [Int] +redSizes = map (length . redNeutral) + -- | How many reduction results are produced by these 'Reduce's? redResults :: [Reduce rep] -> Int -redResults = sum . map (length . redNeutral) +redResults = sum . redSizes -- | Combine multiple reduction operators to a single operator. singleReduce :: (Buildable rep) => [Reduce rep] -> Reduce rep @@ -588,6 +597,61 @@ instance CanBeAliased SOAC where instance (ASTRep rep) => IsOp (SOAC rep) where safeOp _ = False cheapOp _ = False + opDependencies (Stream w arrs accs lam) = + let accs_deps = map depsOf' accs + arrs_deps = depsOfArrays w arrs + in lambdaDependencies mempty lam (arrs_deps <> accs_deps) + opDependencies (Hist w arrs ops lam) = + let bucket_fun_deps' = lambdaDependencies mempty lam (depsOfArrays w arrs) + -- Bucket function results are indices followed by values. + -- Reshape this to align with list of histogram operations. + ranks = [length (histShape op) | op <- ops] + value_lengths = [length (histNeutral op) | op <- ops] + (indices, values) = splitAt (sum ranks) bucket_fun_deps' + bucket_fun_deps = + zipWith + concatIndicesToEachValue + (chunks ranks indices) + (chunks value_lengths values) + in mconcat $ zipWith (<>) bucket_fun_deps (map depsOfHistOp ops) + where + depsOfHistOp (HistOp dest_shape rf dests nes op) = + let shape_deps = depsOfShape dest_shape + in_deps = map (\vn -> oneName vn <> shape_deps <> depsOf' rf) dests + in reductionDependencies mempty op nes in_deps + -- A histogram operation may use the same index for multiple values. + concatIndicesToEachValue is vs = + let is_flat = mconcat is + in map (is_flat <>) vs + opDependencies (Scatter w arrs lam outputs) = + let deps = lambdaDependencies mempty lam (depsOfArrays w arrs) + in map flattenGroups (groupScatterResults' outputs deps) + where + flattenGroups (indicess, values) = mconcat indicess <> values + opDependencies (JVP lam args vec) = + mconcat $ + replicate 2 $ + lambdaDependencies mempty lam $ + zipWith (<>) (map depsOf' args) (map depsOf' vec) + opDependencies (VJP lam args vec) = + mconcat $ + replicate 2 $ + lambdaDependencies mempty lam $ + zipWith (<>) (map depsOf' args) (map depsOf' vec) + opDependencies (Screma w arrs (ScremaForm scans reds map_lam)) = + let (scans_in, reds_in, map_deps) = + splitAt3 (scanResults scans) (redResults reds) $ + lambdaDependencies mempty map_lam (depsOfArrays w arrs) + scans_deps = + concatMap depsOfScan (zip scans $ chunks (scanSizes scans) scans_in) + reds_deps = + concatMap depsOfRed (zip reds $ chunks (redSizes reds) reds_in) + in scans_deps <> reds_deps <> map_deps + where + depsOfScan (Scan lam nes, deps_in) = + reductionDependencies mempty lam nes deps_in + depsOfRed (Reduce _ lam nes, deps_in) = + reductionDependencies mempty lam nes deps_in substNamesInType :: M.Map VName SubExp -> Type -> Type substNamesInType _ t@Prim {} = t diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index 3f82effa07..abdae3db38 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -521,37 +521,49 @@ isMapWithOp pat e -- the data dependencies to see that the "dead" result is not -- actually used for computing one of the live ones. removeDeadReduction :: BottomUpRuleOp (Wise SOACS) -removeDeadReduction (_, used) pat aux (Screma w arrs form) - | Just ([Reduce comm redlam nes], maplam) <- isRedomapSOAC form, - not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check - let (red_pes, map_pes) = splitAt (length nes) $ patElems pat, - let redlam_deps = dataDependencies $ lambdaBody redlam, - let redlam_res = bodyResult $ lambdaBody redlam, - let redlam_params = lambdaParams redlam, - let used_after = - map snd . filter ((`UT.used` used) . patElemName . fst) $ - zip red_pes redlam_params, - let necessary = - findNecessaryForReturned - (`elem` used_after) - (zip redlam_params $ map resSubExp $ redlam_res <> redlam_res) - redlam_deps, - let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params, - not $ and (take (length nes) alive_mask) = Simplify $ do - let fixDeadToNeutral lives ne = if lives then Nothing else Just ne - dead_fix = zipWith fixDeadToNeutral alive_mask nes - (used_red_pes, _, used_nes) = - unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $ - zip3 red_pes redlam_params nes - - let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam - redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix) - - auxing aux $ - letBind (Pat $ used_red_pes ++ map_pes) $ - Op $ - Screma w arrs $ - redomapSOAC [Reduce comm redlam' used_nes] maplam' +removeDeadReduction (_, used) pat aux (Screma w arrs form) = + case isRedomapSOAC form of + Just ([Reduce comm redlam rednes], maplam) -> + let mkOp lam nes' = redomapSOAC [Reduce comm lam nes'] + in removeDeadReduction' redlam rednes maplam mkOp + _ -> + case isScanomapSOAC form of + Just ([Scan scanlam nes], maplam) -> + let mkOp lam nes' = scanomapSOAC [Scan lam nes'] + in removeDeadReduction' scanlam nes maplam mkOp + _ -> Skip + where + removeDeadReduction' redlam nes maplam mkOp + | not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check + let (red_pes, map_pes) = splitAt (length nes) $ patElems pat, + let redlam_deps = dataDependencies $ lambdaBody redlam, + let redlam_res = bodyResult $ lambdaBody redlam, + let redlam_params = lambdaParams redlam, + let used_after = + map snd . filter ((`UT.used` used) . patElemName . fst) $ + zip red_pes redlam_params, + let necessary = + findNecessaryForReturned + (`elem` used_after) + (zip redlam_params $ map resSubExp $ redlam_res <> redlam_res) + redlam_deps, + let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params, + not $ and (take (length nes) alive_mask) = Simplify $ do + let fixDeadToNeutral lives ne = if lives then Nothing else Just ne + dead_fix = zipWith fixDeadToNeutral alive_mask nes + (used_red_pes, _, used_nes) = + unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $ + zip3 red_pes redlam_params nes + + let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam + redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix) + + auxing aux $ + letBind (Pat $ used_red_pes ++ map_pes) $ + Op $ + Screma w arrs $ + mkOp redlam' used_nes maplam' + removeDeadReduction' _ _ _ _ = Skip removeDeadReduction _ _ _ _ = Skip -- | If we are writing to an array that is never used, get rid of it. diff --git a/src/Futhark/IR/SegOp.hs b/src/Futhark/IR/SegOp.hs index 2bfe50f782..30f6a0d397 100644 --- a/src/Futhark/IR/SegOp.hs +++ b/src/Futhark/IR/SegOp.hs @@ -1005,6 +1005,8 @@ instance where cheapOp _ = False safeOp _ = True + opDependencies op = + replicate (length . kernelBodyResult $ segBody op) (freeIn op) --- Simplification diff --git a/tests/dependence-analysis/ad_acc.fut b/tests/dependence-analysis/ad_acc.fut new file mode 100644 index 0000000000..d45b858b95 --- /dev/null +++ b/tests/dependence-analysis/ad_acc.fut @@ -0,0 +1,229 @@ +-- See issue 1989. +-- == +-- structure { UpdateAcc 3 } + +def gather1D 't [m] (arr1D: [m]t) (inds: [m]i32) : *[m]t = + map (\ind -> arr1D[ind] ) inds + +def gather2D 't [m][d] (arr2D: [m][d]t) (inds: [m]i32) : *[m][d]t = + map (\ind -> map (\j -> arr2D[ind,j]) (iota d) ) inds + +def sumSqrsSeq [d] (xs: [d]f32) (ys: [d]f32) : f32 = + loop (res) = (0.0f32) for (x,y) in (zip xs ys) do + let z = x-y in res + z*z + +def log2 x = (loop (y,c) = (x,0i32) while y > 1i32 do (y >> 1, c+1)).1 + +def isLeaf (h: i32) (node_index: i32) = + node_index >= ((1 << (h+1)) - 1) + +def findLeaf [q][d] (median_dims: [q]i32) (median_vals: [q]f32) + (height: i32) (query: [d]f32) = + let leaf = + loop (node_index) = (0) + while !(isLeaf height node_index) do + if query[median_dims[node_index]] <= median_vals[node_index] + then (node_index+1)*2-1 + else (node_index+1)*2 + in leaf - i32.i64 q + +def traverseOnce [q] [d] (radius: f32) (height: i32) + (kd_tree: [q](i32,f32,i32)) + (query: [d]f32) + (last_leaf: i32, stack: i32, dist: f32) : (i32, i32, f32) = + + let (median_dims, median_vals, clanc_eqdim) = unzip3 kd_tree + let last_leaf = last_leaf + i32.i64 q + let no_leaf = 2*q + 1 + + let getPackedInd (stk: i32) (ind: i32) : bool = + let b = stk & (1<> (ind+1)) << (ind+1) + let mid = if v then (1 << ind) else 0 + in ( (fst | snd) | mid ) + + let getLevel (node_idx: i32) : i32 = log2 (node_idx+1) + let getAncSameDimContrib (q_m_i: f32) (node_stack: i32) (node: i32) : f32 = + (loop (idx, res) = (node, 0.0f32) + while (idx >= 0) do + let anc = clanc_eqdim[idx] in + if anc == (-1i32) then (-1i32, 0.0f32) + else + let anc_lev = getLevel anc + let is_anc_visited = getPackedInd node_stack anc_lev + in if !is_anc_visited then (anc, res) + else (-1i32, median_vals[anc] - q_m_i) + ).1 + + let (parent_rec, stack, count, dist, rec_node) = + loop (node_index, stack, count, dist, rec_node) = + (last_leaf, stack, height, dist, -1) + for _i2 < height+1 do + if (node_index != 0) && (rec_node < 0) + then + let parent = (node_index-1) / 2 + let scnd_visited = getPackedInd stack count --stack[count] + + let q_m_d = query[median_dims[parent]] + let cur_med_dst = median_vals[parent] - q_m_d + let cur_med_sqr = cur_med_dst * cur_med_dst + + let prv_med_dst = getAncSameDimContrib q_m_d stack parent + let prv_med_sqr = prv_med_dst * prv_med_dst + + let dist_minu = f32.abs(dist - cur_med_sqr + prv_med_sqr) + let dist_plus = f32.abs(dist - prv_med_sqr + cur_med_sqr) + + in if scnd_visited + then -- continue backing-up towards the root + (parent, stack, count-1, dist_minu, -1) + + else -- the node_index is actually the `first` child of parent, + let to_visit = dist_plus <= radius + in if !to_visit + then (parent, stack, count-1, dist, -1) + else -- update the stack + let fst_node = node_index + let snd_node = if (fst_node % 2) == 0 then fst_node-1 else fst_node+1 + let stack = setPackedInd stack count true + in (parent, stack, count, dist_plus, snd_node) + else (node_index, stack, count, dist, rec_node) + let (new_leaf, new_stack, _) = + if parent_rec == 0 && rec_node == -1 + then -- we are done, we are at the root node + (i32.i64 no_leaf, stack, 0) + + else -- now traverse downwards by computing `first` + loop (node_index, stack, count) = + (rec_node, stack, count) + for _i3 < height+1 do + if isLeaf height node_index + then (node_index, stack, count) + else + let count = count+1 + let stack = setPackedInd stack count false + let node_index = + if query[median_dims[node_index]] <= median_vals[node_index] + then (node_index+1)*2-1 + else (node_index+1)*2 + in (node_index, stack, count) + + in (new_leaf-i32.i64 q, new_stack, dist) + +def sortQueriesByLeavesRadix [n] (leaves: [n]i32) : ([n]i32, [n]i32) = + (leaves, map i32.i64 (iota n)) + +def bruteForce [m][d] (radius: f32) + (query: [d]f32) + (query_w: f32) + (leaf_refs : [m][d]f32) + (leaf_ws : [m]f32) + : f32 = + map2(\ref i -> + let dist = sumSqrsSeq query ref + in if dist <= radius then query_w * leaf_ws[i] else 0.0f32 + ) leaf_refs (iota m) + |> reduce (+) 0.0f32 + +def iterationSorted [q][n][d][num_leaves][ppl] + (radius: f32) + (h: i32) + (kd_tree: [q](i32,f32,i32)) + (leaves: [num_leaves][ppl][d]f32) + (ws: [num_leaves][ppl]f32) + (queries: [n][d]f32) + (query_ws:[n]f32) + (qleaves: [n]i32) + (stacks: [n]i32) + (dists: [n]f32) + (query_inds: [n]i32) + (res: f32) + : ([n]i32, [n]i32, [n]f32, [n]i32, f32) = + + let queries_sorted = gather2D queries query_inds + let query_ws_sorted= gather1D query_ws query_inds + + let new_res = + map3 (\ query query_w leaf_ind -> + if leaf_ind >= i32.i64 num_leaves + then 0.0f32 + else bruteForce radius query query_w (leaves[leaf_ind]) (ws[leaf_ind]) + ) queries_sorted query_ws_sorted qleaves + |> reduce (+) 0.0f32 |> opaque + + let (new_leaves, new_stacks, new_dists) = unzip3 <| + map4 (\ query leaf_ind stack dist -> + if leaf_ind >= i32.i64 num_leaves + then + (leaf_ind, stack, dist) + else traverseOnce radius h kd_tree query + (leaf_ind, stack, dist) + ) queries_sorted qleaves stacks dists + |> opaque + + let (qleaves', sort_inds) = sortQueriesByLeavesRadix new_leaves + let stacks' = gather1D new_stacks sort_inds + let dists' = gather1D new_dists sort_inds + let query_inds' = gather1D query_inds sort_inds + in (qleaves', stacks', dists', query_inds', res + new_res) + + +def propagate [m1][m][q][d][n] (radius: f32) + (ref_pts: [m][d]f32) + (indir: [m]i32) + (kd_tree: [q](i32,f32,i32)) + (queries: [n][d]f32) + (query_ws:[n]f32, ref_ws_orig: [m1]f32) + : f32 = + + let kd_weights = + map i64.i32 indir |> + map (\ind -> if ind >= m1 then 1.0f32 else ref_ws_orig[ind]) + + let (median_dims, median_vals, _) = unzip3 kd_tree + let num_nodes = q -- trace q + let num_leaves = num_nodes + 1 + let h = (log2 (i32.i64 num_leaves)) - 1 + let ppl = m / num_leaves + let leaves = unflatten (sized (num_leaves*ppl) ref_pts) + let kd_ws_sort = unflatten (sized (num_leaves*ppl) kd_weights) + + let query_leaves = map (findLeaf median_dims median_vals h) queries + let (qleaves, query_inds) = sortQueriesByLeavesRadix query_leaves + let dists = replicate n 0.0f32 + let stacks = replicate n 0i32 + let res_ws = 0f32 + + let (_qleaves', _stacks', _dists', _query_inds', res_ws') = + loop (qleaves : [n]i32, stacks : [n]i32, dists : [n]f32, query_inds : [n]i32, res_ws : f32) + for _i < 8 do + iterationSorted radius h kd_tree leaves kd_ws_sort queries query_ws qleaves stacks dists query_inds res_ws + + in res_ws' + +def rev_prop [m1][m][q][d][n] (radius: f32) + (ref_pts: [m][d]f32) + (indir: [m]i32) + (kd_tree: [q](i32,f32,i32)) + (queries: [n][d]f32) + (query_ws:[n]f32, ref_ws_orig: [m1]f32) + : (f32, ([n]f32, [m1]f32)) = + let f = propagate radius ref_pts indir kd_tree queries + in vjp2 f (query_ws, ref_ws_orig) 1.0f32 + +def main [d][n][m][m'][q] + (sq_radius: f32) + (queries: [n][d]f32) + (query_ws: [n]f32) + (ref_ws: [m]f32) + (refs_pts : [m'][d]f32) + (indir: [m']i32) + (median_dims : [q]i32) + (median_vals : [q]f32) + (clanc_eqdim : [q]i32) = + let (res, (query_ws_adj, ref_ws_adj)) = + rev_prop sq_radius refs_pts indir (zip3 median_dims median_vals clanc_eqdim) queries (query_ws, ref_ws) + in (res, query_ws_adj, ref_ws_adj) + diff --git a/tests/dependence-analysis/hist0.fut b/tests/dependence-analysis/hist0.fut new file mode 100644 index 0000000000..85d3984a54 --- /dev/null +++ b/tests/dependence-analysis/hist0.fut @@ -0,0 +1,10 @@ +-- == +-- structure { Screma/Hist/BinOp 1 } + +-- The two reduce_by_index get fused into a single histogram operation. +def main [m][n] (A: [m]([n]i32, [n]i32)) = + let r = + loop A for _i < n do + map (\(a, b) -> (reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) a, + reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) b)) A + in map (.0) r diff --git a/tests/dependence-analysis/hist1.fut b/tests/dependence-analysis/hist1.fut new file mode 100644 index 0000000000..650b6677e6 --- /dev/null +++ b/tests/dependence-analysis/hist1.fut @@ -0,0 +1,10 @@ +-- == +-- structure { Screma/Hist 1 } + +-- The two reduce_by_index produce two separate histogram operations. +def main [m][n] (A: [m]([n]i32, [n]i32)) = + let r = + loop A for _i < n do + map (\(a, b) -> (reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) a, + reduce_by_index (replicate n 0) (+) 0 (map i64.i32 b) b)) A + in map (.0) r diff --git a/tests/dependence-analysis/jvp0.fut b/tests/dependence-analysis/jvp0.fut new file mode 100644 index 0000000000..eb02dcd188 --- /dev/null +++ b/tests/dependence-analysis/jvp0.fut @@ -0,0 +1,8 @@ +-- == +-- structure { BinOp 2 } + +def main (A: [](i32,i32)) (n: i64) = + let r = + loop A for _i < n do + jvp (map (\(a,b) -> (a*a,b*b))) A A + in map (.0) r diff --git a/tests/dependence-analysis/map0.fut b/tests/dependence-analysis/map0.fut new file mode 100644 index 0000000000..75b72103dc --- /dev/null +++ b/tests/dependence-analysis/map0.fut @@ -0,0 +1,8 @@ +-- == +-- structure { BinOp 1 } + +def main (A: [](i32,i32)) (n: i64) = + let r = + loop A for i < n do + map (\(a,b) -> if i == 0 then (a,b) else (a+1,b+1)) A + in map (.0) r diff --git a/tests/dependence-analysis/reduce0.fut b/tests/dependence-analysis/reduce0.fut new file mode 100644 index 0000000000..d6ff6f0c70 --- /dev/null +++ b/tests/dependence-analysis/reduce0.fut @@ -0,0 +1,12 @@ +-- == +-- structure { BinOp 2 } + +def plus (a,b) (x,y): (i32,i32) = (a+x,b+y) + +def main (A: [](i32,i32)) (n: i64) = + let r = + loop r' = (0,0) for i < n do + reduce (\(a,b) (x,y) -> if i == 0 then (a,b) else (a+x,b+y)) + (0,0) + (map (plus r') A) + in r.0 diff --git a/tests/dependence-analysis/scan0.fut b/tests/dependence-analysis/scan0.fut new file mode 100644 index 0000000000..fd81241c3b --- /dev/null +++ b/tests/dependence-analysis/scan0.fut @@ -0,0 +1,8 @@ +-- == +-- structure { BinOp 1 } + +def main (A: [](i32,i32)) (n: i64) = + let r = + loop A for i < n do + scan (\(a,b) (x,y) -> if i == 0 then (a,b) else (a+x,b+y)) (0,0) A + in map (.0) r diff --git a/tests/dependence-analysis/scan1.fut b/tests/dependence-analysis/scan1.fut new file mode 100644 index 0000000000..01ffd0d2fe --- /dev/null +++ b/tests/dependence-analysis/scan1.fut @@ -0,0 +1,13 @@ +-- A simple streamSeq; does not exercise the Stream case in opDependencies, +-- but dead code is still removed by the Screma case. +-- == +-- structure { Stream/BinOp 2 } +-- structure { Screma/BinOp 2 } + +def plus (a,b) (x,y): (i32,i32) = (a+x,b+y) + +def main (xs: [](i32,i32)) (n: i64) = + let r = + loop xs for i < n do + map (\(x,y) -> if i == 0 then (x,y) else (x+1,y+2)) (scan plus (0,0) xs) + in map (.0) r diff --git a/tests/dependence-analysis/scatter0.fut b/tests/dependence-analysis/scatter0.fut new file mode 100644 index 0000000000..88db83e49b --- /dev/null +++ b/tests/dependence-analysis/scatter0.fut @@ -0,0 +1,8 @@ +-- == +-- structure { Replicate 1 } + +def main [n] (A: *[n](i32,i32)) = + let r = + loop A for _i < n do + scatter A (iota n) (copy A) + in map (.0) r diff --git a/tests/dependence-analysis/vjp0.fut b/tests/dependence-analysis/vjp0.fut new file mode 100644 index 0000000000..321b6146e2 --- /dev/null +++ b/tests/dependence-analysis/vjp0.fut @@ -0,0 +1,8 @@ +-- == +-- structure { BinOp 2 } + +def main (A: [](i32,i32)) (n: i64) = + let r = + loop A for _i < n do + vjp (map (\(a,b) -> (a*a,b*b))) A A + in map (.0) r