Skip to content

Commit

Permalink
Keep all casts, and cast (Signal a ~ a) where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
christiaanb committed Feb 9, 2020
1 parent ea4d877 commit 598039e
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 98 deletions.
3 changes: 3 additions & 0 deletions clash-ghc/src-ghc/Clash/GHC/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3465,6 +3465,9 @@ naturalLiteral v =
DC dc [Left (Literal (ByteArrayLiteral (Vector.Vector _ _ (ByteArray.ByteArray ba))))]
| dcTag dc == 2
-> Just (Jp# (BN# ba))
CastValue v0 _ _
| Just n <- naturalLiteral v0
-> Just n
_ -> Nothing

integerLiterals' :: [Value] -> [Integer]
Expand Down
90 changes: 60 additions & 30 deletions clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ import TyCon (AlgTyConRhs (..), TyCon, tyConName,
tyConArity,
tyConDataCons, tyConKind,
tyConName, tyConUnique, isClassTyCon)
import Type (mkTvSubstPrs, substTy, coreView)
import Type (mkTvSubstPrs, substTy, coreView, piResultTys)
import TyCoRep (Coercion (..), TyLit (..), Type (..))
import Unique (Uniquable (..), Unique, getKey, hasKey)
import Var (Id, TyVar, Var, idDetails,
Expand Down Expand Up @@ -288,71 +288,101 @@ coreToTerm primMap unlocs = term
, let (nm, _) = RWS.evalRWS (qualifiedNameString (varName x))
noSrcSpan
emptyGHC2CoreState
= go nm args
= go nm (varType x) args
| otherwise
= term' e
where
-- Remove most Signal transformers
go "Clash.Signal.Internal.mapSignal#" args
| length args == 5
= term (App (args!!3) (args!!4))
go "Clash.Signal.Internal.signal#" args
| length args == 3
= term (args!!2)
go "Clash.Signal.Internal.appSignal#" args
| length args == 5
= term (App (args!!3) (args!!4))
go "Clash.Signal.Internal.joinSignal#" args
go "Clash.Signal.Internal.mapSignal#" pTy args
| [Type aTy, Type bTy, Type domTy, fTm, aSigTm] <- args
= do
let aSigTy = piResultTys pTy [bTy,aTy,domTy,aTy,aTy]
bSigTy = piResultTys pTy [aTy,bTy,domTy,bTy,bTy]
aTyC <- coreToType aTy
bTyC <- coreToType bTy
aSigTyC <- coreToType aSigTy
bSigTyC <- coreToType bSigTy
C.Cast <$> (C.App <$> term fTm
<*> (C.Cast <$> term aSigTm
<*> pure aSigTyC
<*> pure aTyC))
<*> pure bTyC
<*> pure bSigTyC
go "Clash.Signal.Internal.signal#" pty args
| [Type aTy, Type domTy, aTm] <- args
= let aSigTy = piResultTys pty [aTy,domTy,aTy]
in C.Cast <$> term aTm <*> coreToType aTy <*> coreToType aSigTy
go "Clash.Signal.Internal.appSignal#" pTy args
| [Type domTy, Type aTy, Type bTy, fSigTm, aSigTm] <- args
= do
let aSigTy = piResultTys pTy [domTy,bTy,aTy,aTy,aTy]
bSigTy = piResultTys pTy [domTy,aTy,bTy,bTy,bTy]
fSigTy = piResultTys pTy [domTy,aTy,FunTy aTy bTy,aTy,aTy]
aTyC <- coreToType aTy
bTyC <- coreToType bTy
aSigTyC <- coreToType aSigTy
bSigTyC <- coreToType bSigTy
fSigTyC <- coreToType fSigTy
let fTyC = C.mkFunTy aTyC bTyC
C.Cast <$> (C.App <$> (C.Cast <$> term fSigTm
<*> pure fSigTyC
<*> pure fTyC)
<*> (C.Cast <$> term aSigTm
<*> pure aSigTyC
<*> pure aTyC))
<*> pure bTyC
<*> pure bSigTyC
go "Clash.Signal.Internal.joinSignal#" _ args
| length args == 3
= term (args!!2)
go "Clash.Signal.Bundle.vecBundle#" args
go "Clash.Signal.Bundle.vecBundle#" _ args
| length args == 4
= term (args!!3)
--- Remove `$`
go "GHC.Base.$" args
go "GHC.Base.$" _ args
| length args == 5
= term (App (args!!3) (args!!4))
go "GHC.Magic.noinline" args -- noinline :: forall a. a -> a
go "GHC.Magic.noinline" _ args -- noinline :: forall a. a -> a
| [_ty, x] <- args
= term x
-- Remove most CallStack logic
go "GHC.Stack.Types.PushCallStack" args = term (last args)
go "GHC.Stack.Types.FreezeCallStack" args = term (last args)
go "GHC.Stack.withFrozenCallStack" args
go "GHC.Stack.Types.PushCallStack" _ args = term (last args)
go "GHC.Stack.Types.FreezeCallStack" _ args = term (last args)
go "GHC.Stack.withFrozenCallStack" _ args
| length args == 3
= term (App (args!!2) (args!!1))
go "Clash.Class.BitPack.packXWith" args
go "Clash.Class.BitPack.packXWith" _ args
| [_nTy,_aTy,_kn,f] <- args
= term f
go "Clash.Sized.BitVector.Internal.checkUnpackUndef" args
go "Clash.Sized.BitVector.Internal.checkUnpackUndef" _ args
| [_nTy,_aTy,_kn,_typ,f] <- args
= term f
go "Clash.Magic.prefixName" args
go "Clash.Magic.prefixName" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.PrefixName <$> coreToType nmTy) <*> term f
go "Clash.Magic.suffixName" args
go "Clash.Magic.suffixName" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f
go "Clash.Magic.suffixNameFromNat" args
go "Clash.Magic.suffixNameFromNat" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f
go "Clash.Magic.suffixNameP" args
go "Clash.Magic.suffixNameP" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f
go "Clash.Magic.suffixNameFromNatP" args
go "Clash.Magic.suffixNameFromNatP" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f
go "Clash.Magic.setName" args
go "Clash.Magic.setName" _ args
| [Type nmTy,_aTy,f] <- args
= C.Tick <$> (C.NameMod C.SetName <$> coreToType nmTy) <*> term f
go "Clash.Magic.deDup" args
go "Clash.Magic.deDup" _ args
| [_aTy,f] <- args
= C.Tick C.DeDup <$> term f
go "Clash.Magic.noDeDup" args
go "Clash.Magic.noDeDup" _ args
| [_aTy,f] <- args
= C.Tick C.NoDeDup <$> term f

go _ _ = term' e
go _ _ _ = term' e
term' (Var x) = var x
term' (Lit l) = return $ C.Literal (coreToLiteral l)
term' (App eFun (Type tyArg)) = C.TyApp <$> term eFun <*> coreToType tyArg
Expand Down Expand Up @@ -405,7 +435,7 @@ coreToTerm primMap unlocs = term
case hasPrimCoM of
Just _ | ty1_I || ty2_I
-> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2
_ -> term e
_ -> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2
term' (Tick (SourceNote rsp _) e) =
C.Tick (C.SrcSpan (RealSrcSpan rsp)) <$> addUsefull (RealSrcSpan rsp) (term e)
term' (Tick _ e) = term e
Expand Down
18 changes: 12 additions & 6 deletions clash-lib/src/Clash/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ unwindStack m
let term = Tick sp (getTerm m')
in unwindStack (setTerm term m')

Castish ty1 ty2 ->
let term = Cast (getTerm m') ty1 ty2
in unwindStack (setTerm term m')

-- | A single step in the partial evaluator. The result is the new heap and
-- stack, and the next expression to be reduced.
--
Expand Down Expand Up @@ -232,6 +236,8 @@ stepApp x y m tcm =
GT -> let (m0, n) = newLetBinding tcm m y
in Just . setTerm x $ stackPush (Apply n) m0

Cast {} -> error "stepApp QQ"

_ -> let (m0, n) = newLetBinding tcm m y
in Just . setTerm x $ stackPush (Apply n) m0
where
Expand Down Expand Up @@ -264,6 +270,8 @@ stepTyApp x ty m tcm =
LT -> newBinder tys' (TyApp x ty) m tcm
GT -> Just . setTerm x $ stackPush (Instantiate ty) m

Cast {} -> error "stepTyApp QQ"

_ -> Just . setTerm x $ stackPush (Instantiate ty) m
where
(term, args, _) = collectArgsTicks (TyApp x ty)
Expand All @@ -273,17 +281,14 @@ stepLetRec :: [LetBinding] -> Term -> Step
stepLetRec bs x m _ = Just (allocate bs x m)

stepCase :: Term -> Type -> [Alt] -> Step
stepCase (Cast {}) _ty _alts _m _ = error "stepCase QQ"
stepCase scrut ty alts m _ =
Just . setTerm scrut $ stackPush (Scrutinise ty alts) m

-- TODO Support stepwise evaluation of casts.
--
stepCast :: Term -> Type -> Type -> Step
stepCast _ _ _ _ _ =
flip trace Nothing $ unlines
[ "WARNING: " <> $(curLoc) <> "Clash can't symbolically evaluate casts"
, "Please file an issue at https://github.com/clash-lang/clash-compiler/issues"
]
stepCast x ty1 ty2 m _ = Just . setTerm x $ stackPush (Castish ty1 ty2) m

stepTick :: TickInfo -> Term -> Step
stepTick tick x m _ =
Expand Down Expand Up @@ -356,7 +361,8 @@ unwind tcm m v = do
go (Instantiate ty) = return . instantiate v ty
go (PrimApply p tys vs tms) = mPrimUnwind m tcm p tys vs v tms
go (Scrutinise _ as) = return . scrutinise v as
go (Tickish _) = return . setTerm (valToTerm v)
go (Tickish t) = flip (unwind tcm) (TickValue t v)
go (Castish ty1 ty2) = flip (unwind tcm) (CastValue v ty1 ty2)

-- | Update the Heap with the evaluated term
update :: IdScope -> Id -> Value -> Machine -> Machine
Expand Down
3 changes: 3 additions & 0 deletions clash-lib/src/Clash/Core/Evaluator/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ data StackFrame
| PrimApply PrimInfo [Type] [Value] [Term]
| Scrutinise Type [Alt]
| Tickish TickInfo
| Castish Type Type
deriving Show

instance ClashPretty StackFrame where
Expand All @@ -134,6 +135,8 @@ instance ClashPretty StackFrame where
fromPpr (Case (Literal (CharLiteral '_')) a b)]
clashPretty (Tickish sp) =
hsep ["Tick", fromPpr sp]
clashPretty (Castish ty1 ty2) =
hsep ["Cast", fromPpr ty1, fromPpr ty2]

-- Values
data Value
Expand Down
6 changes: 3 additions & 3 deletions clash-lib/src/Clash/Core/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ coreView1 tcMap ty = case tyView ty of
| nameOcc tcNm == "Clash.Signal.BiSignal.BiSignalOut"
, [_,_,_,elTy] <- args
-> Just elTy
| nameOcc tcNm == "Clash.Signal.Internal.Signal"
, [_,elTy] <- args
-> Just elTy
-- | nameOcc tcNm == "Clash.Signal.Internal.Signal"
-- , [_,elTy] <- args
-- -> Just elTy
| otherwise
-> case tcMap `lookupUniqMap'` tcNm of
AlgTyCon {algTcRhs = (NewTyCon _ nt)}
Expand Down
18 changes: 11 additions & 7 deletions clash-lib/src/Clash/Core/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import Clash.Core.Type
coreView, coreView1, isFunTy, isPolyFunCoreTy, mkFunTy, splitFunTy, tyView,
undefinedTy, isTypeFamilyApplication)
import Clash.Core.TyCon
(TyConMap, tyConDataCons)
(TyConMap, TyConName, tyConDataCons)
import Clash.Core.TysPrim (typeNatKind)
import Clash.Core.Var
(Id, TyVar, Var (..), isLocalId, mkLocalId, mkTyVar)
Expand Down Expand Up @@ -995,15 +995,17 @@ shouldSplit
shouldSplit tcm (tyView -> TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") [tyArg]) =
-- We also look through `SimIO` to find things like Files
shouldSplit tcm tyArg
shouldSplit tcm ty = shouldSplit0 tcm (tyView (coreView tcm ty))
shouldSplit tcm ty = shouldSplit0 emptyUniqSet tcm (tyView (coreView tcm ty))

-- | Worker of 'shouldSplit', works on 'TypeView' instead of 'Type'
shouldSplit0
:: TyConMap
:: UniqSet TyConName
-> TyConMap
-> TypeView
-> Maybe (Term,[Type])
shouldSplit0 tcm (TyConApp tcNm tyArgs)
| Just tc <- lookupUniqMap tcNm tcm
shouldSplit0 seen tcm (TyConApp tcNm tyArgs)
| tcNm `notElemUniqSet` seen
, Just tc <- lookupUniqMap tcNm tcm
, [dc] <- tyConDataCons tc
, let dcArgs = substArgTys dc tyArgs
, let dcArgVs = map (tyView . coreView tcm) dcArgs
Expand All @@ -1012,8 +1014,10 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs)
else
Nothing
where
seen1 = extendUniqSet seen tcNm

shouldSplitTy :: TypeView -> Bool
shouldSplitTy ty = isJust (shouldSplit0 tcm ty) || splitTy ty
shouldSplitTy ty = isJust (shouldSplit0 seen1 tcm ty) || splitTy ty

-- Hidden constructs (HiddenClock, HiddenReset, ..) don't need to be split
-- because KnownDomain will be filtered anyway during netlist generation due
Expand Down Expand Up @@ -1046,7 +1050,7 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs)
]
splitTy _ = False

shouldSplit0 _ _ = Nothing
shouldSplit0 _ _ _ = Nothing

-- | Potentially split apart a list of function argument types. e.g. given:
--
Expand Down
6 changes: 6 additions & 0 deletions clash-lib/src/Clash/Netlist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ mkDeclarations'
-> Term
-- ^ RHS of the let-binder
-> NetlistMonad [Declaration]
mkDeclarations' declType bndr (collectTicks -> (Cast e _ _,ticks)) =
mkDeclarations' declType bndr (mkTicks e ticks)

mkDeclarations' _declType bndr (collectTicks -> (Var v,ticks)) =
withTicks ticks $ \tickDecls -> do
mkFunApp (id2identifier bndr) v [] tickDecls
Expand Down Expand Up @@ -734,6 +737,9 @@ mkExpr bbEasD declType bndr app =
decls <- concat <$> mapM (uncurry mkDeclarations) binders
(bodyE,bodyDecls) <- mkExpr bbEasD declType bndr (mkApps (mkTicks body ticks) args)
return (bodyE,netDecls ++ decls ++ bodyDecls)

Cast e0 _ _ | null args ->
mkExpr bbEasD declType bndr (mkTicks e0 ticks)
_ -> throw (ClashException sp ($(curLoc) ++ "Not in normal form: application of a Lambda-expression\n\n" ++ showPpr app) Nothing)

-- | Generate an expression that projects a field out of a data-constructor.
Expand Down
5 changes: 4 additions & 1 deletion clash-lib/src/Clash/Netlist/BlackBox.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ import Clash.Core.Type as C
(Type (..), ConstTy (..), TypeView (..), mkFunTy, splitFunTys, splitFunTy, tyView)
import Clash.Core.TyCon as C (TyConMap, tyConDataCons)
import Clash.Core.Util
(collectBndrs, inverseTopSortLetBindings, isFun, mkApps, splitShouldSplit, termType)
(collectBndrs, inverseTopSortLetBindings, isFun, mkApps, splitShouldSplit,
termType, mkTicks)
import Clash.Core.Var as V
(Id, Var (..), mkLocalId, modifyVarName)
import Clash.Core.VarEnv
Expand Down Expand Up @@ -188,6 +189,7 @@ isLiteral e = case collectArgs e of
(Data _, args) -> all (either isLiteral (const True)) args
(Prim _, args) -> all (either isLiteral (const True)) args
(C.Literal _,_) -> True
(Cast e0 _ _, args) -> all (either isLiteral (const True)) (Left e0:args)
_ -> False

mkArgument
Expand Down Expand Up @@ -239,6 +241,7 @@ mkArgument bndr e = do
(Case scrut ty' [alt],[],_) -> do
(projection,decls) <- mkProjection False (NetlistId bndr ty) scrut ty' alt
return ((projection,hwTy,False),decls)
(Cast e0 _ _,[],ticks) -> mkArgument bndr (mkTicks e0 ticks)
_ ->
return ((Identifier (error ($(curLoc) ++ "Forced to evaluate unexpected function argument: " ++ eTyMsg)) Nothing
,hwTy,False),[])
Expand Down
Loading

0 comments on commit 598039e

Please sign in to comment.