Skip to content

[feat] syntactic semantic tokens #4672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
# Compiler toolchain
hpkgs.ghc
hpkgs.haskell-language-server
pkgs.stack
pkgs.haskellPackages.cabal-install
# Dependencies needed to build some parts of Hackage
gmp zlib ncurses
Expand Down
42 changes: 42 additions & 0 deletions ghcide/src/Development/IDE/GHC/Compat/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,53 @@ instance HasSrcSpan SrcSpan where
instance HasSrcSpan (SrcLoc.GenLocated SrcSpan a) where
getLoc = GHC.getLoc

#if MIN_VERSION_ghc(9,11,0)
instance HasSrcSpan (GHC.EpToken sym) where
getLoc = GHC.getHasLoc
instance HasSrcSpan (GHC.EpUniToken sym sym') where
getLoc = GHC.getHasLoc
#elif MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (GHC.EpToken sym) where
getLoc = GHC.getHasLoc . \case
GHC.NoEpTok -> Nothing
GHC.EpTok loc -> Just loc
instance HasSrcSpan (GHC.EpUniToken sym sym') where
getLoc = GHC.getHasLoc . \case
GHC.NoEpUniTok -> Nothing
GHC.EpUniTok loc _ -> Just loc
#endif

#if MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (EpAnn a) where
getLoc = GHC.getHasLoc
#endif

#if !MIN_VERSION_ghc(9,11,0)
instance HasSrcSpan GHC.AddEpAnn where
getLoc (GHC.AddEpAnn _ loc) = getLoc loc

instance HasSrcSpan GHC.EpaLocation where
#if MIN_VERSION_ghc(9,9,0)
getLoc loc = GHC.getHasLoc loc
#else
getLoc loc = case loc of
GHC.EpaSpan span bufspan -> RealSrcSpan span $ case bufspan of Strict.Nothing -> Nothing; Strict.Just a -> Just a
GHC.EpaDelta {} -> panic "compiler inserted epadelta in EpaLocation"
#endif
#endif

instance HasSrcSpan GHC.LEpaComment where
#if MIN_VERSION_ghc(9,9,0)
getLoc :: GHC.LEpaComment -> SrcSpan
getLoc (GHC.L l _) = case l of
SrcLoc.EpaDelta {} -> panic "compiler inserted epadelta into NoCommentsLocation"
SrcLoc.EpaSpan span -> span
#else
getLoc :: GHC.LEpaComment -> SrcSpan
getLoc c = case c of
SrcLoc.L (GHC.Anchor realSpan _) _ -> RealSrcSpan realSpan Nothing
#endif

#if MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (SrcLoc.GenLocated (EpAnn ann) a) where
getLoc (L l _) = getLoc l
Expand Down
1 change: 1 addition & 0 deletions haskell-language-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ library hls-semantic-tokens-plugin
, containers
, extra
, text-rope
, ghc
, mtl >= 2.2
, ghcide == 2.11.0.0
, hls-plugin-api == 2.11.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ descriptor recorder plId =
{ Ide.Types.pluginHandlers =
mkPluginHandler SMethod_TextDocumentSemanticTokensFull (Internal.semanticTokensFull recorder)
<> mkPluginHandler SMethod_TextDocumentSemanticTokensFullDelta (Internal.semanticTokensFullDelta recorder),
Ide.Types.pluginRules = Internal.getSemanticTokensRule recorder,
Ide.Types.pluginRules = Internal.getSemanticTokensRule recorder <> Internal.getSyntacticTokensRule recorder,
pluginConfigDescriptor =
defaultConfigDescriptor
{ configInitialGenericConfig = (configInitialGenericConfig defaultConfigDescriptor) {plcGlobalOn = False}
{ configInitialGenericConfig = (configInitialGenericConfig defaultConfigDescriptor) {plcGlobalOn = True}
, configCustomConfig = mkCustomConfig Internal.semanticConfigProperties
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LiberalTypeSynonyms #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnicodeSyntax #-}

-- |
-- This module provides the core functionality of the plugin.
module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull, getSemanticTokensRule, semanticConfigProperties, semanticTokensFullDelta) where
module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull, getSemanticTokensRule, getSyntacticTokensRule, semanticConfigProperties, semanticTokensFullDelta) where

import Control.Concurrent.STM (stateTVar)
import Control.Concurrent.STM.Stats (atomically)
Expand All @@ -20,31 +27,39 @@ import Control.Monad.Except (ExceptT, liftEither,
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except (runExceptT)
import Control.Monad.Trans.Maybe
import Data.Data (Data (..))
import Data.List
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.Text (Text)
import qualified Data.Text as T
import Development.IDE (Action,
GetDocMap (GetDocMap),
GetHieAst (GetHieAst),
GetParsedModuleWithComments (..),
HieAstResult (HAR, hieAst, hieModule, refMap),
IdeResult, IdeState,
Priority (..),
Recorder, Rules,
WithPriority,
cmapWithPrio, define,
fromNormalizedFilePath,
hieKind)
hieKind,
srcSpanToRange,
useWithStale)
import Development.IDE.Core.PluginUtils (runActionE, useE,
useWithStaleE)
import Development.IDE.Core.PositionMapping
import Development.IDE.Core.Rules (toIdeResult)
import Development.IDE.Core.RuleTypes (DocAndTyThingMap (..))
import Development.IDE.Core.Shake (ShakeExtras (..),
getShakeExtras,
getVirtualFile)
import Development.IDE.GHC.Compat hiding (Warning)
import Development.IDE.GHC.Compat.Util (mkFastString)
import GHC.Parser.Annotation
import Ide.Logger (logWith)
import Ide.Plugin.Error (PluginError (PluginInternalError),
import Ide.Plugin.Error (PluginError (PluginInternalError, PluginRuleFailed),
getNormalizedFilePathE,
handleMaybe,
handleMaybeM)
Expand All @@ -58,10 +73,17 @@ import qualified Language.LSP.Protocol.Lens as L
import Language.LSP.Protocol.Message (MessageResult,
Method (Method_TextDocumentSemanticTokensFull, Method_TextDocumentSemanticTokensFullDelta))
import Language.LSP.Protocol.Types (NormalizedFilePath,
Range,
SemanticTokens,
fromNormalizedFilePath,
type (|?) (InL, InR))
import Prelude hiding (span)
import qualified StmContainers.Map as STM
import Type.Reflection (Typeable, eqTypeRep,
pattern App,
type (:~~:) (HRefl),
typeOf, typeRep,
withTypeable)


$mkSemanticConfigFunctions
Expand All @@ -75,8 +97,17 @@ computeSemanticTokens recorder pid _ nfp = do
config <- lift $ useSemanticConfigAction pid
logWith recorder Debug (LogConfig config)
semanticId <- lift getAndIncreaseSemanticTokensId
(RangeHsSemanticTokenTypes {rangeSemanticList}, mapping) <- useWithStaleE GetSemanticTokens nfp
withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config mapping rangeSemanticList

tokenList <- sortOn fst <$> do
rangesyntacticTypes <- lift $ useWithStale GetSyntacticTokens nfp
rangesemanticTypes <- lift $ useWithStale GetSemanticTokens nfp
let mk w u (toks, mapping) = map (\(ran, tok) -> (toCurrentRange mapping ran, w tok)) $ u toks
maybeToExceptT (PluginRuleFailed "no syntactic nor semantic tokens") $ hoistMaybe $
(mk HsSyntacticTokenType rangeSyntacticList <$> rangesyntacticTypes)
<> (mk HsSemanticTokenType rangeSemanticList <$> rangesemanticTypes)

-- NOTE: rangeSemanticsSemanticTokens actually assumes that the tokesn are in order. that means they have to be sorted by position
withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config tokenList

semanticTokensFull :: Recorder (WithPriority SemanticLog) -> PluginMethodHandler IdeState 'Method_TextDocumentSemanticTokensFull
semanticTokensFull recorder state pid param = runActionE "SemanticTokens.semanticTokensFull" state computeSemanticTokensFull
Expand Down Expand Up @@ -130,6 +161,132 @@ getSemanticTokensRule recorder =
let hsFinder = idSemantic getTyThingMap (hieKindFunMasksKind hieKind) refMap
return $ computeRangeHsSemanticTokenTypeList hsFinder virtualFile ast

getSyntacticTokensRule :: Recorder (WithPriority SemanticLog) -> Rules ()
getSyntacticTokensRule recorder =
define (cmapWithPrio LogShake recorder) $ \GetSyntacticTokens nfp -> handleError recorder $ do
(parsedModule, _) <- withExceptT LogDependencyError $ useWithStaleE GetParsedModuleWithComments nfp
pure $ computeRangeHsSyntacticTokenTypeList parsedModule

astTraversalWith :: forall b r. Data b => b -> (forall a. Data a => a -> [r]) -> [r]
astTraversalWith ast f = mconcat $ flip gmapQ ast \y -> f y <> astTraversalWith y f

{-# inline extractTyToTyToTy #-}
extractTyToTyToTy :: forall f a. (Typeable f, Data a) => a -> Maybe (forall r. (forall b c. (Typeable b, Typeable c) => f b c -> r) -> r)
extractTyToTyToTy node
| App (App conRep argRep1) argRep2 <- typeOf node
, Just HRefl <- eqTypeRep conRep (typeRep @f)
= Just $ withTypeable argRep1 $ withTypeable argRep2 \k -> k node
| otherwise = Nothing

{-# inline extractTyToTy #-}
extractTyToTy :: forall f a. (Typeable f, Data a) => a -> Maybe (forall r. (forall b. Typeable b => f b -> r) -> r)
extractTyToTy node
| App conRep argRep <- typeOf node
, Just HRefl <- eqTypeRep conRep (typeRep @f)
= Just $ withTypeable argRep \k -> k node
| otherwise = Nothing

{-# inline extractTy #-}
extractTy :: forall b a. (Typeable b, Data a) => a -> Maybe b
extractTy node
| Just HRefl <- eqTypeRep (typeRep @b) (typeOf node)
= Just node
| otherwise = Nothing

computeRangeHsSyntacticTokenTypeList :: ParsedModule -> RangeHsSyntacticTokenTypes
computeRangeHsSyntacticTokenTypeList ParsedModule {pm_parsed_source} =
let toks = astTraversalWith pm_parsed_source \node -> mconcat
[
#if MIN_VERSION_ghc(9,9,0)
maybeToList $ mkFromLocatable TKeyword . (\k -> k \x k' -> k' x) =<< extractTyToTy @EpToken node,
maybeToList $ mkFromLocatable TKeyword . (\k -> k \x k' -> k' x) =<< extractTyToTyToTy @EpUniToken node,
do
AnnContext {ac_darrow, ac_open, ac_close} <- maybeToList $ extractTy node
let mkFromTok :: (Foldable f, HasSrcSpan a) => f a -> [(Range,HsSyntacticTokenType)]
mkFromTok = foldMap (\tok -> maybeToList $ mkFromLocatable TKeyword \k -> k tok)
mconcat
#if MIN_VERSION_ghc(9,11,0)
[ mkFromTok ac_darrow
#else
[ foldMap (\(_, loc) -> maybeToList $ mkFromLocatable TKeyword \k -> k loc) ac_darrow
#endif
, mkFromTok ac_open
, mkFromTok ac_close
],
#endif

#if !MIN_VERSION_ghc(9,11,0)
maybeToList $ mkFromLocatable TKeyword . (\x k -> k x) =<< extractTy @AddEpAnn node,
do
EpAnnImportDecl i p s q pkg a <- maybeToList $ extractTy @EpAnnImportDecl node
mapMaybe (mkFromLocatable TKeyword . (\x k -> k x)) $ catMaybes $ [Just i, s, q, pkg, a] <> foldMap (\(l, l') -> [Just l, Just l']) p,
#endif
maybeToList do
comment <- extractTy @LEpaComment node
#if !MIN_VERSION_ghc(9,7,0)
-- NOTE: on ghc 9.6 there's an empty comment that is supposed to
-- located the end of file
case comment of
L _ (EpaComment {ac_tok = EpaEofComment}) -> Nothing
_ -> pure ()
#endif
mkFromLocatable TComment \k -> k comment,
do
L loc expr <- maybeToList $ extractTy @(LHsExpr GhcPs) node
let fromSimple = maybeToList . flip mkFromLocatable \k -> k loc
case expr of
HsOverLabel {} -> fromSimple TStringLit
HsOverLit _ (OverLit _ lit) -> fromSimple case lit of
HsIntegral {} -> TNumberLit
HsFractional {} -> TNumberLit

HsIsString {} -> TStringLit
HsLit _ lit -> fromSimple case lit of
-- NOTE: unfortunately, lsp semantic tokens doesn't have a notion of char literals
HsChar {} -> TStringLit
HsCharPrim {} -> TStringLit

HsInt {} -> TNumberLit
HsInteger {} -> TNumberLit
HsIntPrim {} -> TNumberLit
HsWordPrim {} -> TNumberLit
#if MIN_VERSION_ghc(9,9,0)
HsWord8Prim {} -> TNumberLit
HsWord16Prim {} -> TNumberLit
HsWord32Prim {} -> TNumberLit
#endif
HsWord64Prim {} -> TNumberLit
#if MIN_VERSION_ghc(9,9,0)
HsInt8Prim {} -> TNumberLit
HsInt16Prim {} -> TNumberLit
HsInt32Prim {} -> TNumberLit
#endif
HsInt64Prim {} -> TNumberLit
HsFloatPrim {} -> TNumberLit
HsDoublePrim {} -> TNumberLit
HsRat {} -> TNumberLit

HsString {} -> TStringLit
HsStringPrim {} -> TStringLit
#if MIN_VERSION_ghc(9,11,0)
HsMultilineString {} -> TStringLit
#endif
HsGetField _ _ field -> maybeToList $ mkFromLocatable TRecordSelector \k -> k field
#if MIN_VERSION_ghc(9,11,0)
HsProjection _ projs -> foldMap (\dotFieldOcc -> maybeToList $ mkFromLocatable TRecordSelector \k -> k dotFieldOcc.dfoLabel) projs
#else
HsProjection _ projs -> foldMap (\proj -> maybeToList $ mkFromLocatable TRecordSelector \k -> k proj) projs
#endif
_ -> []
]
in RangeHsSyntacticTokenTypes toks

{-# inline mkFromLocatable #-}
mkFromLocatable
:: HsSyntacticTokenType
-> (forall r. (forall a. HasSrcSpan a => a -> r) -> r)
-> Maybe (Range, HsSyntacticTokenType)
mkFromLocatable tt w = w \tok -> let mrange = srcSpanToRange $ getLoc tok in fmap (, tt) mrange

-- taken from /haskell-language-server/plugins/hls-code-range-plugin/src/Ide/Plugin/CodeRange/Rules.hs

Expand Down
Loading
Loading