Skip to content

Commit

Permalink
[ perf ] Calculate total weight with no lazy singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
buzden authored Sep 26, 2023
2 parents b3789e3 + cb8dc8a commit ae67fb3
Showing 1 changed file with 38 additions and 45 deletions.
83 changes: 38 additions & 45 deletions src/Test/DepTyCheck/Gen.idr
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ record RawGen a where
constructor MkRawGen
unRawGen : forall m. MonadRandom m => CanManageLabels m => m a

record OneOfAlternatives (0 em : Emptiness) (0 a : Type)
export
record GenAlternatives (0 mustBeNotEmpty : Bool) (em : Emptiness) (a : Type)

export
data Gen : Emptiness -> Type -> Type where
Expand All @@ -64,18 +65,20 @@ data Gen : Emptiness -> Type -> Type where

OneOf : alem `NoWeaker` em =>
NotImmediatelyEmpty alem =>
OneOfAlternatives alem a -> Gen em a
GenAlternatives True alem a -> Gen em a

Bind : {biem : _} ->
(0 _ : BindToOuter biem em) =>
RawGen c -> (c -> Gen biem a) -> Gen em a

Labelled : Label -> Gen em a -> Gen em a

record OneOfAlternatives (0 em : Emptiness) (0 a : Type) where
constructor MkOneOf
gens : LazyLst1 (PosNat, Lazy (Gen em a))
totalWeight : Lazy (Singleton $ foldl1 (+) (gens <&> \x => fst x))
record GenAlternatives (0 mustBeNotEmpty : Bool) (em : Emptiness) (a : Type) where
constructor MkGenAlts
unGenAlts : LazyLst mustBeNotEmpty (PosNat, Lazy (Gen em a))

(.totalWeight) : GenAlternatives True em a -> PosNat
(.totalWeight) oo = foldl1 (+) (oo.unGenAlts <&> \x => fst x)

public export %inline
Gen1 : Type -> Type
Expand Down Expand Up @@ -121,7 +124,7 @@ data Equiv : Gen lem a -> Gen rem a -> Type where
EE : Empty `Equiv` Empty
EP : Pure x `Equiv` Pure x
ER : Raw x `Equiv` Raw x
EO : lgs `AltsEquiv` rgs => OneOf @{lalemem} @{lalemcd} (MkOneOf lgs _) `Equiv` OneOf @{ralemem} @{ralemcd} (MkOneOf rgs _)
EO : lgs `AltsEquiv` rgs => OneOf @{lalemem} @{lalemcd} (MkGenAlts lgs) `Equiv` OneOf @{ralemem} @{ralemcd} (MkGenAlts rgs)
EB : Bind @{lbo} x g `Equiv` Bind @{rbo} x g

