diff --git a/clash-protocols/src/Protocols.hs b/clash-protocols/src/Protocols.hs index 0fbf89b5..6adb1472 100644 --- a/clash-protocols/src/Protocols.hs +++ b/clash-protocols/src/Protocols.hs @@ -66,3 +66,4 @@ module Protocols ( import Data.Default (def) import Protocols.Df (Df) import Protocols.Internal +import Protocols.Plugin diff --git a/clash-protocols/src/Protocols/Internal.hs b/clash-protocols/src/Protocols/Internal.hs index f42e726b..67b56b80 100644 --- a/clash-protocols/src/Protocols/Internal.hs +++ b/clash-protocols/src/Protocols/Internal.hs @@ -34,7 +34,11 @@ import qualified Clash.Explicit.Prelude as CE import Clash.Prelude (type (*), type (+)) import qualified Clash.Prelude as C -import Protocols.Internal.TH (simulateTupleInstances) +import Protocols.Internal.TH ( + backPressureTupleInstances, + drivableTupleInstances, + simulateTupleInstances, + ) import Protocols.Internal.Types import Protocols.Plugin import Protocols.Plugin.Cpp (maxTupleSize) @@ -97,13 +101,7 @@ instance Backpressure () where instance (Backpressure a, Backpressure b) => Backpressure (a, b) where boolsToBwd _ bs = (boolsToBwd (Proxy @a) bs, boolsToBwd (Proxy @b) bs) -instance (Backpressure a, Backpressure b, Backpressure c) => Backpressure (a, b, c) where - boolsToBwd _ bs = - ( boolsToBwd (Proxy @a) bs - , boolsToBwd (Proxy @b) bs - , boolsToBwd (Proxy @c) bs - ) - +backPressureTupleInstances 3 maxTupleSize instance (C.KnownNat n, Backpressure a) => Backpressure (C.Vec n a) where boolsToBwd _ bs = C.repeat (boolsToBwd (Proxy @a) bs) @@ -265,10 +263,7 @@ instance (Drivable a, Drivable b) => Drivable (a, b) where , sampleC @b conf (Circuit $ \_ -> ((), fwd2)) ) --- TODO TemplateHaskell? --- instance SimulateType (a, b, c) --- instance SimulateType (a, b, c, d) - +drivableTupleInstances 3 maxTupleSize instance (CE.KnownNat n, Simulate a) => Simulate (C.Vec n a) where type SimulateFwdType (C.Vec n a) = C.Vec n (SimulateFwdType a) type SimulateBwdType (C.Vec n a) = C.Vec n (SimulateBwdType a) diff --git a/clash-protocols/src/Protocols/Internal/TH.hs b/clash-protocols/src/Protocols/Internal/TH.hs index 02b7d290..4dc7586a 100644 --- a/clash-protocols/src/Protocols/Internal/TH.hs +++ b/clash-protocols/src/Protocols/Internal/TH.hs @@ -4,6 +4,7 @@ module Protocols.Internal.TH where import qualified Clash.Prelude as C import Control.Monad.Extra (concatMapM) +import Data.Proxy import GHC.TypeNats import Language.Haskell.TH import Protocols.Internal.Types @@ -108,3 +109,69 @@ simulateTupleInstance n = tupE [varE $ mkName $ "fwd" <> show i, varE $ mkName $ "bwd" <> show i] ) [] + +drivableTupleInstances :: Int -> Int -> DecsQ +drivableTupleInstances n m = concatMapM drivableTupleInstance [n .. m] + +drivableTupleInstance :: Int -> DecsQ +drivableTupleInstance n = + [d| + instance ($instCtx) => Drivable $instTy where + type + ExpectType $instTy = + $(foldl appT (tupleT n) $ map (\ty -> [t|ExpectType $ty|]) circTys) + toSimulateType Proxy $(tupP circPats) = $toSimulateExpr + + fromSimulateType Proxy $(tupP circPats) = $fromSimulateExpr + + driveC $(varP $ mkName "conf") $(tupP fwdPats) = $(letE driveCDecs driveCExpr) + sampleC conf (Circuit f) = + let + $(varP $ mkName "bools") = replicate (resetCycles conf) False <> repeat True + (_, $(tupP fwdPats)) = f ((), $(tupE $ map mkSampleCExpr circTys)) + in + $( tupE $ + zipWith (\ty fwd -> [|sampleC @($ty) conf (Circuit $ const ((), $fwd))|]) circTys fwdExprs + ) + |] + where + circStrings = map (\i -> "c" <> show i) [1 .. n] + circTys = map (varT . mkName) circStrings + circPats = map (varP . mkName) circStrings + circExprs = map (varE . mkName) circStrings + instCtx = foldl appT (tupleT n) $ map (\ty -> [t|Drivable $ty|]) circTys + instTy = foldl appT (tupleT n) circTys + fwdPats = map (varP . mkName . ("fwd" <>)) circStrings + fwdExprs = map (varE . mkName . ("fwd" <>)) circStrings + bwdExprs = map (varE . mkName . ("bwd" <>)) circStrings + bwdPats = map (varP . mkName . ("bwd" <>)) circStrings + + mkSampleCExpr ty = [e|boolsToBwd (Proxy @($ty)) bools|] + driveCDecs = + pure $ + valD + (tupP $ map (\p -> [p|(Circuit $p)|]) circPats) + (normalB $ tupE $ zipWith (\ty fwd -> [e|driveC @($ty) conf $fwd|]) circTys fwdExprs) + [] + + driveCExpr = + [e| + Circuit $ \(_, $(tildeP $ tupP bwdPats)) -> ((), $(tupE $ zipWith mkDriveCExpr circExprs bwdExprs)) + |] + mkDriveCExpr c bwd = [e|snd ($c ((), $bwd))|] + toSimulateExpr = tupE $ zipWith (\ty c -> [|toSimulateType (Proxy @($ty)) $c|]) circTys circExprs + fromSimulateExpr = tupE $ zipWith (\ty c -> [|fromSimulateType (Proxy @($ty)) $c|]) circTys circExprs + +backPressureTupleInstances :: Int -> Int -> DecsQ +backPressureTupleInstances n m = concatMapM backPressureTupleInstance [n .. m] + +backPressureTupleInstance :: Int -> DecsQ +backPressureTupleInstance n = + [d| + instance ($instCtx) => Backpressure $instTy where + boolsToBwd _ bs = $(tupE $ map (\ty -> [e|boolsToBwd (Proxy @($ty)) bs|]) circTys) + |] + where + circTys = map (\i -> varT $ mkName $ "c" <> show i) [1 .. n] + instCtx = foldl appT (tupleT n) $ map (\ty -> [t|Backpressure $ty|]) circTys + instTy = foldl appT (tupleT n) circTys