Skip to content

Commit

Permalink
Add TH tuple instances for Simulate
Browse files Browse the repository at this point in the history
  • Loading branch information
lmbollen committed Sep 26, 2024
1 parent 3a4af6e commit 80b24a2
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
5 changes: 4 additions & 1 deletion clash-protocols-base/src/Protocols/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=20 #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- TODO: Hide internal documentation
Expand All @@ -31,7 +32,7 @@ import qualified Clash.Prelude as C

import Protocols.Cpp (maxTupleSize)
import Protocols.Internal.Classes
import Protocols.Internal.TH (protocolTupleInstances)
import Protocols.Internal.TH (protocolTupleInstances, simulateTupleInstances)

import Control.Arrow ((***))
import Data.Coerce (coerce)
Expand Down Expand Up @@ -263,6 +264,8 @@ instance (Simulate a, Simulate b) => Simulate (a, b) where
in
((fwdL1, fwdR1), (bwdL1, bwdR1))

simulateTupleInstances 3 maxTupleSize

instance (Drivable a, Drivable b) => Drivable (a, b) where
type ExpectType (a, b) = (ExpectType a, ExpectType b)

Expand Down
77 changes: 76 additions & 1 deletion clash-protocols-base/src/Protocols/Internal/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

module Protocols.Internal.TH where

import qualified Clash.Prelude as C
import Control.Monad.Extra (concatMapM)
import GHC.TypeNats
import Language.Haskell.TH

import Protocols.Internal.Classes

appTs :: Q Type -> [Q Type] -> Q Type
Expand Down Expand Up @@ -62,3 +63,77 @@ idleCircuitTupleInstance n =
mkFwdExpr ty = [e|idleFwd $ Proxy @($ty)|]
bwdExpr = tupE $ map mkBwdExpr circTys
mkBwdExpr ty = [e|idleBwd $ Proxy @($ty)|]

simulateTupleInstances :: Int -> Int -> DecsQ
simulateTupleInstances n m = concatMapM simulateTupleInstance [n .. m]

simulateTupleInstance :: Int -> DecsQ
simulateTupleInstance n =
[d|
instance ($instCtx) => Simulate $instTy where
type SimulateFwdType $instTy = $fwdType
type SimulateBwdType $instTy = $bwdType
type SimulateChannels $instTy = $channelSum

simToSigFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigFwd (Proxy @($ty)) $expr|]) circTys fwdExpr)
simToSigBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigBwd (Proxy @($ty)) $expr|]) circTys bwdExpr)
sigToSimFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimFwd (Proxy @($ty)) $expr|]) circTys fwdExpr)
sigToSimBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimBwd (Proxy @($ty)) $expr|]) circTys bwdExpr)

stallC $(varP $ mkName "conf") $(varP $ mkName "rem0") = $(letE (stallVecs ++ stallCircuits) stallCExpr)
|]
where
-- Generate the types for the instance
circTys = map (\i -> varT $ mkName $ "c" <> show i) [1 .. n]
instTy = foldl appT (tupleT n) circTys
instCtx = foldl appT (tupleT n) $ map (\ty -> [t|Simulate $ty|]) circTys
fwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateFwdType $ty|]) circTys
bwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateBwdType $ty|]) circTys
channelSum = foldl1 (\a b -> [t|$a + $b|]) $ map (\ty -> [t|SimulateChannels $ty|]) circTys

-- Relevant expressions and patterns
fwdPat0 = tupP $ map (\i -> varP $ mkName $ "fwd" <> show i) [1 .. n]
bwdPat0 = tupP $ map (\i -> varP $ mkName $ "bwd" <> show i) [1 .. n]
fwdExpr = map (\i -> varE $ mkName $ "fwd" <> show i) [1 .. n]
bwdExpr = map (\i -> varE $ mkName $ "bwd" <> show i) [1 .. n]
fwdExpr1 = map (\i -> varE $ mkName $ "fwdStalled" <> show i) [1 .. n]
bwdExpr1 = map (\i -> varE $ mkName $ "bwdStalled" <> show i) [1 .. n]

-- stallC Declaration: Split off the stall vectors from the large input vector
stallVecs = zipWith mkStallVec [1 .. n] circTys
mkStallVec i ty =
valD
mkStallPat
( normalB [e|(C.splitAtI @(SimulateChannels $ty) $(varE (mkName $ "rem" <> show (i - 1))))|]
)
[]
where
mkStallPat =
tupP
[ varP (mkName $ "stalls" <> show i)
, varP (mkName $ if i == n then "_" else "rem" <> show i)
]

-- stallC Declaration: Generate stalling circuits
stallCircuits = zipWith mkStallCircuit [1 .. n] circTys
mkStallCircuit i ty =
valD
[p|Circuit $(varP $ mkName $ "stalled" <> show i)|]
(normalB [e|stallC @($ty) conf $(varE $ mkName $ "stalls" <> show i)|])
[]

-- Generate the stallC expression
stallCExpr =
[e|
Circuit $ \($fwdPat0, $bwdPat0) -> $(letE stallCResultDecs [e|($(tupE fwdExpr1), $(tupE bwdExpr1))|])
|]

stallCResultDecs = map mkStallCResultDec [1 .. n]
mkStallCResultDec i =
valD
(tupP [varP $ mkName $ "fwdStalled" <> show i, varP $ mkName $ "bwdStalled" <> show i])
( normalB $
appE (varE $ mkName $ "stalled" <> show i) $
tupE [varE $ mkName $ "fwd" <> show i, varE $ mkName $ "bwd" <> show i]
)
[]

0 comments on commit 80b24a2

Please sign in to comment.