data AltsEquiv : LazyLst lne (PosNat, Lazy (Gen lem a)) -> LazyLst rne (PosNat, Lazy (Gen lem a)) -> Type where
Expand All @@ -143,27 +146,22 @@ data AltsEquiv : LazyLst lne (PosNat, Lazy (Gen lem a)) -> LazyLst rne (PosNat,
mapTaggedLazy : (a -> b) -> LazyLst ne (tag, Lazy a) -> LazyLst ne (tag, Lazy b)
mapTaggedLazy = map . mapSnd . wrapLazy

mapOneOf : OneOfAlternatives iem a -> (Gen iem a -> Gen em b) -> OneOfAlternatives em b
mapOneOf (MkOneOf gs tw) f = MkOneOf (mapTaggedLazy f gs) $ do
rewrite mapFusion (Builtin.fst) (mapSnd $ wrapLazy f) gs
transport tw $ cong (Lazy.foldl1 (+)) $ mapExt gs $ \(_, _) => Refl
mapOneOf : GenAlternatives ne iem a -> (Gen iem a -> Gen em b) -> GenAlternatives ne em b
mapOneOf (MkGenAlts gs) f = MkGenAlts $ mapTaggedLazy f gs

traverseMaybe : (a -> Maybe b) -> LazyLst ne a -> Maybe $ LazyLst ne b
traverseMaybe f [] = Just []
traverseMaybe f (x::xs) = case f x of
Nothing => Nothing
Just y => (y ::) <$> traverseMaybe f xs

trMTaggedLazy : (a -> Maybe b) -> LazyLst1 (tag, Lazy a) -> Maybe $ LazyLst1 (tag, Lazy b)
trMTaggedLazy : (a -> Maybe b) -> LazyLst ne (tag, Lazy a) -> Maybe $ LazyLst ne (tag, Lazy b)
trMTaggedLazy = traverseMaybe . m . wrapLazy where
m : (Lazy a -> Lazy (Maybe b)) -> (tag, Lazy a) -> Maybe (tag, (Lazy b))
m f (tg, lz) = (tg,) . delay <$> f lz

-- TODO to make the proof properly
trMOneOf : OneOfAlternatives iem a -> (Gen iem a -> Maybe $ Gen em b) -> Maybe $ OneOfAlternatives em b
trMOneOf (MkOneOf gs tw) f with (trMTaggedLazy f gs) proof trm
_ | Nothing = Nothing
_ | Just gs' = Just $ MkOneOf gs' $ believe_me tw
trMOneOf : GenAlternatives ne iem a -> (Gen iem a -> Maybe $ Gen em b) -> Maybe $ GenAlternatives ne em b
trMOneOf (MkGenAlts gs) f = MkGenAlts <$> trMTaggedLazy f gs

-----------------------------
--- Emptiness tweakenings ---
Expand Down Expand Up @@ -213,7 +211,7 @@ mkOneOf : alem `NoWeaker` em =>
NotImmediatelyEmpty alem =>
(gens : LazyLst1 (PosNat, Lazy (Gen alem a))) ->
Gen em a
mkOneOf gens = OneOf $ MkOneOf gens $ Val _
mkOneOf gens = OneOf $ MkGenAlts gens
-- TODO to make elimination of a single element

--------------------------
Expand All @@ -226,7 +224,7 @@ export
unGen1 : MonadRandom m => CanManageLabels m => Gen1 a -> m a
unGen1 $ Pure x = pure x
unGen1 $ Raw sf = sf.unRawGen
unGen1 $ OneOf @{NN} oo = assert_total unGen1 . force . pickWeighted oo.gens . finToNat =<< randomFin oo.totalWeight.unVal
unGen1 $ OneOf @{NN} oo = assert_total unGen1 . force . pickWeighted oo.unGenAlts . finToNat =<< randomFin oo.totalWeight
unGen1 $ Bind @{bo} x f = case extractNE bo of Refl => x.unRawGen >>= unGen1 . f
unGen1 $ Labelled l x = manageLabel l $ unGen1 x

Expand All @@ -247,7 +245,7 @@ unGen : MonadRandom m => MonadError () m => CanManageLabels m => Gen em a -> m a
unGen $ Empty = throwError ()
unGen $ Pure x = pure x
unGen $ Raw sf = sf.unRawGen
unGen $ OneOf oo = assert_total unGen . force . pickWeighted oo.gens . finToNat =<< randomFin oo.totalWeight.unVal
unGen $ OneOf oo = assert_total unGen . force . pickWeighted oo.unGenAlts . finToNat =<< randomFin oo.totalWeight
unGen $ Bind x f = x.unRawGen >>= unGen . f
unGen $ Labelled l x = manageLabel l $ unGen x

Expand Down Expand Up @@ -336,7 +334,7 @@ export
Raw g >>= nf = Bind @{reflexive} g nf
(OneOf @{ao} oo >>= nf) {em=NonEmpty} with (ao) _ | NN = OneOf $ mapOneOf oo $ assert_total (>>= nf)
(OneOf @{ao} oo >>= nf) {em=MaybeEmptyDeep} = OneOf $ mapOneOf oo $ assert_total (>>= nf) . relax @{ao}
(OneOf {alem} (MkOneOf gs _) >>= nf) {em=MaybeEmpty} = maybe Empty (mkOneOf {alem=MaybeEmptyDeep}) $
(OneOf {alem} (MkGenAlts gs) >>= nf) {em=MaybeEmpty} = maybe Empty (mkOneOf {alem=MaybeEmptyDeep}) $
strengthen $ flip mapMaybe gs $ traverse $ map delay . strengthen . assert_total (>>= nf) . relax . force
Bind {biem} x f >>= nf with (order {rel=NoWeaker} biem em)
_ | Left _ = Bind x $ \x => assert_total $ relax (f x) >>= nf
Expand All @@ -349,14 +347,9 @@ export

namespace GenAlternatives

export
record GenAlternatives (0 mustBeNotEmpty : Bool) (em : Emptiness) a where
constructor MkGenAlternatives
unGenAlternatives : LazyLst mustBeNotEmpty (PosNat, Lazy (Gen em a))

export %inline
Nil : GenAlternatives False em a
Nil = MkGenAlternatives []
Nil = MkGenAlts []

export %inline
(::) : {em : _} ->
Expand All @@ -367,7 +360,7 @@ namespace GenAlternatives
(0 _ : IfUnsolved lem em) =>
(0 _ : IfUnsolved rem em) =>
Lazy (Gen lem a) -> Lazy (GenAlternatives e rem a) -> GenAlternatives ne em a
x :: xs = MkGenAlternatives $ (1, relax x) :: mapTaggedLazy relax xs.unGenAlternatives
x :: xs = MkGenAlts $ (1, relax x) :: mapTaggedLazy relax xs.unGenAlts

-- This concatenation breaks relative proportions in frequences of given alternative lists
public export %inline
Expand All @@ -379,41 +372,41 @@ namespace GenAlternatives
(0 _ : IfUnsolved nel False) =>
(0 _ : IfUnsolved ner False) =>
GenAlternatives nel lem a -> Lazy (GenAlternatives ner rem a) -> GenAlternatives (nel || ner) em a
xs ++ ys = MkGenAlternatives $ mapTaggedLazy relax xs.unGenAlternatives ++ mapTaggedLazy relax ys.unGenAlternatives
xs ++ ys = MkGenAlts $ mapTaggedLazy relax xs.unGenAlts ++ mapTaggedLazy relax ys.unGenAlts

public export %inline
length : GenAlternatives ne em a -> Nat
length $ MkGenAlternatives alts = length alts
length $ MkGenAlts alts = length alts

export %inline
processAlternatives : (Gen em a -> Gen em b) -> GenAlternatives ne em a -> GenAlternatives ne em b
processAlternatives f $ MkGenAlternatives xs = MkGenAlternatives $ xs <&> mapSnd (wrapLazy f)
processAlternatives = flip mapOneOf

export %inline
processAlternativesMaybe : (Gen em a -> Maybe $ Lazy (Gen em b)) -> GenAlternatives ne em a -> GenAlternatives False em b
processAlternativesMaybe f $ MkGenAlternatives xs = MkGenAlternatives $ mapMaybe (\(t, x) => (t,) <$> f x) xs
processAlternativesMaybe f $ MkGenAlts xs = MkGenAlts $ mapMaybe (\(t, x) => (t,) <$> f x) xs

export %inline
processAlternatives'' : (Gen em a -> GenAlternatives neb em b) -> GenAlternatives nea em a -> GenAlternatives (nea && neb) em b
processAlternatives'' f = mapGens where
mapWeight : forall a, nea. (PosNat -> PosNat) -> GenAlternatives nea em a -> GenAlternatives nea em a
mapWeight f $ MkGenAlternatives xs = MkGenAlternatives $ xs <&> mapFst f
mapWeight f $ MkGenAlts xs = MkGenAlts $ xs <&> mapFst f
mapGens : GenAlternatives nea em a -> GenAlternatives (nea && neb) em b
mapGens $ MkGenAlternatives xs = MkGenAlternatives $ xs `bind` \(w, x) => unGenAlternatives $ mapWeight (w *) $ f x
mapGens $ MkGenAlts xs = MkGenAlts $ xs `bind` \(w, x) => unGenAlts $ mapWeight (w *) $ f x
export %inline
processAlternatives' : (Gen em a -> GenAlternatives ne em b) -> GenAlternatives ne em a -> GenAlternatives ne em b
processAlternatives' f xs = rewrite sym $ andSameNeutral ne in processAlternatives'' f xs
export %inline
relax : GenAlternatives True em a -> GenAlternatives ne em a
relax $ MkGenAlternatives alts = MkGenAlternatives $ relaxT alts
relax $ MkGenAlts alts = MkGenAlts $ relaxT alts
export %inline
strengthen : GenAlternatives ne em a -> Maybe $ GenAlternatives True em a
strengthen $ MkGenAlternatives xs = MkGenAlternatives <$> strengthen xs
strengthen $ MkGenAlts xs = MkGenAlts <$> strengthen xs
export
Functor (GenAlternatives ne em) where
Expand All @@ -427,13 +420,13 @@ namespace GenAlternatives
export
{em : _} -> Alternative (GenAlternatives False em) where
empty = []
xs <|> ys = MkGenAlternatives $ xs.unGenAlternatives <|> ys.unGenAlternatives
xs <|> ys = MkGenAlts $ xs.unGenAlts <|> ys.unGenAlts

-- implementation for `Monad` is below --

export
{em : _} -> Cast (LazyLst ne a) (GenAlternatives ne em a) where
cast = MkGenAlternatives . map (\x => (1, pure x))
cast = MkGenAlts . map (\x => (1, pure x))

public export %inline
altsFromList : {em : _} -> LazyLst ne a -> GenAlternatives ne em a
Expand Down Expand Up @@ -468,9 +461,9 @@ oneOf : {em : _} ->
(0 _ : IfUnsolved alem em) =>
(0 _ : IfUnsolved altsNe $ em /= MaybeEmpty) =>
GenAlternatives altsNe alem a -> Gen em a
oneOf {em=NonEmpty} @{NN} @{NT} $ MkGenAlternatives xs = mkOneOf xs
oneOf {em=MaybeEmptyDeep} @{_} @{DT} x = case x of MkGenAlternatives xs => mkOneOf xs
oneOf {em=MaybeEmpty} x = case x of MkGenAlternatives xs => do
oneOf {em=NonEmpty} @{NN} @{NT} $ MkGenAlts xs = mkOneOf xs
oneOf {em=MaybeEmptyDeep} @{_} @{DT} x = case x of MkGenAlts xs => mkOneOf xs
oneOf {em=MaybeEmpty} x = case x of MkGenAlts xs => do
maybe Empty mkOneOf $ strengthen $ flip mapMaybe xs $
\wg => (fst wg,) . delay <$> Gen.strengthen {em=MaybeEmptyDeep} (snd wg)

Expand All @@ -487,7 +480,7 @@ frequency : {em : _} ->
(0 _ : IfUnsolved alem em) =>
(0 _ : IfUnsolved altsNe $ em /= MaybeEmpty) =>
LazyLst altsNe (PosNat, Lazy (Gen alem a)) -> Gen em a
frequency = oneOf . MkGenAlternatives
frequency = oneOf . MkGenAlts

||| Choose one of the given values uniformly.
|||
Expand All @@ -512,7 +505,7 @@ elements' xs = elements $ relaxF $ fromList $ toList xs

export
alternativesOf : {em : _} -> Gen em a -> GenAlternatives True em a
alternativesOf $ OneOf oo = MkGenAlternatives $ gens $ mapOneOf oo relax
alternativesOf $ OneOf oo = MkGenAlts $ unGenAlts $ mapOneOf oo relax
alternativesOf $ Labelled l x = processAlternatives (label l) $ alternativesOf x
alternativesOf g = [g]

Expand Down Expand Up @@ -540,8 +533,8 @@ forgetAlternatives g@(OneOf {}) = case canBeNotImmediatelyEmpty em of
Left Refl => maybe Empty single $ strengthen {em=MaybeEmptyDeep} g
where
%inline single : iem `NoWeaker` MaybeEmptyDeep => iem `NoWeaker` em => Gen iem a -> Gen em a
single g = label "forgetAlternatives" $ OneOf $ MkOneOf [(1, g)] $ Val _
-- `mkOneOf` is not used here intentionally, since if `mkOneOf` is changed to eliminate single-element `MkOneOf`'s, we still want such behaviour here.
single g = label "forgetAlternatives" $ OneOf $ MkGenAlts [(1, g)]
-- `mkOneOf` is not used here intentionally, since if `mkOneOf` can be changed to eliminate single-element `MkGenAlts`'s, we still want such behaviour here.
forgetAlternatives (Labelled l x) = label l $ forgetAlternatives x
forgetAlternatives g = g

Expand Down

0 comments on commit ae67fb3

Please sign in to comment.