diff --git a/.gitignore b/.gitignore index 2851260..82efca7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ dist-* result -cabal.project.local \ No newline at end of file +cabal.project.local +tmp* +.DS_STORE diff --git a/api-contract/.gitignore b/api-contract/.gitignore new file mode 100644 index 0000000..efe3757 --- /dev/null +++ b/api-contract/.gitignore @@ -0,0 +1,5 @@ +dist* +result +dist-newstyle +result +tmp* \ No newline at end of file diff --git a/api-contract/.juspay/api-contract/test/Main.hs.yaml b/api-contract/.juspay/api-contract/test/Main.hs.yaml new file mode 100644 index 0000000..4cb2eca --- /dev/null +++ b/api-contract/.juspay/api-contract/test/Main.hs.yaml @@ -0,0 +1,59 @@ +RefundAttempt: + caseType: null + dataConstructors: + RefundAttempt: + fields': + created: Text + error_code: Maybe Value + error_message: Maybe Text + id': Maybe Text + last_modified: Maybe Text + ref: Maybe Text + sumTypes: [] + instances: + parseJSON: + fieldsList: + - id' + - created + - ref + - error_message + - error_code + - last_modified + typeOfInstance: Custom + toEncoding: + fieldsList: + - id' + - created + - ref + - error_message + - error_code + - last_modified + typeOfInstance: Custom + toJSON: + fieldsList: [] + typeOfInstance: Derived + typeKind: data + +RefundAttempt': + caseType: null + dataConstructors: + RefundAttempt': + fields': + created'': Text + error_code'': Maybe Value + error_message'': Maybe Text + id'': Maybe Text + last_modified'': Maybe Text + ref'': Maybe Text + sumTypes: [] + instances: + parseJSON: + fieldsList: [] + typeOfInstance: Derived + toEncoding: + fieldsList: [] + typeOfInstance: Derived + toJSON: + fieldsList: [] + typeOfInstance: Derived + typeKind: data \ No newline at end of file diff --git a/api-contract/CHANGELOG.md b/api-contract/CHANGELOG.md new file mode 100644 index 0000000..96ba851 --- /dev/null +++ b/api-contract/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for api-contract + +## 0.1.0.0 -- YYYY-mm-dd + +* First version. Released on an unsuspecting world. diff --git a/api-contract/LICENSE b/api-contract/LICENSE new file mode 100644 index 0000000..0a9095c --- /dev/null +++ b/api-contract/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2024 eswar2001 + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/api-contract/api-contract.cabal b/api-contract/api-contract.cabal new file mode 100644 index 0000000..798e25f --- /dev/null +++ b/api-contract/api-contract.cabal @@ -0,0 +1,86 @@ +cabal-version: 3.0 +name: api-contract +version: 0.1.0.0 +-- synopsis: +-- description: +license: MIT +license-file: LICENSE +author: eswar2001 +maintainer: eswar.tadiparth@juspay.in +-- copyright: +category: Control +build-type: Simple +extra-doc-files: CHANGELOG.md +data-files: .juspay/api-contract/**/*.hs.yaml +-- extra-source-files: + +flag enable-isolation + description: set this flag to enable these plugins Data.Record.Plugin , Data.Record.Anon.Plugin , Data.Record.Plugin.HasFieldPattern + default: False + manual: True + +common warnings + if flag(enable-isolation) + cpp-options: -DENABLE_ISOLATION + ghc-options: -Wall + +library + import: warnings + exposed-modules: + ApiContract.Plugin + , ApiContract.Types + -- other-modules: + -- other-extensions: + build-depends: + bytestring + , containers + , filepath + , ghc + , unordered-containers + , aeson + , directory + , extra + , aeson-pretty + , base + , text + , base64-bytestring + , optparse-applicative + , deepseq + , time + , async + , cryptonite + , hasbolt + , universum + , streamly-core + , data-default + , large-records + , large-generics + , large-anon + , ghc-hasfield-plugin + , record-dot-preprocessor + , ghc-tcplugin-api + , typelet + , record-hasfield + , binary + , references + , uniplate + , yaml + + hs-source-dirs: src + default-language: Haskell2010 + +test-suite api-contract-test + import: warnings + default-language: Haskell2010 + ghc-options: -fplugin=ApiContract.Plugin + -- other-modules: + -- other-extensions: + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + build-depends: + base + , api-contract + , aeson + , record-hasfield + , text diff --git a/api-contract/src/ApiContract/Plugin.hs b/api-contract/src/ApiContract/Plugin.hs new file mode 100644 index 0000000..174e441 --- /dev/null +++ b/api-contract/src/ApiContract/Plugin.hs @@ -0,0 +1,595 @@ + +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RecordWildCards #-} + +module ApiContract.Plugin where + +#if __GLASGOW_HASKELL__ >= 900 +import Language.Haskell.Syntax.Type +import GHC.Hs.Extension () +import GHC.Parser.Annotation () +import GHC.Utils.Outputable (docToSDoc) +import qualified Data.IntMap.Internal as IntMap +import GHC.Data.Bag +import GHC.HsToCore +import GHC.Types.SrcLoc +import GHC.Driver.Errors +import GHC.Unit.Types +import GHC.Driver.Backpack.Syntax +import GHC.Unit.Info +-- import Streamly.Internal.Data.Stream (fromList,mapM_,mapM,toList) +import GHC hiding (typeKind) +import GHC.Driver.Plugins (Plugin(..),CommandLineOption,defaultPlugin,PluginRecompile(..)) +import GHC.Driver.Env +import GHC.Tc.Types +import GHC.Unit.Module.ModSummary +import GHC.Utils.Outputable (showSDocUnsafe,ppr,SDoc,Outputable,reallyAlwaysQualify) +import GHC.Data.Bag (bagToList) +import GHC.Types.Name hiding (varName) +import GHC.Types.Var +import qualified Data.Aeson.KeyMap as HM +import qualified Data.Aeson.Key as HM +import GHC.Core.Opt.Monad +import GHC.Rename.HsType +import qualified GHC.Tc.Utils.Monad as TCError +import qualified GHC.Types.SourceError as ParseError +import qualified GHC.Types.Error as ParseError +import GHC.Types.Name.Reader ( rdrNameOcc ,rdrNameSpace) +import GHC.Core.TyCo.Rep +import GHC.Data.FastString +import GHC.IO (unsafePerformIO) +import qualified GHC.Utils.Ppr as Pretty +#else +import FastString +import CoreMonad (CoreM, CoreToDo (CoreDoPluginPass), liftIO) +import CoreSyn ( + AltCon (..), + Bind (NonRec, Rec), + CoreBind, + CoreExpr, + Expr (..), + mkStringLit + ) +import TyCoRep +import GHC.IO (unsafePerformIO) +import GHC.Hs +import GHC.Hs.Decls +import GhcPlugins ( + CommandLineOption,Arg (..), + HsParsedModule(..), + Hsc, + Name,SDoc,DataCon,DynFlags,ModSummary(..),TyCon, + Literal (..),typeEnvElts, + ModGuts (mg_binds, mg_loc, mg_module),showSDoc, + Module (moduleName),tyConKind, + NamedThing (getName),getDynFlags,tyConDataCons,dataConOrigArgTys,dataConName, + Outputable (..),dataConFieldLabels,PluginRecompile(..), + Plugin (..), + Var,flLabel,dataConRepType, + coVarDetails, + defaultPlugin, + idName, + mkInternalName, + mkLitString, + mkLocalVar, + mkVarOcc, + moduleNameString, + nameStableString, + noCafIdInfo, + purePlugin, + showSDocUnsafe, + tyVarKind, + unpackFS, + tyConName, + msHsFilePath + ) +import Id (isExportedId,idType) +import Name (getSrcSpan) +import SrcLoc +import Unique (mkUnique) +import Var (isLocalId,varType) +import TcRnTypes +import TcRnMonad +import DataCon +#endif + +import Control.Reference (biplateRef, (^?)) +import ApiContract.Types +-- import Data.Aeson +import Data.List.Extra (intercalate, isSuffixOf, replace, splitOn,groupBy) +import Data.List ( sortBy, intercalate ,foldl') +import qualified Data.Map as Map +import Data.Text (Text, concat, isInfixOf, pack, unpack) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Time +import Data.Map (Map) +import Data.Data +import Data.Maybe (catMaybes,isJust,fromJust) +import Control.Monad.IO.Class (liftIO) +import System.IO (writeFile) +import Streamly.Internal.Data.Stream hiding (concatMap, init, length, map, splitOn,foldl',intercalate) +import System.Directory (createDirectoryIfMissing, removeFile,doesFileExist) +import System.Directory.Internal.Prelude hiding (mapM, mapM_) +import Prelude hiding (id, mapM, mapM_) +import Control.Exception (evaluate) +import qualified Data.Record.Plugin as DRP +import qualified Data.Record.Anon.Plugin as DRAP +import qualified Data.Record.Plugin.HasFieldPattern as DRPH +import qualified RecordDotPreprocessor as RDP +import qualified Data.Yaml as YAML +import Control.Monad (foldM) +import Data.Char +import qualified Data.ByteString as DBS +import Data.Bool (bool) +import GHC (noLoc) +import Data.Aeson.Encode.Pretty (mempty) +import GHC.ExecutionStack (Location(srcLoc)) +import Streamly.Internal.Data.Parser (ParseError(ParseError)) +import ApiContract.Types +import Data.List (nub) +import Control.Concurrent (MVar,putMVar,takeMVar,readMVar,newMVar) +import Data.ByteString (elem) + +-- ENABLE_ISOLATION +#if defined(ENABLE_LR_PLUGINS) +plugin :: Plugin +plugin = (defaultPlugin{ + -- installCoreToDos = install + pluginRecompile = (\_ -> return NoForceRecompile) + , parsedResultAction = collectTypeInfoParser + , typeCheckResultAction = collectInstanceInfo + }) +#if defined(ENABLE_LR_PLUGINS) + <> DRP.plugin + <> DRAP.plugin + <> DRPH.plugin +#endif + <> RDP.plugin + +instance Semigroup Plugin where + p <> q = defaultPlugin { + parsedResultAction = \args summary -> + parsedResultAction p args summary + >=> parsedResultAction q args summary + , typeCheckResultAction = \args summary -> + typeCheckResultAction p args summary + >=> typeCheckResultAction q args summary + , pluginRecompile = \args -> + (<>) + <$> pluginRecompile p args + <*> pluginRecompile q args + , tcPlugin = \args -> + case (tcPlugin p args, tcPlugin q args) of + (Nothing, Nothing) -> Nothing + (Just tp, Nothing) -> Just tp + (Nothing, Just tq) -> Just tq + (Just (TcPlugin tcPluginInit1 tcPluginSolve1 tcPluginStop1), Just (TcPlugin tcPluginInit2 tcPluginSolve2 tcPluginStop2)) -> Just $ TcPlugin + { tcPluginInit = do + ip <- tcPluginInit1 + iq <- tcPluginInit2 + return (ip, iq) + , tcPluginSolve = \(sp,sq) given derived wanted -> do + solveP <- tcPluginSolve1 sp given derived wanted + solveQ <- tcPluginSolve2 sq given derived wanted + return $ combineTcPluginResults solveP solveQ + , tcPluginStop = \(solveP,solveQ) -> do + tcPluginStop1 solveP + tcPluginStop2 solveQ + } + } + +combineTcPluginResults :: TcPluginResult -> TcPluginResult -> TcPluginResult +combineTcPluginResults resP resQ = + case (resP, resQ) of + (TcPluginContradiction ctsP, TcPluginContradiction ctsQ) -> + TcPluginContradiction (ctsP ++ ctsQ) + + (TcPluginContradiction ctsP, TcPluginOk _ _) -> + TcPluginContradiction ctsP + + (TcPluginOk _ _, TcPluginContradiction ctsQ) -> + TcPluginContradiction ctsQ + + (TcPluginOk solvedP newP, TcPluginOk solvedQ newQ) -> + TcPluginOk (solvedP ++ solvedQ) (newP ++ newQ) + + +instance Monoid Plugin where + mempty = defaultPlugin + +instance Outputable Void where + +#else +plugin :: Plugin +plugin = (defaultPlugin{ + pluginRecompile = (\_ -> return NoForceRecompile) + , parsedResultAction = collectTypeInfoParser + , typeCheckResultAction = collectInstanceInfo + }) +#endif + +pprTyCon :: Name -> SDoc +pprTyCon = ppr + +pprDataCon :: Name -> SDoc +pprDataCon = ppr + +instanceToAdd :: [String] +instanceToAdd = ["deriveJSON","FromMultipart Mem","FromHttpApiData","ToHttpApiData","MimeRender JSON","FromForm","ToJSON","FromJSON","toJSON","fromJSON","toEncoding","toXML","ToXml","toXml","fromXml","FromXml","ToCybsXml","toCybsXml"] + +collectTypeInfoParser :: [CommandLineOption] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule +collectTypeInfoParser opts modSummary hpm = do + let prefixPath = "./.juspay/api-contract/" + moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + modulePath = prefixPath <> msHsFilePath modSummary + moduleSrcSpan = mkFileSrcSpan $ ms_location modSummary + hm_module = unLoc $ hpm_module hpm + path = (intercalate "/" . init . splitOn "/") modulePath + liftIO $ createDirectoryIfMissing True path + typesInThisModule <- liftIO $ toList $ mapM (pure . getTypeInfo) (fromList $ hsmodDecls hm_module) + typesToInstancesPresent <- liftIO $ toList $ mapM (getInstancesInfo) (fromList $ hsmodDecls hm_module) + when (generateTypesRules) $ liftIO $ print $ Data.List.nub $ Prelude.concat typesToInstancesPresent + (shouldAddTypes :: [String]) <- foldM (\acc (inst,type_) -> if inst `Prelude.elem` instanceToAdd then pure $ acc <> [type_] else pure $ acc) [] (Data.List.nub $ Prelude.concat typesToInstancesPresent) + let (srcSpansHM :: HM.KeyMap SrcSpan) = HM.fromList $ map (\(srcSpan,a,_) -> (HM.fromString a, srcSpan)) $ Prelude.concat typesInThisModule + (typeVsFields :: HM.KeyMap TypeRule) = HM.fromList $ Prelude.filter (\(typeName,_) -> (HM.toString typeName) `Prelude.elem` shouldAddTypes) $ map (\(_,a,b) -> (HM.fromString a, b)) $ Prelude.concat typesInThisModule + isOldFile <- liftIO $ doesFileExist (modulePath <> ".yaml") + if generateTypesRules || (not $ isOldFile) + then do + newModuleList <- liftIO $ takeMVar newModuleListMvar + liftIO $ putMVar newModuleListMvar (newModuleList <> [modulePath]) + liftIO $ DBS.writeFile (modulePath <> ".yaml") (YAML.encode typeVsFields) + else do + (eitherTypeRules :: Either (YAML.ParseException) (HM.KeyMap TypeRule)) <- liftIO $ fetchRules (modulePath <> ".yaml") + case eitherTypeRules of + Left err -> ParseError.throwErrors + $ listToBag + $ [ParseError.mkErr moduleSrcSpan reallyAlwaysQualify (ParseError.mkDecorated [docToSDoc $ Pretty.text $ (modulePath <> ".yaml") <> " is missing for this module : " <> show err])] + Right typeRules -> do + errors :: [[(SrcSpan,ApiContractError)]] <- liftIO $ toList $ mapM (\(typeName,rules) -> do + case HM.lookup typeName typeVsFields of + Just typeRule -> do + let srcSpan = fromJust $ HM.lookup typeName srcSpansHM + errorList <- runFieldNameAndTypeRule (HM.toString typeName) (caseType rules) (dataConstructors rules) (dataConstructors typeRule) + pure $ map (\x -> (srcSpan,x)) $ errorList + Nothing -> pure [(moduleSrcSpan,(MISSING_TYPE_CODE (HM.toString typeName)))] + ) (fromList $ HM.toList typeRules) + let missingTypesInRulesWithAeson = mempty--map (\x -> if HM.member (HM.fromString x) typeRules then mempty else [(moduleSrcSpan,(MISSING_TYPE_IN_RULE (x) (maybe (mempty) (\y -> (unpack . decodeUtf8 . YAML.encode) $ Map.fromList [(x,y)]) $ HM.lookup (HM.fromString x) typeVsFields)))] ) shouldAddTypes + errorsNubbed :: [(SrcSpan,ApiContractError)] <- pure $ Data.List.nub $ Prelude.concat (errors <> missingTypesInRulesWithAeson) + if (not $ Prelude.null $ errorsNubbed) + then do + errorMessages <- pure $ listToBag $ map (\(srcSpan,errorMessage) -> ParseError.mkErr srcSpan reallyAlwaysQualify (ParseError.mkDecorated [docToSDoc $ Pretty.text $ generateErrorMessage (modulePath <> ".yaml") errorMessage])) errorsNubbed + ParseError.throwErrors errorMessages + else pure () + pure hpm + +newModuleListMvar :: MVar [String] +{-# NOINLINE newModuleListMvar #-} +newModuleListMvar = unsafePerformIO (newMVar mempty) + +collectInstanceInfo :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv +collectInstanceInfo opts modSummary tcEnv = do + let prefixPath = "./.juspay/api-contract/" + moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + modulePath = prefixPath <> msHsFilePath modSummary + moduleSrcSpan = mkFileSrcSpan $ ms_location modSummary + path = (intercalate "/" . init . splitOn "/") modulePath + typeInstances <- liftIO $ toList $ mapM processInstance (fromList $ bagToList $ tcg_binds tcEnv) + (instanceSrcSpansHM :: HM.KeyMap SrcSpan) <- pure $ HM.fromList $ map (\(srcSpan,typeName,instanceName,_) -> (HM.fromString (typeName <> "--" <> instanceName), srcSpan)) $ Prelude.concat typeInstances + (instanceTypeHM :: HM.KeyMap InstanceFromTC) <- pure $ HM.fromList $ map (\(_,typeName,instanceName,x) -> (HM.fromString (typeName <> "--" <> instanceName), x)) $ Prelude.concat typeInstances + (eitherTypeRules :: Either (YAML.ParseException) (HM.KeyMap TypeRule)) <- liftIO $ fetchRules (modulePath <> ".yaml") + newModuleList <- liftIO $ readMVar newModuleListMvar + case eitherTypeRules of + Left err -> TCError.addErrs $ [(moduleSrcSpan,docToSDoc $ Pretty.text $ (modulePath <> ".yaml") <> " " <> "is missing for this module :" <> show err)] + Right typeRules -> do + let updatedTypesRules = foldl' (\hm (_,typeName,instanceName,x) -> + case HM.lookup (HM.fromString typeName) hm of + Just v -> HM.insert (HM.fromString typeName) (v{instances = Map.insert instanceName x (instances v)}) hm + Nothing -> hm + ) typeRules $ Prelude.concat typeInstances + isOldFile <- liftIO $ doesFileExist (modulePath <> ".yaml") + if generateTypesRules || (not $ isOldFile) || (modulePath `Prelude.elem` newModuleList) + then do + when (generateTypesRules) $ liftIO $ print "dumping rules" + liftIO $ DBS.writeFile (modulePath <> ".yaml") (YAML.encode updatedTypesRules) + else do + errors :: [[(SrcSpan,ApiContractError)]] <- liftIO $ toList $ mapM (\(typeName,rules) -> + let (instancesMap :: Map.Map String InstanceFromTC) = instances rules + errors = map (\(x,ruleInst) -> + case HM.lookup (typeName <> (HM.fromString "--") <> (HM.fromString x)) instanceTypeHM of + Just inst -> + let instanceSpan = fromJust $ HM.lookup (typeName <> (HM.fromString "--") <> (HM.fromString x)) instanceSrcSpansHM + typeOfInstanceCheck = + if (typeOfInstance inst) == (typeOfInstance ruleInst) + then mempty + else [(instanceSpan,(TYPE_OF_INSTANCE_CHANGED (HM.toString typeName) x (typeOfInstance ruleInst)))] + fieldListCheck = + map (\y -> + if y `Prelude.elem` fieldsList inst + then mempty + else [(instanceSpan,(MISSING_FIELD_IN_INSTANCE_CODE (HM.toString typeName) x y))] + ) (fieldsList ruleInst) + fieldListCheckInverse = + map (\y -> + if y `Prelude.elem` fieldsList ruleInst + then mempty + else [(instanceSpan,(MISSING_FIELD_IN_INSTANCE_RULES (HM.toString typeName) x y))] + ) (fieldsList inst) + in typeOfInstanceCheck <> (Prelude.concat $ fieldListCheck <> fieldListCheckInverse) + Nothing -> [(moduleSrcSpan, (MISSING_INSTANCE_IN_CODE (HM.toString typeName) x))] + ) $ Map.toList instancesMap + in pure $ Prelude.concat errors + ) (fromList $ HM.toList typeRules) + let missingInstanceConstraintsInRules = map (\(k,v) -> + let typeName = HM.fromString $ Prelude.head $ splitOn "--" $ HM.toString k + instanceName = Prelude.last $ splitOn "--" $ HM.toString k + in case HM.lookup typeName typeRules of + Just rules -> + case Map.lookup instanceName $ instances rules of + Just val -> mempty + Nothing -> [(moduleSrcSpan, (MISSING_INSTANCE_IN_RULES (HM.toString typeName) (instanceName) (maybe (mempty) (\y -> (unpack . decodeUtf8 . YAML.encode) $ Map.fromList [((HM.toString typeName),y)]) $ HM.lookup typeName updatedTypesRules)))] + Nothing -> mempty + ) $ HM.toList instanceTypeHM + errorsNubbed :: [(SrcSpan,ApiContractError)] <- pure $ Data.List.nub $ Prelude.concat (errors <> missingInstanceConstraintsInRules) + if (not $ Prelude.null $ errorsNubbed) + then do + TCError.addErrs $ map (\(srcSpan,errorMessage) -> (srcSpan,docToSDoc $ Pretty.text $ generateErrorMessage (modulePath <> ".yaml") errorMessage)) errorsNubbed + else pure () + pure tcEnv + +processInstance :: LHsBindLR GhcTc GhcTc -> IO [(SrcSpan,String,String,InstanceFromTC)] +processInstance (L l (FunBind _ id' matches _)) = do + let instanceFunctionName = replace "$_in$" "" $ nameStableString $ getName id' + stmts = (mg_alts matches) ^? biplateRef :: [LHsExpr GhcTc] + possibleFields = Data.List.nub $ map (\(_,val) -> getLit $ unXRec @(GhcTc) val) $ Prelude.filter (\(constr,_) -> constr `Prelude.elem` ["HsLit"] ) $ map ((\x -> (show $ toConstr $ unLoc x,x))) (stmts) + typeSignature = getAppliedOnTypeName instanceFunctionName $ varType (unLoc id') + if isJust typeSignature then + if Prelude.null possibleFields + then pure [(locA l,fromMaybe mempty typeSignature,instanceFunctionName,InstanceFromTC possibleFields Derived)] + else pure [(locA l,fromMaybe mempty typeSignature,instanceFunctionName,InstanceFromTC possibleFields Custom)] + else pure mempty +processInstance (L _ (VarBind{var_id = var, var_rhs = expr})) = pure mempty +processInstance (L _ (PatBind{pat_lhs = pat, pat_rhs = expr})) = pure mempty +processInstance (L _ (AbsBinds{abs_binds = binds})) = do + res <- toList $ mapM processInstance $ fromList $ bagToList binds + pure $ Prelude.concat res +processInstance _ = pure mempty + +getLit :: HsExpr p -> [Char] +getLit (HsLit _ (HsChar _ char)) = [char] +getLit (HsLit _ (HsCharPrim _ char)) = [char] +getLit (HsLit _ (HsString _ fs)) = unpackFS fs +getLit (HsLit _ (HsStringPrim _ bs)) = unpack $ decodeUtf8 bs +getLit _ = mempty + +getAppliedOnTypeName :: String -> Type -> Maybe String +getAppliedOnTypeName "parseJSON" (FunTy _ _ arg (TyConApp _ types)) = Just $ (showSDocUnsafe $ ppr $ Prelude.last types) +getAppliedOnTypeName "toJSON" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName "toEncoding" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName "toXml" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName "fromXml" (FunTy _ _ arg (TyConApp _ types)) = Just $ (showSDocUnsafe $ ppr $ Prelude.last types) +getAppliedOnTypeName "fromXml" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName "toCybsXml" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName "toXML" (FunTy _ _ arg res) = Just $ (showSDocUnsafe $ ppr arg) +getAppliedOnTypeName _ _ = Nothing + +processHsSplice (HsTypedSplice _ _ name expr) = do + when (generateTypesRules) $ print ("HsTypedSplice",showSDocUnsafe $ ppr name , showSDocUnsafe $ ppr expr) + pure mempty +processHsSplice (HsUntypedSplice _ _ name expr) = do + let types = expr ^? biplateRef :: [HsExpr GhcPs] + typeName = map (\(_,y) -> replace "''" "" y) $ Prelude.filter (\(const,_) -> const `Prelude.elem` ["HsBracket"]) $ map (\x -> (show $ toConstr x,showSDocUnsafe $ ppr x)) types + possibleInstances = map (\(_,y) -> y) $ Prelude.filter (\(const,_) -> const `Prelude.elem` ["HsVar"]) $ map (\x -> (show $ toConstr x,showSDocUnsafe $ ppr x)) types + pure $ Prelude.concat $ map (\x -> map (\y -> (y,x)) possibleInstances) typeName +processHsSplice (HsQuasiQuote _ id1 id2 srcSpan fs) = do + when (generateTypesRules) $ print ("HsQuasiQuote",showSDocUnsafe $ ppr id1 , showSDocUnsafe $ ppr id2) + pure mempty +processHsSplice (HsSpliced _ _ expr) = do + case expr of + (HsSplicedExpr expr' ) -> when (generateTypesRules) $ print (showSDocUnsafe $ ppr expr') + (HsSplicedTy type_ ) -> when (generateTypesRules) $ print (showSDocUnsafe $ ppr type_) + (HsSplicedPat pat) -> when (generateTypesRules) $ print (showSDocUnsafe $ ppr pat) + pure mempty + +getInstancesInfo :: LHsDecl GhcPs -> IO [(String,String)] +getInstancesInfo (L l (TyClD _ (DataDecl _ lname _ _ defn))) = do + let types = defn ^? biplateRef :: [HsType GhcPs] + pure $ Data.List.nub $ map (\x -> (showSDocUnsafe $ ppr x,showSDocUnsafe $ ppr lname)) types +getInstancesInfo (L l (SpliceD _ (SpliceDecl _ (L _ decl) _))) = do + let types = decl ^? biplateRef :: [HsType GhcPs] + when (generateTypesRules) $ print $ map (\x -> (showSDocUnsafe $ ppr x)) types + processHsSplice decl +getInstancesInfo (L l (DerivD _ x@(DerivDecl{deriv_type=derivType}))) = do + case sig_body $ unXRec @(GhcPs) $ hswc_body $ derivType of + (L _ (HsAppTy _ ty1 ty2)) -> pure $ [(showSDocUnsafe $ ppr ty1,showSDocUnsafe $ ppr ty2)] + (L _ (HsQualTy _ mContext (L _ (HsAppTy _ ty1 ty2)))) -> do + pure [(showSDocUnsafe $ ppr ty1,showSDocUnsafe $ ppr ty2)] + (L _ x) -> do + when (generateTypesRules) $ print $ (toConstr x,showSDocUnsafe $ ppr x) + let types = x ^? biplateRef :: [HsType GhcPs] + when (generateTypesRules) $ print $ showSDocUnsafe $ ppr types + pure mempty +getInstancesInfo (L l (InstD _ (ClsInstD _ (ClsInstDecl{cid_poly_ty=cidPolyTy})))) = + case sig_body $ unXRec @(GhcPs) $ cidPolyTy of + (L _ (HsAppTy _ ty1 ty2)) -> pure $ [(showSDocUnsafe $ ppr ty1,showSDocUnsafe $ ppr ty2)] + (L _ (HsQualTy _ mContext (L _ (HsAppTy _ ty1 ty2)))) -> do + pure [(showSDocUnsafe $ ppr ty1,showSDocUnsafe $ ppr ty2)] + (L _ x) -> do + when (generateTypesRules) $ print $ (toConstr x,showSDocUnsafe $ ppr x) + let types = x ^? biplateRef :: [HsType GhcPs] + when (generateTypesRules) $ print $ showSDocUnsafe $ ppr types + pure mempty +-- getInstancesInfo (L l (InstD _ (DataFamInstD _ dfidInst))) = dfidInst ^? biplateRef :: [HsSigType GhcPs] +-- getInstancesInfo (L l (InstD _ (TyFamInstD _ tfidInst))) = tfidInst ^? biplateRef :: [HsSigType GhcPs] +getInstancesInfo (L l x) = do + -- print $ (toConstr x,showSDocUnsafe $ ppr x) + pure mempty + +getTypeInfo :: LHsDecl GhcPs -> [(SrcSpan,String,TypeRule)] +getTypeInfo (L l (TyClD _ (DataDecl _ lname _ _ defn))) = + [(locA l ,showSDocUnsafe' lname ,TypeRule + { typeKind = "data" + , caseType = Nothing + , instances = mempty + , dataConstructors = Map.fromList $ map getDataConInfo (dd_cons defn) + })] +getTypeInfo (L l (TyClD _ (SynDecl _ lname _ _ rhs))) = + [(locA l ,showSDocUnsafe' lname,TypeRule + { typeKind = "type" + , caseType = Nothing + , instances = mempty +#if __GLASGOW_HASKELL__ >= 900 + , dataConstructors = Map.singleton (showSDocUnsafe' lname) (DataConInfo (maybe mempty (Map.singleton "synonym" . unpackHDS) (hsTypeToString $ unLoc rhs)) []) +#else + , dataConstructors = Map.singleton (showSDocUnsafe' lname) (DataConInfo (Map.singleton "synonym" ((showSDocUnsafe . ppr . unLoc) rhs)) []) +#endif + })] +getTypeInfo _ = mempty + +getDataConInfo :: LConDecl GhcPs -> (String,DataConInfo) +getDataConInfo (L _ ConDeclH98{ con_name = lname, con_args = args }) = + (showSDocUnsafe' lname,DataConInfo + { fields' = getFieldMap args + , sumTypes = [] -- For H98-style data constructors, sum types are not applicable + }) +getDataConInfo (L _ ConDeclGADT{ con_names = lnames, con_res_ty = ty }) = + (intercalate ", " (map showSDocUnsafe' lnames),DataConInfo + { +#if __GLASGOW_HASKELL__ >= 900 + fields' = maybe (mempty) (\x -> Map.singleton "gadt" $ unpackHDS x) (hsTypeToString $ unLoc ty) +#else + fields' = Map.singleton "gadt" (showSDocUnsafe $ ppr ty) +#endif + , sumTypes = [] -- For GADT-style data constructors, sum types can be represented by the type itself + }) + +#if __GLASGOW_HASKELL__ >= 900 +hsTypeToString :: HsType GhcPs -> Maybe HsDocString +hsTypeToString = f + where + f :: HsType GhcPs -> Maybe HsDocString + f (HsDocTy _ _ lds) = Just (unLoc lds) + f (HsBangTy _ _ (L _ (HsDocTy _ _ lds))) = Just (unLoc lds) + f x = Just (mkHsDocString $ showSDocUnsafe $ ppr x) + +extractInfixCon :: [HsType GhcPs] -> Map.Map String String +extractInfixCon x = + let l = length x + in Map.fromList $ map (\(a,b) -> (show a , b)) $ Prelude.zip [0..l] (map f x) + where + f :: HsType GhcPs -> (String) + f (HsDocTy _ _ lds) = showSDocUnsafe $ ppr $ (unLoc lds) + f (HsBangTy _ _ (L _ (HsDocTy _ _ lds))) = showSDocUnsafe $ ppr $ (unLoc lds) + f x = (showSDocUnsafe $ ppr x) + +extractConDeclField :: [ConDeclField GhcPs] -> Map.Map String String +extractConDeclField x = Map.fromList (go x) + where + go :: [ConDeclField GhcPs] -> [(String,String)] + go [] = [] + go ((ConDeclField _ cd_fld_names cd_fld_type _):xs) = + [((intercalate "," $ convertRdrNameToString cd_fld_names),(showSDocUnsafe $ ppr cd_fld_type))] <> (go xs) + + convertRdrNameToString x = map (showSDocUnsafe . ppr . rdrNameOcc . unLoc . reLocN . rdrNameFieldOcc . unXRec @(GhcPs)) x + +getFieldMap :: HsConDeclH98Details GhcPs -> Map.Map String String +getFieldMap con_args = + case con_args of + PrefixCon _ args -> extractInfixCon $ map (unLoc . hsScaledThing) args + InfixCon arg1 arg2 -> extractInfixCon $ map (unLoc . hsScaledThing) [arg1,arg2] + RecCon (fields) -> extractConDeclField $ map unLoc $ (unXRec @(GhcPs)) fields + +#else +getFieldMap :: HsConDeclDetails GhcPs -> Map String String +getFieldMap (PrefixCon args) = Map.fromList $ Prelude.zipWith (\i t -> (show i, showSDocUnsafe (ppr t))) [1..] args +getFieldMap (RecCon (L _ fields)) = Map.fromList $ concatMap getRecField fields + where + getRecField (L _ (ConDeclField _ fnames t _)) = [(showSDocUnsafe (ppr fname), showSDocUnsafe (ppr t)) | L _ fname <- fnames] +getFieldMap (InfixCon t1 t2) = Map.fromList [("field1", showSDocUnsafe (ppr t1)), ("field2", showSDocUnsafe (ppr t2))] +#endif + +#if __GLASGOW_HASKELL__ >= 900 +showSDocUnsafe' = showSDocUnsafe . ppr . GHC.unXRec @(GhcPs) +#else +showSDocUnsafe' = showSDocUnsafe . ppr +#endif + +fetchRules :: String -> IO (Either YAML.ParseException (HM.KeyMap TypeRule)) +fetchRules = YAML.decodeFileEither + +runFieldNameAndTypeRule :: String -> Maybe CaseType -> Map.Map String DataConInfo -> Map.Map String DataConInfo -> IO [ApiContractError] +runFieldNameAndTypeRule typeName caseType rules codeExtract = do + foldM (\acc (dataConName,x) -> + case Map.lookup dataConName codeExtract of + Just val -> do + res <- toList $ mapM (checkCaseType caseType) (fromList $ Map.keys $ fields' val) + res' <- checkIfAllFieldsArePresent (Map.keys $ fields' x) (Map.keys $ fields' val) + res' <- checkIfAllFieldsArePresentInverse (Map.keys $ fields' val) (Map.keys $ fields' x) + res'' <- checkAllFieldTypes (HM.fromList $ map (\(a,b) -> (HM.fromString a, b)) $ Map.toList $ fields' x) (HM.fromList $ map (\(a,b) -> (HM.fromString a, b)) $ Map.toList $ fields' val) + pure $ acc <> (Prelude.concat res) <> res' <> res'' + Nothing -> pure $ acc <> [(MISSING_DATACON dataConName typeName)] + ) mempty $ Map.toList rules + where + checkIfAllFieldsArePresent :: [String] -> [String] -> IO [ApiContractError] + checkIfAllFieldsArePresent fromRules fromCode = + foldM (\acc x -> + if x `Prelude.elem` fromCode then pure acc else pure $ acc <> [(MISSING_FIELD_IN_CODE x typeName)] + ) mempty (fromRules) + + checkIfAllFieldsArePresentInverse :: [String] -> [String] -> IO [ApiContractError] + checkIfAllFieldsArePresentInverse fromRules fromCode = + foldM (\acc x -> + if x `Prelude.elem` fromCode then pure acc else pure $ acc <> [(MISSING_FIELD_IN_RULES x typeName)] + ) mempty (fromRules) + + checkAllFieldTypes :: HM.KeyMap String -> HM.KeyMap String -> IO [ApiContractError] + checkAllFieldTypes fromRules fromCode = + foldM (\acc (_fieldName,_type) -> + case HM.lookup _fieldName fromCode of + Just val -> if val == _type + then pure acc + else pure $ acc <> [(TYPE_MISMATCH (HM.toString _fieldName) (_type) (val) typeName)] + Nothing -> pure $ acc <> [(MISSING_FIELD_IN_CODE (HM.toString _fieldName) typeName)] + ) mempty (HM.toList fromRules) + + isSnakeCase :: String -> Bool + isSnakeCase = Prelude.all (\c -> isLower c || c == '_' || isAlphaNum c) + + isCamelCase :: String -> Bool + isCamelCase s = Prelude.all (\(c, i) -> if i == 0 then isLower c else isAlphaNum c) (zip s [0..]) + + isPascalCase :: String -> Bool + isPascalCase s = Prelude.all (\(c, i) -> if i == 0 then isUpper c else isAlphaNum c) (zip s [0..]) + + isKebabCase :: String -> Bool + isKebabCase = Prelude.all (\c -> isLower c || c == '-' || isAlphaNum c) + + checkCaseType :: Maybe CaseType -> String -> IO [ApiContractError] + checkCaseType (Just SnakeCase) field = bool (pure [(FIELD_CASE_MISMATCH typeName field SnakeCase)]) (pure mempty) $ isSnakeCase field + checkCaseType (Just CamelCase) field = bool (pure [(FIELD_CASE_MISMATCH typeName field CamelCase)]) (pure mempty) $ isCamelCase field + checkCaseType (Just PascalCase) field = bool (pure [(FIELD_CASE_MISMATCH typeName field PascalCase)]) (pure mempty) $ isPascalCase field + checkCaseType (Just KebabCase) field = bool (pure [(FIELD_CASE_MISMATCH typeName field KebabCase)]) (pure mempty) $ isKebabCase field + checkCaseType _ _ = pure mempty + +generateTypesRules :: Bool +generateTypesRules = readBool $ unsafePerformIO $ lookupEnv "DUMP_TYPE_RULES" + where + readBool :: Maybe String -> Bool + readBool (Just "true") = True + readBool (Just "True") = True + readBool (Just "TRUE") = True + readBool _ = False + +mkFileSrcSpan :: ModLocation -> SrcSpan +mkFileSrcSpan mod_loc + = case ml_hs_file mod_loc of + Just file_path -> mkGeneralSrcSpan (mkFastString file_path) + Nothing -> interactiveSrcSpan + diff --git a/api-contract/src/ApiContract/Types.hs b/api-contract/src/ApiContract/Types.hs new file mode 100644 index 0000000..b4344d4 --- /dev/null +++ b/api-contract/src/ApiContract/Types.hs @@ -0,0 +1,133 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# LANGUAGE StandaloneDeriving #-} + +module ApiContract.Types where + +import Data.Aeson (ToJSON,FromJSON) +import GHC.Generics (Generic) +import qualified Data.Map as Map +import GHC.Hs (SrcSpanAnnA) +import GHC (SrcSpan(..), RealSrcSpan(..)) + +data TypeOfInstance = + Derived + | Custom + deriving (Eq,Show, Generic,FromJSON,ToJSON) + +data CaseType = + SnakeCase + | CamelCase + | PascalCase + | KebabCase + deriving (Eq,Show, Generic,FromJSON,ToJSON) + +data InstancePresence = + ToJSON + | ParseJSON + | ToEncoding + deriving (Eq,Show, Generic,FromJSON,ToJSON) + +data TypeRule = TypeRule + { caseType :: Maybe CaseType + , dataConstructors :: Map.Map String DataConInfo + , instances :: Map.Map String InstanceFromTC + , typeKind :: String + } deriving (Show, Generic,FromJSON,ToJSON) + +data Types = Types + { types :: Map.Map String TypeRule + } deriving (Show, Generic,FromJSON,ToJSON) + +data DataConInfo = DataConInfo + { fields' :: Map.Map String String + , sumTypes :: [String] + } deriving (Show, Eq, Ord,Generic,ToJSON,FromJSON) + +data InstanceFromTC = InstanceFromTC + { + fieldsList :: [String] + , typeOfInstance :: TypeOfInstance + } + deriving (Show, Generic,FromJSON,ToJSON) + + +data ApiContractError = + -- fieldName typeName + MISSING_FIELD_IN_RULES String String + -- fieldName typeName + | MISSING_FIELD_IN_CODE String String + -- fieldName expectedType typeFromCode typeName + | TYPE_MISMATCH String String String String + -- dataConName typeName + | MISSING_DATACON String String + -- typeName fieldName caseType + | FIELD_CASE_MISMATCH String String CaseType + -- typeName + | MISSING_TYPE_CODE String + -- typeName ruleYAML + | MISSING_TYPE_IN_RULE String String + -- typeName instanceName + | MISSING_INSTANCE_IN_CODE String String + -- typeName instanceName ruleYAML + | MISSING_INSTANCE_IN_RULES String String String + -- typeName instanceName typeOfInsance + | TYPE_OF_INSTANCE_CHANGED String String TypeOfInstance + -- typeName instanceName fieldName + | MISSING_FIELD_IN_INSTANCE_CODE String String String + -- typeName instanceName fieldName + | MISSING_FIELD_IN_INSTANCE_RULES String String String + deriving (Eq,Show, Generic,FromJSON,ToJSON) + +generateErrorMessage :: FilePath -> ApiContractError -> String +generateErrorMessage yamlFilePath (MISSING_FIELD_IN_RULES fieldName typeName) = + "Error: The field '" ++ fieldName ++ "' is missing in the rules for type '" ++ typeName ++ "'.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd the field under the appropriate type's fields section." +generateErrorMessage yamlFilePath (MISSING_FIELD_IN_CODE fieldName typeName) = + "Error: The field '" ++ fieldName ++ "' is missing in the code for type '" ++ typeName ++ "'.\n\n" ++ + "\tPlease add the field '" ++ fieldName ++ "' to the type '" ++ typeName ++ "' in your code." +generateErrorMessage yamlFilePath (TYPE_MISMATCH fieldName expectedType typeFromCode typeName) = + "Error: Type mismatch for field '" ++ fieldName ++ "' in type '" ++ typeName ++ "'. Expected type: '" ++ expectedType ++ "', but found: '" ++ typeFromCode ++ "'.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tChange the type of the field '" ++ fieldName ++ "' to '" ++ expectedType ++ "' under the appropriate type's fields section." +generateErrorMessage yamlFilePath (MISSING_DATACON dataConName typeName) = + "Error: The data constructor '" ++ dataConName ++ "' is missing in type '" ++ typeName ++ "'.\n\n" ++ + "Please add the data constructor '" ++ dataConName ++ "' to the type '" ++ typeName ++ "' in your code." +generateErrorMessage yamlFilePath (FIELD_CASE_MISMATCH typeName fieldName caseType) = + "Error: Field name case mismatch for field '" ++ fieldName ++ "' in type '" ++ typeName ++ "'. Expected case: " ++ show caseType ++ ".\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tChange the field name to follow the " ++ show caseType ++ " convention under the appropriate type's fields section." +generateErrorMessage yamlFilePath (MISSING_TYPE_CODE typeName) = + "Error: The type '" ++ typeName ++ "' is missing.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd the type '" ++ typeName ++ "' to the types section." +generateErrorMessage yamlFilePath (MISSING_TYPE_IN_RULE typeName yamlRule) = + "Error: The type '" ++ typeName ++ "' is missing in the rules.\n\n" ++ + "\tYou should add the rule in the file: " ++ yamlFilePath ++ + "\n\tAdd the type '" ++ typeName ++ "' to the types section." ++ "\n\n" ++ yamlRule +generateErrorMessage yamlFilePath (MISSING_INSTANCE_IN_CODE typeName instanceName) = + "Error: The instance '" ++ instanceName ++ "' is missing for type '" ++ typeName ++ "'in the code.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd the instance '" ++ instanceName ++ "' under the appropriate type's instances section." +generateErrorMessage yamlFilePath (MISSING_INSTANCE_IN_RULES typeName instanceName yamlRule) = + "Error: The instance '" ++ instanceName ++ "' is missing for type '" ++ typeName ++ "' in the rules.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd the instance '" ++ instanceName ++ "' under the appropriate type's instances section." ++ "\n\n" ++ yamlRule +generateErrorMessage yamlFilePath (TYPE_OF_INSTANCE_CHANGED typeName instanceName typeOfInstance) = + "Error: The type of instance '" ++ instanceName ++ "' for type '" ++ typeName ++ "' has changed to " ++ show typeOfInstance ++ ".\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tChange the type of the instance '" ++ instanceName ++ "' under the appropriate type's instances section." +generateErrorMessage yamlFilePath (MISSING_FIELD_IN_INSTANCE_CODE typeName instanceName fieldName) = + "Error: The field '" ++ fieldName ++ "' is missing in the instance '" ++ instanceName ++ "' for type '" ++ typeName ++ "'.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd/remove the field '" ++ fieldName ++ "' under the appropriate instance's fieldsList section." +generateErrorMessage yamlFilePath (MISSING_FIELD_IN_INSTANCE_RULES typeName instanceName fieldName) = + "Error: The field '" ++ fieldName ++ "' is missing in rules in instance '" ++ instanceName ++ "' for type '" ++ typeName ++ "'.\n\n" ++ + "\tYou can update the change in the file: " ++ yamlFilePath ++ + "\n\tAdd/remove the field '" ++ fieldName ++ "' under the appropriate instance's fieldsList section." diff --git a/api-contract/test/Main.hs b/api-contract/test/Main.hs new file mode 100644 index 0000000..3eb4cc9 --- /dev/null +++ b/api-contract/test/Main.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE StandaloneDeriving,DeriveDataTypeable #-} + +module Main (main) where + +import Data.Aeson +import Data.Text +import GHC.Generics (Generic) +import Data.Aeson.KeyMap +import Data.Data + + +main :: IO () +main = putStrLn "Test suite not yet implemented." + +data RefundAttempt' = RefundAttempt' + { id'' :: Maybe Text + , created'' :: Text + , ref'' :: Maybe Text + , error_message'' :: Maybe Text + , error_code'' :: Maybe Value + , last_modified'' :: Maybe Text + } + deriving (Show, Eq, Generic) + +deriving instance Data RefundAttempt' +deriving instance ToJSON RefundAttempt' +deriving instance FromJSON RefundAttempt' + +data RefundAttempt = RefundAttempt + { id' :: Maybe Text + , created :: Text + , ref :: Maybe Text + , error_message :: Maybe Text + , error_code :: Maybe Value + , last_modified :: Maybe Text + } + deriving (Show, Eq, Generic) + +deriving instance Data RefundAttempt + +instance ToJSON RefundAttempt where +-- toJSON RefundAttempt{..} = Data.Aeson.object $ mconcat [[ +-- "id'" .= id', +-- "created" .= created +-- ] +-- , toV "ref" ref +-- , toV "error_message" error_message +-- , toV "error_code" error_code +-- , toV "last_modified" last_modified +-- ] +-- where +-- toV _ Nothing = [] +-- toV t (Just v) = [t .= v] + + toEncoding RefundAttempt{..} = Data.Aeson.pairs $ mconcat [ + "id'" .= id', + "created" .= created + , toE "ref" ref + , toE "error_message" error_message + , toE "error_code" error_code + , toE "last_modified" last_modified + ] + where + toE _ Nothing = mempty + toE t (Just v) = t .= v + +instance FromJSON RefundAttempt where + parseJSON = withObject "Refund'" $ \o -> do + id' <- o .:? "id'" + created <- o .: "created" + ref <- o .:? "ref" + error_message <- o .:? "error_message" + error_code <- o .:? "error_code" + last_modified <- o .:? "last_modified" + pure $ RefundAttempt {..} + diff --git a/cabal.project b/cabal.project index 2d76ffa..b6b8922 100644 --- a/cabal.project +++ b/cabal.project @@ -1,5 +1,7 @@ packages: ./fdep - ./coresyn2chart - ./sheriff ./fieldInspector + ./sheriff + ./paymentFlow + ./api-contract + ./dc diff --git a/coresyn2chart/coresyn2chart.cabal b/coresyn2chart/coresyn2chart.cabal index 4b2b1f9..f6a917f 100644 --- a/coresyn2chart/coresyn2chart.cabal +++ b/coresyn2chart/coresyn2chart.cabal @@ -38,13 +38,13 @@ common common-options bytestring , containers , filepath - , ghc ^>= 8.10.7 + , ghc , unordered-containers , aeson , directory , extra , aeson-pretty - , base ^>=4.14.3.0 + , base , text , base64-bytestring , optparse-applicative @@ -55,7 +55,6 @@ common common-options , hasbolt , universum , data-default - , streamly library import: common-options diff --git a/dc/CHANGELOG.md b/dc/CHANGELOG.md new file mode 100644 index 0000000..6d8ba26 --- /dev/null +++ b/dc/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for dc + +## 0.1.0.0 -- YYYY-mm-dd + +* First version. Released on an unsuspecting world. diff --git a/dc/LICENSE b/dc/LICENSE new file mode 100644 index 0000000..3ebfa94 --- /dev/null +++ b/dc/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2024, Chaitanya Nair +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the + distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/dc/dc.cabal b/dc/dc.cabal new file mode 100644 index 0000000..77462d1 --- /dev/null +++ b/dc/dc.cabal @@ -0,0 +1,106 @@ +cabal-version: 3.0 +-- The cabal-version field refers to the version of the .cabal specification, +-- and can be different from the cabal-install (the tool) version and the +-- Cabal (the library) version you are using. As such, the Cabal (the library) +-- version used must be equal or greater than the version stated in this field. +-- Starting from the specification version 2.2, the cabal-version field must be +-- the first thing in the cabal file. + +-- Initial package description 'dc' generated by +-- 'cabal init'. For further documentation, see: +-- http://haskell.org/cabal/users-guide/ +-- +-- The name of the package. +name: dc + +-- The package version. +-- See the Haskell package versioning policy (PVP) for standards +-- guiding when and how versions should be incremented. +-- https://pvp.haskell.org +-- PVP summary: +-+------- breaking API changes +-- | | +----- non-breaking API additions +-- | | | +--- code changes with no API change +version: 0.1.0.0 + +-- A short (one-line) description of the package. +-- synopsis: + +-- A longer description of the package. +-- description: + +-- The license under which the package is released. +license: BSD-2-Clause + +-- The file containing the license text. +license-file: LICENSE + +-- The package author(s). +author: Chaitanya Nair + +-- An email address to which users can send suggestions, bug reports, and patches. +maintainer: chaitanya.nair@juspay.in + +-- A copyright notice. +-- copyright: +build-type: Simple + +-- Extra doc files to be distributed with the package, such as a CHANGELOG or a README. +extra-doc-files: CHANGELOG.md + +-- Extra source files to be distributed with the package, such as examples, or a tutorial module. +-- extra-source-files: + +common warnings + ghc-options: -Wall + +library + -- Import common warning flags. + import: warnings + + -- Modules exported by the library. + exposed-modules: DC.DefaultCheck + default-extensions: OverloadedStrings + + -- Modules included in this library but not exported. + -- other-modules: + + -- LANGUAGE extensions used by modules in this package. + -- other-extensions: + + -- Other library packages from which modules are imported. + build-depends: base, directory, bytestring, aeson, extra, aeson-pretty, unordered-containers, uniplate, ghc, references, yaml + + + -- Directories containing source files. + hs-source-dirs: src + + -- Base language which the package is written in. + default-language: Haskell2010 + +test-suite dc-test + -- Import common warning flags. + import: warnings + + -- Base language which the package is written in. + default-language: Haskell2010 + + ghc-options: -fplugin=DC.DefaultCheck + -- Modules included in this executable, other than Main. + -- other-modules: + + -- LANGUAGE extensions used by modules in this package. + -- other-extensions: + + -- The interface type and version of the test suite. + type: exitcode-stdio-1.0 + + -- Directories containing source files. + hs-source-dirs: test + + -- The entrypoint to the test suite. + main-is: Main.hs + + -- Test dependencies. + build-depends: + base ^>=4.16.4.0, + dc diff --git a/dc/src/DC/Constants.hs b/dc/src/DC/Constants.hs new file mode 100644 index 0000000..f453ed8 --- /dev/null +++ b/dc/src/DC/Constants.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module DC.Constants where + +createError :: String +createError = "Should not create field with status as success or failure in default case" + +updateError :: String +updateError = "Should not update field with status as success or failure in default case" + +defaultCase :: String +defaultCase = "Should not use status as success or failure in default case" + +syncError :: String +syncError = "Should not use exception functions for sync in default cases" + +prefixPath :: String +prefixPath = "./.juspay/dc/" \ No newline at end of file diff --git a/dc/src/DC/DefaultCheck.hs b/dc/src/DC/DefaultCheck.hs new file mode 100644 index 0000000..0d4a1be --- /dev/null +++ b/dc/src/DC/DefaultCheck.hs @@ -0,0 +1,1071 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DeriveAnyClass, CPP, TypeApplications #-} +{-# LANGUAGE DeriveGeneric, ScopedTypeVariables, TypeFamilies, RecordWildCards, PartialTypeSignatures #-} +-- {-# OPTIONS_GHC -ddump-parsed-ast #-} + +module DC.DefaultCheck where + +import Control.Reference (biplateRef, (^?)) +import GHC hiding (typeKind) +import Data.Aeson +import Data.Aeson as A +import Data.Generics.Uniplate.Data () +import Data.List +import Data.Maybe (mapMaybe, catMaybes, fromMaybe) +import qualified Data.HashMap.Strict as HM +import Data.Data +import Control.Monad.Extra (anyM) +import Data.List.Extra (replace,splitOn) +import Data.Aeson.Encode.Pretty (encodePretty) +import Control.Monad.Extra (filterM, ifM) +import qualified Data.Aeson as Aeson +import qualified Data.ByteString.Lazy as B +import Control.Exception (try,SomeException) +import System.Directory (createDirectoryIfMissing, doesFileExist ) +import Data.Yaml +import qualified Data.ByteString.Lazy.Char8 as Char8 +import DC.Types +import DC.Constants +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.ConLike +import qualified GHC.Utils.Outputable as OP +import GHC.Tc.Utils.Monad (addErrs) +import GHC.Data.Bag (bagToList) +import GHC.Tc.Types +import GHC.Driver.Plugins +import GHC.Types.Var +import GHC.Utils.Outputable hiding ((<>)) +import GHC.Types.Name hiding (varName) +import Control.Monad.IO.Class (liftIO) +import GHC.Utils.Misc +import GHC.Core.DataCon +import GHC.Core.PatSyn +#else +import ConLike +import qualified Outputable as OP +import TcRnMonad ( addErrs) +import Bag (bagToList) +import TcRnTypes (TcGblEnv (..), TcM) +import GhcPlugins hiding ((<>)) --(Plugin(..), defaultPlugin, purePlugin, CommandLineOption, ModSummary(..), Module (..), liftIO, moduleNameString, showSDocUnsafe, ppr, nameStableString, varName) +#endif + +plugin :: Plugin +plugin = + defaultPlugin { + typeCheckResultAction = checkIntegrity + , pluginRecompile = purePlugin + } + +checkIntegrity :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv +checkIntegrity opts modSummary tcEnv = do + let PluginOpts{..} = case opts of + [] -> defaultPluginOpts + (x : _) -> fromMaybe defaultPluginOpts $ A.decode (Char8.pack x) + let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + modulePath = prefixPath <> ms_hspp_file modSummary + parsedYaml :: Either ParseException CheckerConfig <- liftIO $ parseYAMLFile domainConfigFile + case parsedYaml of + Right conf -> do + let path = (intercalate "/" . reverse . tail . reverse . splitOn "/") modulePath + liftIO $ createDirectoryIfMissing True path + getAllUpdatesLi <- mapM (loopOverLHsBindLR pathsTobeChecked conf moduleName') (bagToList $ tcg_binds tcEnv) + let getAllUpdatesList = fst <$> getAllUpdatesLi + getAllFuns = HM.unions $ snd <$> getAllUpdatesLi + let res = foldl (\(UpdateInfo acc1 acc2 acc3 acc4 acc5 acc6) (UpdateInfo x y z z1 z2 otherFuns) -> UpdateInfo ((acc1) ++ (changeModName moduleName' <$> x)) ((acc2) ++ (changeModName moduleName' <$> y)) ((acc3) ++ (changeModName moduleName' <$> z)) ((acc4) ++ (changeModName moduleName' <$> z1)) ((acc5) ++ (changeModName moduleName' <$> z2)) ((acc6) ++ (changeModName moduleName' <$> otherFuns))) (UpdateInfo [] [] [] [] [] []) $ catMaybes getAllUpdatesList + let combination = lookUpAndConcat getAllFuns + allRes = getAllRes conf moduleName' combination res + liftIO $ B.writeFile (modulePath <> ".json") (encodePretty $ (\(UpdateInfo createRec upRecords upFails cFails allFails otherFuns) -> UpdateInfoAsText (nub $ name <$> createRec) (nub $ name <$> upRecords) (nub $ name <$> upFails) (nub $ name <$> cFails) (nub $ name <$> allFails) (nub $ name <$> otherFuns)) allRes) + !exprs <- mapM (loopOverLHsBindLRTot pathsTobeChecked conf path allRes moduleName') (bagToList $ tcg_binds tcEnv) + case conf of + FieldsCheck _ -> do + let exprsC = foldl (\acc (val) -> acc ++ getErrorrs val ) [] exprs + addErrs $ map (mkGhcCompileError) (exprsC) + FunctionCheck (FunctionCheckConfig{..}) -> do + if moduleName' == moduleNameToCheck then do + let exprsC = foldl (\acc (val) -> acc ++ getErrorrs val ) [] exprs + addErrs $ map (mkGhcCompileError) (exprsC) + else do + let exprsC = foldl (\acc (val) -> HM.union acc (getFuncs val) ) HM.empty exprs + liftIO $ B.writeFile (modulePath <> ".err.json") (encodePretty exprsC) + pure tcEnv + Left err -> do + liftIO $ print $ "Not using dc plugin since no config is found" ++ show err + pure tcEnv + +getErrorrs :: ErrorCase -> [CompileError] +getErrorrs (Errors val) = val +getErrorrs _ = [] + +getFuncs :: ErrorCase -> HM.HashMap String [CompileError] +getFuncs (Functions val) = val +getFuncs _ = HM.empty + +getAllRes :: CheckerConfig -> String -> HM.HashMap String [FunctionInfo] -> UpdateInfo -> UpdateInfo +getAllRes conf moduleName' combination res = + case conf of + FieldsCheck (EnumCheck{..}) -> do + HM.foldlWithKey (\acc key val -> + if any (\x -> name x `elem` (name <$> updatedFailurs res)) val + then acc{updatedFailurs = nub $ updatedFailurs acc ++ ([FunctionInfo "" moduleName' key "" False])} + else if any (\x -> name x `elem` (name <$> createdFailurs res) ) val + then acc{createdFailurs = nub $ createdFailurs acc ++ ([FunctionInfo "" moduleName' key "" False])} + else if any (\x -> name x `elem` (name <$> createdRecordsFun res)) val + then if any (\x -> name x `elem` enumList) val + then + acc{createdFailurs = nub $ createdFailurs acc ++ ([FunctionInfo "" moduleName' key "" False])} + else + acc{createdRecordsFun = nub $ createdRecordsFun acc ++ ([FunctionInfo "" moduleName' key "" False])} + else if any (\x -> name x `elem` (name <$> updatedRecordsFun res) ) val + then if any (\x -> name x `elem` enumList) val + then + acc{updatedFailurs = nub $ updatedFailurs acc ++ ([FunctionInfo "" moduleName' key "" False])} + else acc{updatedRecordsFun = nub $ updatedRecordsFun acc ++ ([FunctionInfo "" moduleName' key "" False])} + else if any (\x -> name x `elem` (name <$> allFailures res) ) val + then acc{allFailures = nub $ allFailures acc ++ ([FunctionInfo "" moduleName' key "" False])} + else acc) res combination + _ -> res + +lookUpAndConcat :: HM.HashMap String [FunctionInfo] -> HM.HashMap String [FunctionInfo] +lookUpAndConcat hm = HM.map (\val -> nub $ foldl (\acc x -> nub $ acc ++ lookupEachKey hm [] x) val val ) hm + +changeModName :: String -> FunctionInfo -> FunctionInfo +changeModName moduleName' (FunctionInfo x y z z2 isF) = FunctionInfo x (if y == "_in" then moduleName' else y) z z2 isF + +lookupEachKey :: HM.HashMap String [FunctionInfo] -> [String] -> FunctionInfo -> [FunctionInfo] +lookupEachKey hm alreadyVisited x = case HM.lookup (name x) hm of + Just val -> [x] ++ if name x `elem` (alreadyVisited) then [] else (nub $ concat $ (lookupEachKey hm (alreadyVisited ++ ([name x])) <$> (nub (val)))) + Nothing -> [x] + +processAllLetPats :: LHsBindLR GhcTc GhcTc -> (Maybe (String, [FunctionInfo])) +#if __GLASGOW_HASKELL__ >= 900 +processAllLetPats (L _ (FunBind _ name matches _)) = do +#else +processAllLetPats (L _ (FunBind _ name matches _ _)) = do +#endif + let inte = unLoc $ mg_alts matches + if null inte then Nothing + else Just (nameStableString $ varName $ unLoc name, concat $ map (\(GRHS _ _ val) -> processExpr val) $ map unLoc $ grhssGRHSs $ m_grhss $ unLoc $ head $ inte ) +processAllLetPats (L _ _) = do + Nothing + + +loopOverLHsBindLRTot :: [String] -> CheckerConfig -> String -> UpdateInfo -> String -> LHsBindLR GhcTc GhcTc -> TcM ErrorCase +loopOverLHsBindLRTot allPaths conf path allFuns moduleName' vals@(L _ AbsBinds {abs_binds = binds}) = do + case conf of + FunctionCheck (FunctionCheckConfig{..}) -> + if moduleName' == moduleNameToCheck + then do + let funName = (getFunctionName vals) + -- liftIO $ print(funName) + if ("$_in$" ++ funNameToCheck) `elem` funName then do -- getTxnStatusFromGateway + let binds1 = ( (bagToList binds) ^? biplateRef :: [LHsExpr GhcTc]) + let allNrFuns = nub $ ((concatMap processExpr binds1)) + -- liftIO $ print("ALLFUN", allNrFuns) + first <- catMaybes <$> (liftIO $ (mapM ((\x@(FunctionInfo _ _ _ _ _) -> do + nc <- checkInOtherModsWithoutErrorFuns allPaths conf moduleName' x + -- print(nc, x) + if null nc then pure Nothing else (pure $ Just $ nc))) (allNrFuns))) + pure $ Errors $ concat first + else pure $ Errors [] + else do + let allVals = ((bagToList binds ^? biplateRef :: [LHsExpr GhcTc])) + allFunsWithFailure <- mapM (getFunctionNameIfFailure allPaths conf "" [] "" "" moduleName') (bagToList binds ^? biplateRef) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (bagToList binds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + let funName = map (\y -> transformFromNameStableString y (showSDocUnsafe $ ppr $ getLoc $ vals) False ) (getFunctionName vals) + let val = map (\(upType,listY) -> (createUpdateInfo upType $ map (\y -> transformFromNameStableString y (showSDocUnsafe $ ppr $ getLoc $ vals) False) listY)) allFunsWithFailure + allC <- nub <$> (mapM (loopOverModBinds allPaths conf path allLetPats allFuns moduleName' val) allVals) + let allV = foldl (\acc (val1) -> acc ++ getErrorrs val1 ) [] allC + -- liftIO $ print (allC, allV, funName) + pure $ if null allV then Functions HM.empty else Functions (foldl (\acc val1 -> HM.insert val1 allV acc) HM.empty (name <$> funName)) + FieldsCheck (EnumCheck{..}) -> do + let allVals = ((bagToList binds ^? biplateRef :: [LHsExpr GhcTc])) + allFunsWithFailure <- mapM (getFunctionNameIfFailure allPaths conf recordType enumList enumType fieldType moduleName') (bagToList binds ^? biplateRef) + let val = map (\(upType,listY) -> (createUpdateInfo upType $ map (\y -> transformFromNameStableString y (showSDocUnsafe $ ppr $ getLoc $ vals) False) listY)) allFunsWithFailure + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (bagToList binds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + allC <- nub <$> (mapM (loopOverModBinds allPaths conf path allLetPats allFuns moduleName' val) allVals) + case conf of + FieldsCheck _ -> pure $ Errors $ foldl (\acc (vals1) -> acc ++ getErrorrs vals1 ) [] allC + -- FunctionCheck _ -> do + -- let allV = foldl (\acc (vals1) -> acc ++ getErrorrs vals1 ) [] allC + -- -- liftIO $ print (allC, allV, funName) + -- pure $ if null allV then Functions HM.empty else Functions (foldl (\acc vals1 -> HM.insert vals1 allV acc) HM.empty (name <$> funName)) +loopOverLHsBindLRTot _ _ _ _ _ _ = pure $ Errors [] + +createUpdateInfo :: TypeOfUpdate -> [FunctionInfo] -> UpdateInfo +createUpdateInfo Update list = UpdateInfo [] list [] [] [] [] +createUpdateInfo Create list = UpdateInfo list [] [] [] [] [] +createUpdateInfo CreateWithFailure list = UpdateInfo [] [] [] list [] [] +createUpdateInfo UpdateWithFailure list = UpdateInfo [] [] list [] [] [] +createUpdateInfo Default list = UpdateInfo [] [] list [] [] [] +createUpdateInfo _ _ = UpdateInfo [] [] [] [] [] [] + +loopOverModBinds :: [String] -> CheckerConfig -> String -> HM.HashMap String [FunctionInfo] -> UpdateInfo -> String -> [UpdateInfo] -> LHsExpr GhcTc -> TcM ErrorCase +loopOverModBinds allPaths checkerCase path allFUnsInside allFuns moduleName' allPatsList (L _ (HsCase _ _ exprLStmt)) = do + -- liftIO $ print ("val",allFuns) + allFunsPats <- mapM (loopOverPats allPaths checkerCase path allFUnsInside allFuns moduleName' allPatsList) $ map unLoc $ unLoc $ mg_alts exprLStmt + pure $ Errors $ foldl (\acc (val) -> acc ++ getErrorrs val ) [] allFunsPats +loopOverModBinds _ _ _ _ _ _ _ _ = do + pure $ Errors [] + +getAllEnums :: LHsExpr GhcTc -> Maybe String +getAllEnums (L _ (HsConLikeOut _ liter)) = Just $ showSDocUnsafe $ ppr liter +getAllEnums (L _ _) = Nothing + + +loopOverPats :: [String] -> CheckerConfig -> String -> HM.HashMap String [FunctionInfo] -> UpdateInfo -> String -> [UpdateInfo] -> Match GhcTc (LHsExpr GhcTc) -> TcM ErrorCase +loopOverPats allPaths checkerCase path allFUnsInsid allFunsWithFailure moduleName' allPatsList match = do + case checkerCase of + FieldsCheck (EnumCheck{..}) -> do + let normalBinds = (\(GRHS _ _ stmt )-> stmt ) <$> unLoc <$> (grhssGRHSs $ m_grhss match) + argBinds = m_pats match + checker = any (\x -> isVarPatExprBool x) (normalBinds ^? biplateRef :: [LHsExpr GhcTc] ) + if checker then pure $ Errors [] else + let a = any isVarPat argBinds + in if a then do + -- liftIO $ (print (showSDocUnsafe $ ppr normalBind, showSDocUnsafe $ ppr normalBinds )) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (normalBinds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + let allFUnsInside = HM.union allLetPats allFUnsInsid + allFuns = concat $ map processExpr (normalBinds ^? biplateRef) + check <- mapM (\x -> case HM.lookup (mkStringFromFunctionInfo x) allFUnsInside of + Nothing -> throwErrorRules x allPaths path moduleName' allFunsWithFailure allPatsList + Just val -> do + -- liftIO $ print ("showing " ++ show val) + res <- mapM (\y -> throwErrorRules y allPaths path moduleName' allFunsWithFailure allPatsList) (nub val) + let concatVals = concat $ catMaybes ( res) + if null concatVals then pure Nothing else pure $ Just concatVals) (nub allFuns) --anyM (\x -> if module_name x == moduleName' + -- then pure $ name x `elem` (name <$> allFunsWithFailure) && module_name x `elem` (module_name <$> allFunsWithFailure) + -- else checkInOtherMods x) allFuns + if ((not $ null (catMaybes check))) + then pure $ Errors $ (\x -> CompileError "" "" x (getLocGhc $ head argBinds)) <$> (catMaybes check) + else do + processedPats <- mapM (\x -> do + allCHecks <- liftIO $ mapM (checkInOtherModsWithoutError allPaths checkerCase moduleName') x + pure $ any (==True) allCHecks) allLetPats + let allFailureNames = name <$> (updatedFailurs allFunsWithFailure) + let allNeeded = mapMaybe getExprTypeWithName $ normalBinds ^? biplateRef + allVals = fst <$> allNeeded + allEnums = catMaybes $ snd <$> allNeeded + b = (any (\x -> x `elem` (splitOn " " $ showSDocUnsafe $ ppr match)) (allFailureNames) || any (\x -> x `elem` allEnums) ["FAILURE", "SUCCESS"]) && any (\x -> recordType `isInfixOf` (replace enumType "" x)) (allVals) && any (\x -> enumType `isInfixOf` x) allVals + allFunsUpd <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats []) ( (match ^? biplateRef)) + let allFunsUpds = catMaybes allFunsUpd + -- liftIO $ print $ ("TypesInfo ", allFuns) + -- liftIO $ print $ ("check",a, allFunsUpds, toConstr <$> unLoc <$> normalBinds, showSDocUnsafe $ ppr normalBinds, allVals, allFuns) + -- allFunsWithFailure <- mapM getAndPut allFuns + -- liftIO $ print ("Checker", allPatsList) + pure $ Errors $ + if CreateWithFailure `elem` allFunsUpds + then [CompileError "" "" (createError ++ show allFunsUpds) (getLocGhc $ head argBinds)] + else if UpdateWithFailure `elem` allFunsUpds + then [CompileError "" "" (updateError ++ show allFunsUpds) (getLocGhc $ head argBinds)] + else if b then [CompileError "" "" defaultCase (getLocGhc $ head argBinds)] + else [] + else pure $ Errors [] + FunctionCheck (FunctionCheckConfig{..}) -> do + let normalBinds = (\(GRHS _ _ stmt )-> stmt ) <$> unLoc <$> (grhssGRHSs $ m_grhss match) + argBinds = m_pats match + let a = any isVarPat argBinds + if a then do + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (normalBinds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + let allFUnsInside = HM.union allLetPats allFUnsInsid + allFuns = concat $ map processExpr (normalBinds ^? biplateRef) + check <- mapM (\x -> if name x `elem` listOfRestrictedFuns then + pure $ Just syncError else do + case HM.lookup (mkStringFromFunctionInfo x) allFUnsInside of + Nothing -> throwFunctionErrorRules x allPaths path moduleName' allFunsWithFailure allPatsList + Just val -> do + -- liftIO $ print ("showing " ++ show val) + res <- mapM (\y -> throwFunctionErrorRules y allPaths path moduleName' allFunsWithFailure allPatsList) (nub val) + let concatVals = concat $ catMaybes ( res) + if null concatVals then pure Nothing else pure $ Just concatVals) (nub allFuns) + if ((not $ null (catMaybes check))) + then pure $ Errors $ (\x -> CompileError "" "" x (getLocGhc $ head argBinds)) <$> (catMaybes check) + else pure $ Errors [] + else pure $ Errors [] + +throwFunctionErrorRules + :: FunctionInfo + -> [String] + -> String + -> String + -> UpdateInfo + -> [UpdateInfo] + -> TcM (Maybe [Char]) +throwFunctionErrorRules x allPaths path moduleName' (UpdateInfo _ _ _ _ _ otherFuns) _ = do + -- liftIO $ print ("Checking " ++ name x ++ show x) + if module_name x == moduleName' || "_in" == module_name x then + pure $ if( name x `elem` (name <$> otherFuns) && module_name x `elem` (module_name <$> otherFuns)) then + Just (syncError) + else Nothing + else checkInOtherModsFunction allPaths path x + +getLocGhc :: _ -> SrcSpan +getLocGhc val = +#if __GLASGOW_HASKELL__ >= 900 + RealSrcSpan (la2r $ getLoc val) Nothing +#else + getLoc val +#endif + +throwErrorRules :: + FunctionInfo + -> [String] + -> String + -> String + -> UpdateInfo + -> [UpdateInfo] + -> TcM (Maybe [Char]) +throwErrorRules x allPaths path moduleName' (UpdateInfo _ _ upFails cFails _ _) _ = do + -- liftIO $ print ("Checking " ++ name x ++ show x) + if module_name x == moduleName' || "_in" == module_name x then + pure $ if( name x `elem` (name <$> cFails) && module_name x `elem` (module_name <$> cFails)) then + -- || (name x `elem` (concat $ map (\x -> name <$> x) $ createdRecordsFun <$> allPatsList)) then + Just (createError ++ show (name x ++ show (name <$> cFails))) + else if name x `elem` ((name <$> upFails)) && module_name x `elem` ((module_name <$> upFails)) then + -- || (name x `elem` (concat $ map (\x -> name <$> x) $ updatedRecordsFun <$> allPatsList)) then + Just (updateError ++ show (name x ++ show (name <$> upFails))) + -- else if name x `elem` (name <$> functions) && module_name x `elem` (module_name <$> functions) + -- || (name x `elem` (concat $ map (\x -> name <$> x) $ updatedFailurs <$> allPatsList)) then + -- Just defaultCase + else Nothing + else checkInOtherMods allPaths path x + +checkInOtherModsFunction :: [String] -> String -> FunctionInfo -> TcM (Maybe String) +checkInOtherModsFunction allPaths path (FunctionInfo _ y z _ _) = do + let newFileName = "/" ++ (intercalate "/" . splitOn "." $ y) ++ ".hs.json" + filterNames <- liftIO $ filterM (\pos -> doesFileExist (path ++ pos ++ newFileName)) allPaths + let orgName = if null filterNames then ("test" ++ newFileName) else prefixPath ++ head filterNames ++ newFileName + fileContents <- liftIO $ (try $ B.readFile orgName :: IO (Either SomeException B.ByteString)) + pure $ either (\_ -> Nothing) (\contents -> + maybe Nothing + (\(UpdateInfoAsText _ _ _ _ _ otherFuns) -> + if z `elem` otherFuns + then Just (syncError ++ show (z,otherFuns)) + else Nothing) (Aeson.decode contents :: Maybe UpdateInfoAsText)) fileContents + +checkInOtherMods :: [String] -> String -> FunctionInfo -> TcM (Maybe String) +checkInOtherMods allPaths path (FunctionInfo _ y z _ _) = do + let newFileName = "/" ++ (intercalate "/" . splitOn "." $ y) ++ ".hs.json" + filterNames <- liftIO $ filterM (\pos -> doesFileExist (path ++ pos ++ newFileName)) allPaths + let orgName = if null filterNames then ("test" ++ newFileName) else prefixPath ++ head filterNames ++ newFileName + fileContents <- liftIO $ (try $ B.readFile orgName :: IO (Either SomeException B.ByteString)) + pure $ either (\_ -> Nothing) (\contents -> + maybe Nothing + (\(UpdateInfoAsText _ _ upFails cFails _ _) -> + if z `elem` cFails + then Just (createError ++ show (z,cFails)) + else if z `elem` upFails + then Just (updateError ++ show (z,upFails)) + else Nothing) (Aeson.decode contents :: Maybe UpdateInfoAsText)) fileContents + +checkInOtherModsWithoutError :: [String] -> CheckerConfig -> String -> FunctionInfo -> IO Bool +checkInOtherModsWithoutError allPaths checkerCase moduleName' fun@(FunctionInfo _ y z _ _) = do + case checkerCase of + FieldsCheck _ -> do + if module_name fun == moduleName' || "_in" == module_name fun then pure False + else do + let newFileName = "/" ++ (intercalate "/" . splitOn "." $ y) ++ ".hs.json" + filterNames <- liftIO $ filterM (\pos -> doesFileExist (prefixPath ++ pos ++ newFileName)) allPaths + let orgName = if null filterNames then ("test" ++ newFileName) else prefixPath ++ head filterNames ++ newFileName + fileContents <- liftIO $ (try $ B.readFile orgName :: IO (Either SomeException B.ByteString)) + pure $ either (\_ -> False) (\contents -> + maybe False + (\(UpdateInfoAsText creRecords upRecords upFails cFails defaultF _) -> + z `elem` creRecords ++ upRecords ++ upFails ++ cFails ++ defaultF) (Aeson.decode contents :: Maybe UpdateInfoAsText)) fileContents + FunctionCheck _ -> + if module_name fun == moduleName' || "_in" == module_name fun then pure False + else do + let newFileName = "/" ++ (intercalate "/" . splitOn "." $ y) ++ ".hs.json" + filterNames <- liftIO $ filterM (\pos -> doesFileExist (prefixPath ++ pos ++ newFileName)) allPaths + let orgName = if null filterNames then ("test" ++ newFileName) else prefixPath ++ head filterNames ++ newFileName + fileContents <- liftIO $ (try $ B.readFile orgName :: IO (Either SomeException B.ByteString)) + pure $ either (\_ -> False) (\contents -> + maybe False + (\(UpdateInfoAsText _ _ _ _ _ checkerF) -> + z `elem` checkerF) (Aeson.decode contents :: Maybe UpdateInfoAsText)) fileContents + + +checkInOtherModsWithoutErrorFuns :: [String] -> CheckerConfig -> String -> FunctionInfo -> IO [CompileError] +checkInOtherModsWithoutErrorFuns allPaths checkerCase moduleName' fun@(FunctionInfo _ y z _ _) = do + case checkerCase of + FunctionCheck _ -> + if module_name fun == moduleName' || "_in" == module_name fun then pure [] + else do + let newFileName = "/" ++ (intercalate "/" . splitOn "." $ y) ++ ".hs.err.json" + filterNames <- liftIO $ filterM (\pos -> doesFileExist (prefixPath ++ pos ++ newFileName)) allPaths + let orgName = if null filterNames then ("test" ++ newFileName) else prefixPath ++ head filterNames ++ newFileName + fileContents <- liftIO $ (try $ B.readFile orgName :: IO (Either SomeException B.ByteString)) + pure $ either (\_ -> []) (\contents -> + maybe [] + (\(checkerF) -> + case HM.lookup z checkerF of + Just val -> val + Nothing -> []) (Aeson.decode contents :: Maybe (HM.HashMap String [CompileError]))) fileContents + _ -> pure [] + +#if __GLASGOW_HASKELL__ >= 900 +extractLHsRecUpdField :: Either [LHsRecUpdField GhcTc] [LHsRecUpdProj GhcTc] -> [FunctionInfo] +extractLHsRecUpdField fields = + case fields of + Left fun -> concatMap (processExprCases) (fun) + Right x -> + let yn = (map (GHC.unXRec @(GhcTc)) x) + in concatMap processExprUps (yn ) +#else +extractLHsRecUpdField :: GenLocated l (HsRecField' id (LHsExpr GhcTc)) -> [FunctionInfo] +extractLHsRecUpdField (L _ (HsRecField {hsRecFieldArg = fun})) = processExpr fun +#endif + +#if __GLASGOW_HASKELL__ >= 900 +processExprUps :: HsRecField' id (GenLocated SrcSpanAnnA (HsExpr GhcTc)) -> [FunctionInfo] +processExprUps (HsRecField {hsRecFieldArg = fun}) = processExpr fun + +processExprCases :: GenLocated l (HsRecField' id (GenLocated SrcSpanAnnA (HsExpr GhcTc))) -> [FunctionInfo] +processExprCases (L _ (HsRecField {hsRecFieldArg = fun})) = processExpr fun +#endif + + +mkStringFromFunctionInfo :: FunctionInfo -> String +mkStringFromFunctionInfo (FunctionInfo pName modName name _ _) = intercalate "$" [pName, modName, name] + +processExpr :: LHsExpr GhcTc -> [FunctionInfo] +processExpr x@(L _ (HsVar _ (L _ var))) = + let name = transformFromNameStableString (nameStableString $ varName var) (showSDocUnsafe $ ppr $ getLoc $ x) False + in [name] +processExpr (L _ (HsUnboundVar _ _)) = [] +processExpr (L _ (HsApp _ funl funr)) = + processExpr funl <> processExpr funr +processExpr (L _ (OpApp _ funl funm funr)) = + processExpr funl <> processExpr funm <> processExpr funr +processExpr (L _ (NegApp _ funl _)) = + processExpr funl +processExpr (L _ (HsTick _ _ fun)) = + processExpr fun +processExpr (L _ (HsStatic _ fun)) = + processExpr fun +processExpr (L _ (HsBinTick _ _ _ fun)) = + processExpr fun +#if __GLASGOW_HASKELL__ < 900 +processExpr (L _ (HsTickPragma _ _ _ _ fun)) = + processExpr fun +processExpr (L _ (HsSCC _ _ _ fun)) = + processExpr fun +processExpr (L _ (HsCoreAnn _ _ _ fun)) = + processExpr fun +processExpr (L _ (HsWrap _ _ fun)) = + processExpr (noLoc fun) +-- processExpr (L _ (HsWrap _ _ fun)) = +-- processExpr (noLoc fun) +processExpr (L _ (ExplicitList _ _ funList)) = + concatMap processExpr funList +processExpr (L _ (HsIf _ exprLStmt funl funm funr)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr $ [funl, funm, funr] <> stmts) +processExpr (L _ (HsTcBracketOut _ exprLStmtL exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr (stmtsL <> stmtsR)) +-- processExpr (L _ (HsIf _ exprLStmt funl funm funr)) = +-- let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) +-- in nub (concatMap processExpr $ [funl, funm, funr] <> stmts) +processExpr (L _ (RecordUpd _ rupd_expr rupd_flds)) = processExpr rupd_expr <> concatMap extractLHsRecUpdField rupd_flds +#else +processExpr (L _ (ExplicitList _ funList)) = + concatMap processExpr funList +processExpr (L _ (HsIf exprLStmt funl funm funr)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr $ [funl, funm, funr] <> stmts) +processExpr (L _ (HsTcBracketOut _ _ exprLStmtL exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr (stmtsL <> stmtsR)) +processExpr (L _ (RecordUpd _ rupd_expr rupd_flds)) = processExpr rupd_expr <> extractLHsRecUpdField rupd_flds +#endif +processExpr (L _ (ExprWithTySig _ fun _)) = + processExpr fun +processExpr (L _ (HsDo _ _ exprLStmt)) = + let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + in nub $ concatMap processExpr stmts +processExpr (L _ (HsLet _ exprLStmt func)) = + let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + in processExpr func <> nub (concatMap processExpr stmts) +processExpr (L _ (HsMultiIf _ exprLStmt)) = + let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + in nub (concatMap processExpr stmts) +processExpr (L _ (HsCase _ funl exprLStmt@(MG _ (L _ _) _))) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr $ [funl] <> stmts) +processExpr (L _ (ExplicitSum _ _ _ fun)) = processExpr fun +processExpr (L _ (SectionR _ funl funr)) = processExpr funl <> processExpr funr +processExpr (L _ (ExplicitTuple _ exprLStmt _)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmts) +processExpr (L _ (HsPar _ fun)) = processExpr fun +processExpr (L _ (HsAppType _ fun _)) = processExpr fun +processExpr (L _ (HsLamCase _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmts) +processExpr (L _ (HsLam _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmts) +processExpr x@(L _ (HsLit _ liter)) = + -- let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + -- let literals = + [FunctionInfo "" (show $ toConstr liter) (showSDocUnsafe $ ppr liter) (showSDocUnsafe $ ppr $ getLoc x) False] +processExpr (L _ (HsOverLit _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmts) +processExpr (L _ (HsRecFld _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmts) +processExpr (L _ (HsSpliceE exprLStmtL exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr (stmtsL <> stmtsR)) +processExpr (L _ (ArithSeq _ (Just exprLStmtL) exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr (stmtsL <> stmtsR)) +processExpr (L _ (ArithSeq _ Nothing exprLStmtR)) = + let stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr stmtsR) +processExpr (L _ (HsRnBracketOut _ exprLStmtL exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in nub (concatMap processExpr (stmtsL <> stmtsR)) +-- processExpr (L _ (RecordCon _ (L _ (iD)) rcon_flds)) = Just ((extractRecordBinds (rcon_flds)), False) +-- processExpr (L _ (RecordUpd _ rupd_expr rupd_flds) +-- processExpr (L _ (HsTcBracketOut _ exprLStmtL exprLStmtR)) = + +-- HsIPVar (XIPVar p) HsIPName +-- HsOverLabel (XOverLabel p) (Maybe (IdP p)) FastString +processExpr x@(L _ (HsConLikeOut _ (RealDataCon liter))) = + [FunctionInfo "" "" (showSDocUnsafe $ ppr liter) (showSDocUnsafe $ ppr $ getLoc x) False] +processExpr x@(L _ (HsConLikeOut _ (liter))) = + [FunctionInfo "" "" (showSDocUnsafe $ ppr liter) (showSDocUnsafe $ ppr $ getLoc x) False] +processExpr (L _ _) = [] + +isVarPatMatch :: (LMatch GhcTc body) -> Bool +isVarPatMatch (L _ match) = + let argBinds = m_pats match + in any isVarPat argBinds + +isVarPatExprBool :: LHsExpr GhcTc -> Bool +isVarPatExprBool (L _ (HsCase _ _ (MG _ (L _ mg_alts) _))) = + any isVarPatMatch mg_alts + -- in L loc (HsCase m funl (MG mg_ext (L loc1 y) mg_org)) +isVarPatExprBool _ = False + +-- isVarPatMatch (L _ match) = +-- let normalBinds = (\(GRHS _ _ stmt )-> stmt ) <$> unLoc <$> (grhssGRHSs $ m_grhss match) +-- argBinds = m_pats match +-- in any isVarPat argBinds + +isVarPat :: LPat GhcTc -> Bool +isVarPat (L _ pat) = case pat of + VarPat _ (L _ _) -> True + WildPat _ -> True + x -> if any (\y -> y `isInfixOf` (showSDocUnsafe $ ppr x)) ["Nothing", "Left"] then True else False + +getExprType :: LHsExpr GhcTc -> Maybe String +getExprType (L _ (HsVar _ idT)) = Just $ showSDocUnsafe $ ppr $ idType $ unLoc idT +getExprType (L _ (HsConLikeOut _ (RealDataCon idT))) = Just $ showSDocUnsafe $ ppr $ idType $ dataConWrapId idT +-- getExprType (L _ (HsConLikeOut _ (PatSynCon id))) = Just $ showSDocUnsafe $ ppr $ idType $ dataConWrapId id +getExprType (L _ _)= Nothing + +getExprTypeWithName :: LHsExpr GhcTc -> Maybe (String, Maybe String) +getExprTypeWithName (L _ (HsVar _ idT)) = Just $ (showSDocUnsafe $ ppr $ idType $ unLoc idT, Nothing) +getExprTypeWithName (L _ (HsConLikeOut _ (RealDataCon idT))) = Just $ (showSDocUnsafe $ ppr $ idType $ dataConWrapId idT, Just $ showSDocUnsafe $ ppr idT) +-- getExprTypeWithName (L _ (HsConLikeOut _ (PatSynCon id))) = Just $ showSDocUnsafe $ ppr $ idType $ dataConWrapId id +getExprTypeWithName (L _ _)= Nothing + +getExprTypeAsType :: LHsExpr GhcTc -> Maybe Type +getExprTypeAsType (L _ (HsVar _ idT)) = Just $ idType $ unLoc idT +getExprTypeAsType (L _ (HsConLikeOut _ (RealDataCon idT))) = Just $ idType $ dataConWrapId idT +-- getExprTypeAsType (L _ (HsConLikeOut _ (PatSynCon id))) = Just $ showSDocUnsafe $ ppr $ idType $ dataConWrapId id +getExprTypeAsType (L _ _)= Nothing + +getDataTypeDetails :: String -> [String] -> String -> String -> HM.HashMap String Bool -> [String] -> HsExpr GhcTc -> TcM (Maybe TypeOfUpdate) +#if __GLASGOW_HASKELL__ >= 900 +getDataTypeDetails recordType enumList _ fieldType allLetPats allArgs (RecordCon _ iD rcon_flds) = pure $ if recordType `elem` ((splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr $ conLikeType (GHC.unXRec @(GhcTc) iD))) then Just (extractRecordBinds rcon_flds allLetPats allArgs enumList fieldType) else Nothing +#else +getDataTypeDetails recordType enumList _ fieldType allLetPats allArgs (RecordCon _ iD rcon_flds) = pure $ if recordType `elem` ((splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr $ idType $ unLoc iD)) then Just (extractRecordBinds rcon_flds allLetPats allArgs enumList fieldType) else Nothing +#endif +getDataTypeDetails recordType enumList _ fieldType allLetPats allArgs (RecordUpd _ rupd_expr rupd_flds) = + let allVals = mapMaybe getExprTypeAsType $ rupd_expr ^? biplateRef + in pure $ if any (\x -> recordType `elem` ((splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr x))) allVals then Just (getFieldUpdates rupd_flds allLetPats allArgs enumList fieldType) else Nothing +getDataTypeDetails recordType enumList _ _ allLetPats allArgs x@(OpApp _ funl funm _) = do + -- trace (show (showSDocUnsafe $ ppr x, showSDocUnsafe $ ppr funl,showSDocUnsafe $ ppr funm, showSDocUnsafe $ ppr funr)) Nothing + if (showSDocUnsafe $ ppr funm) == "(#)" + then do + let allVals = mapMaybe getExprTypeAsType $ funl ^? biplateRef + if any (\val -> recordType `elem` ((splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr val))) allVals + then do + let allOps = mapMaybe (splitOnOpAp enumList allLetPats allArgs) (x ^? biplateRef) + -- liftIO $ print ("Proper ," ++ (show allOps)) + pure $ if UpdateWithFailure `elem` allOps then Just UpdateWithFailure else if Update `elem` allOps then Just Update else Nothing + else do + -- liftIO $ print ("Frist ," ++ (showSDocUnsafe $ ppr allVals), showSDocUnsafe $ ppr funl) + pure Nothing + else do + -- liftIO $ print ("Seconf ," ++ (showSDocUnsafe $ ppr funm)) + pure Nothing +getDataTypeDetails recordType enumList enumType fieldType _ _ pat@(HsApp _ app1 (L _ _)) = do + if "Set" `isInfixOf` (showSDocUnsafe $ ppr app1) then do + let allVals = mapMaybe getExprTypeAsType $ pat ^? biplateRef + pure $ if recordType `isInfixOf` (showSDocUnsafe $ ppr allVals) && fieldType `isInfixOf` (showSDocUnsafe $ ppr pat) && enumType `isInfixOf` (showSDocUnsafe $ ppr allVals) + then if any (\x -> x `isInfixOf` (showSDocUnsafe $ ppr pat)) enumList then + Just UpdateWithFailure + else Just Update + else Nothing + -- liftIO $ print (showSDocUnsafe $ ppr app1, showSDocUnsafe $ ppr allVals, toConstr app2, showSDocUnsafe $ ppr app2) + else pure Nothing + -- if showSDocUnsafe $ ppr app1 + -- pure Nothing +getDataTypeDetails recordType enumList enumType fieldType allLetPats _ (HsPar _ (L _ pat)) = do + -- liftIO $ print (allLetPats) + if "setField" `isInfixOf` (showSDocUnsafe $ ppr pat) then do + let allVals = mapMaybe getExprTypeAsType $ pat ^? biplateRef + allInnerVals = concat $ map processExpr (pat ^? biplateRef) + pure $ if recordType `isInfixOf` (showSDocUnsafe $ ppr allVals) && fieldType `isInfixOf` (showSDocUnsafe $ ppr pat) && enumType `isInfixOf` (showSDocUnsafe $ ppr allVals) + then if any (\x -> x `isInfixOf` (showSDocUnsafe $ ppr pat)) enumList then + Just UpdateWithFailure + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allInnerVals + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then Just UpdateWithFailure + else Just Update + else Nothing + else pure Nothing +getDataTypeDetails _ _ _ _ _ _ (_) = do + -- liftIO $ print (toConstr app1, showSDocUnsafe $ ppr app1) + pure Nothing + +splitOnOpAp :: [String] -> HM.HashMap String Bool -> [String] -> HsExpr GhcTc -> Maybe TypeOfUpdate +splitOnOpAp enumList allLetPats _ (OpApp _ left op right) = + if (showSDocUnsafe $ ppr op) `elem` ["(%~)","(.~)"] then + if (showSDocUnsafe $ ppr left) == "_status" then do + let allVals = concat $ map processExpr (right ^? biplateRef) + if any (\x -> x `isInfixOf` (showSDocUnsafe $ ppr right)) enumList then Just UpdateWithFailure + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allVals + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then Just UpdateWithFailure + else Just Update + else Nothing + else Nothing +splitOnOpAp _ _ _ _ = Nothing + + -- let allVals = mapMaybe getExprType $ x ^? biplateRef + -- in if any (\x -> x `isInfixOf` recordType) allVals then Just (False, True) else Nothing + +processRecordExpr :: HsExpr GhcTc -> String -> [FunctionInfo] +#if __GLASGOW_HASKELL__ >= 900 +processRecordExpr (RecordCon _ (iD) rcon_flds) recordType = if recordType `isInfixOf` (showSDocUnsafe $ ppr $ conLikeType (GHC.unXRec @(GhcTc) iD)) then concat $ map processExpr (rcon_flds ^? biplateRef) else [] +#else +processRecordExpr (RecordCon _ (L _ (iD)) rcon_flds) recordType = if recordType `isInfixOf` (showSDocUnsafe $ ppr $ idType iD) then concat $ map processExpr (rcon_flds ^? biplateRef) else [] +#endif +processRecordExpr (RecordUpd _ rupd_expr rupd_flds) recordType = + let allVals = mapMaybe getExprType $ rupd_expr ^? biplateRef + in if any (\x -> x `isInfixOf` recordType) allVals then + concat $ map processExpr (rupd_flds ^? biplateRef) + else [] +processRecordExpr x recordType = + let allVals = mapMaybe getExprType $ x ^? biplateRef + allExprs = map processExpr (x ^? biplateRef) + in if any (\val -> val `isInfixOf` recordType) allVals then + concat $ allExprs + else [] +-- inferFieldType :: Name -> String +-- inferFieldTypeFieldOcc (L _ (FieldOcc _ (L _ rdrName))) = handleRdrName rdrName +-- inferFieldTypeAFieldOcc = (handleRdrName . rdrNameAmbiguousFieldOcc . unLoc) +#if __GLASGOW_HASKELL__ >= 900 +conLikeType :: ConLike -> Type +conLikeType (RealDataCon data_con) = dataConType data_con +conLikeType (PatSynCon pat_syn) = patSynResultType pat_syn +#endif + +#if __GLASGOW_HASKELL__ >= 900 +getFieldUpdates :: Either [LHsRecUpdField GhcTc] [LHsRecUpdProj GhcTc] -> HM.HashMap String Bool -> [String] -> [String] -> String -> TypeOfUpdate +getFieldUpdates fields allLetPats allArgs enumList fieldType = + case fields of + Left x -> + let allUpdates = map extractField x + in if UpdateWithFailure `elem` allUpdates then UpdateWithFailure + else if Update `elem` allUpdates then Update + else NoChange + Right x -> + let yn = (map (GHC.unXRec @(GhcTc)) x) + allUpdates = map extractField' (yn) + in if UpdateWithFailure `elem` allUpdates then UpdateWithFailure + else if Update `elem` allUpdates then Update + else NoChange + where + extractField :: LHsRecUpdField GhcTc -> TypeOfUpdate + extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr})) = + let allNrFuns = nub $ ((concatMap processExpr (map noLocA $ expr ^? biplateRef))) + in if isInfixOf fieldType (showSDocUnsafe $ ppr lbl) then + if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) enumList then + UpdateWithFailure + else if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) allArgs then + Update + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allNrFuns + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then UpdateWithFailure + else Update + else NoChange + -- extractField' :: HsRecUpdField GhcTc -> TypeOfUpdate + extractField' ((HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr})) = + let allNrFuns = nub $ ((concatMap processExpr (map noLocA $ expr ^? biplateRef))) + in if isInfixOf fieldType (showSDocUnsafe $ ppr lbl) then + if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) enumList then + UpdateWithFailure + else if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) allArgs then + Update + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allNrFuns + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then UpdateWithFailure + else Update + else NoChange +#else +getFieldUpdates :: [LHsRecUpdField GhcTc] -> HM.HashMap String Bool -> [String] -> [String] -> String -> TypeOfUpdate +getFieldUpdates fields allLetPats allArgs enumList fieldType = + let allUpdates = map extractField fields + in if UpdateWithFailure `elem` allUpdates then UpdateWithFailure + else if Update `elem` allUpdates then Update + else NoChange + where + extractField :: LHsRecUpdField GhcTc -> TypeOfUpdate + extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr})) = + let allNrFuns = nub $ ((concatMap processExpr (map noLoc $ expr ^? biplateRef))) + in if isInfixOf fieldType (showSDocUnsafe $ ppr lbl) then + if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) enumList then + UpdateWithFailure + else if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) allArgs then + Update + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allNrFuns + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then UpdateWithFailure + else Update + else NoChange +#endif + +extractRecordBinds :: HsRecFields GhcTc (LHsExpr GhcTc) -> HM.HashMap String Bool -> [String] -> [String] -> String -> TypeOfUpdate +extractRecordBinds (HsRecFields{rec_flds = fields}) allLetPats allArgs enumList fieldType = + let allUpdates = map extractField fields + in if CreateWithFailure `elem` allUpdates then CreateWithFailure + else if Create `elem` allUpdates then Create + else NoChange + where + extractField :: LHsRecField GhcTc (LHsExpr GhcTc) -> TypeOfUpdate + extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr})) = do +#if __GLASGOW_HASKELL__ >= 900 + let allNrFuns = nub $ ((concatMap processExpr (map (noLocA) $ expr ^? biplateRef))) +#else + let allNrFuns = nub $ ((concatMap processExpr (map noLoc $ expr ^? biplateRef))) +#endif + if isInfixOf fieldType (showSDocUnsafe $ ppr lbl) then + if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) enumList then + CreateWithFailure + else if any (\x -> isInfixOf x (showSDocUnsafe $ ppr expr)) allArgs then + Create + else do + let allPosibleVals = + map (\x -> case HM.lookup (mkStringFromFunctionInfo x) allLetPats of + Nothing -> maybe (True,False) (\val -> (False, val)) $ HM.lookup (name x) allLetPats + Just val -> (False, val)) allNrFuns + if any (==False) (fst <$> allPosibleVals) && any (==True) (snd <$> allPosibleVals) then CreateWithFailure + else Create + else NoChange + -- then (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr lbl) (inferFieldTypeFieldOcc lbl)) + -- else (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr $ unLoc expr) (inferFieldTypeFieldOcc lbl)) + +getFunctionName :: LHsBindLR GhcTc GhcTc -> [String] +#if __GLASGOW_HASKELL__ < 900 +getFunctionName (L _ (FunBind _ idt _ _ _)) = [nameStableString $ getName idt] +#else +getFunctionName (L _ (FunBind _ idt _ _)) = [nameStableString $ getName idt] +#endif +getFunctionName (L _ (VarBind{var_id = var})) = [nameStableString $ varName var] +getFunctionName (L _ (PatBind{})) = [""] +getFunctionName (L _ (AbsBinds{abs_binds = binds})) = concatMap getFunctionName $ bagToList binds +getFunctionName _ = [] + +getFunctionNameIfFailure :: [String] -> CheckerConfig -> String -> [String] -> String -> String -> String -> LHsBindLR GhcTc GhcTc -> TcM (TypeOfUpdate, [String]) +#if __GLASGOW_HASKELL__ < 900 +getFunctionNameIfFailure allPaths checkerCase recordType enumList enumType fieldType moduleName' (L _ x@(FunBind _ idT _ _ _)) = do +#else +getFunctionNameIfFailure allPaths checkerCase recordType enumList enumType fieldType moduleName' (L _ x@(FunBind _ idT _ _)) = do +#endif + let allValsTypes = mapMaybe getExprType (x ^? biplateRef) + let allVals = (map unLoc (x ^? biplateRef :: [LHsExpr GhcTc])) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (x ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + processedPats <- mapM (\funInfo -> do + if any (\val -> val `elem` enumList) (name <$> funInfo) then pure True + else do + allCHecks <- liftIO $ mapM (checkInOtherModsWithoutError allPaths checkerCase moduleName') funInfo + pure $ any (==True) allCHecks) allLetPats + let allBinds = concat $ mapMaybe loopOverFunBind (x ^? biplateRef :: [LHsBindLR GhcTc GhcTc]) + funName = [nameStableString $ getName idT] + (allRecordUpdsAndCrea) <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats allBinds) $ allVals + let allRecordUpdsAndCreate = catMaybes allRecordUpdsAndCrea + pure $ if any (\val -> val==CreateWithFailure) allRecordUpdsAndCreate + then (CreateWithFailure, funName) + else if any (\val -> val==UpdateWithFailure) allRecordUpdsAndCreate + then (UpdateWithFailure, funName) + else if any (\val -> val==Create) allRecordUpdsAndCreate + then (Create, funName) + else if any (\val -> val==Update) allRecordUpdsAndCreate + then (Update, funName) + else if any (\val -> val==Default) allRecordUpdsAndCreate + then (Default, funName) + else if any (\val -> isInfixOf val (showSDocUnsafe $ ppr x) ) enumList && (Just enumType) == (lastMaybe (splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr (lastMaybe allValsTypes))) + then (Default, funName) + else (NoChange,[]) +getFunctionNameIfFailure allPaths checkerCase recordType enumList enumType fieldType moduleName' (L _ x@(VarBind{var_id = var})) = do + let allValsTypes = mapMaybe getExprType (x ^? biplateRef) + let allVals = (map unLoc (x ^? biplateRef :: [LHsExpr GhcTc])) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (x ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + processedPats <- mapM (\funInfo -> do + if any (\val -> val `elem` enumList) (name <$> funInfo) then pure True + else do + allCHecks <- liftIO $ mapM (checkInOtherModsWithoutError allPaths checkerCase moduleName') funInfo + pure $ any (==True) allCHecks) allLetPats + let allBinds = concat $ mapMaybe loopOverFunBind (x ^? biplateRef :: [LHsBindLR GhcTc GhcTc]) + funName = [nameStableString $ varName var] + (allRecordUpdsAndCrea) <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats allBinds) $ allVals + let allRecordUpdsAndCreate = catMaybes allRecordUpdsAndCrea + pure $ if any (\val -> val==CreateWithFailure) allRecordUpdsAndCreate + then (CreateWithFailure, funName) + else if any (\val -> val==UpdateWithFailure) allRecordUpdsAndCreate + then (UpdateWithFailure, funName) + else if any (\val -> val==Create) allRecordUpdsAndCreate + then (Create, funName) + else if any (\val -> val==Update) allRecordUpdsAndCreate + then (Update, funName) + else if any (\val -> val==Default) allRecordUpdsAndCreate + then (Default, funName) + else if any (\val -> isInfixOf val (showSDocUnsafe $ ppr x) ) enumList && (Just enumType) == (lastMaybe (splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr (lastMaybe allValsTypes))) + then (Default, funName) + else (NoChange,[]) +getFunctionNameIfFailure allPaths checkerCase recordType enumList enumType fieldType moduleName' (L _ x@(AbsBinds{abs_binds = binds})) = do + let allValsTypes = mapMaybe getExprType (x ^? biplateRef) + let allVals = (map unLoc (bagToList binds ^? biplateRef :: [LHsExpr GhcTc])) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (bagToList binds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + processedPats <- mapM (\funInfo -> do + if any (\val -> val `elem` enumList) (name <$> funInfo) then pure True + else do + allCHecks <- liftIO $ mapM (checkInOtherModsWithoutError allPaths checkerCase moduleName') funInfo + pure $ any (==True) allCHecks) allLetPats + let allBinds = concat $ mapMaybe loopOverFunBind (bagToList binds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]) + funName = concatMap getFunctionName $ bagToList binds + (allRecordUpdsAndCrea) <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats allBinds) $ allVals + let allRecordUpdsAndCreate = catMaybes allRecordUpdsAndCrea + pure $ if any (\val -> val==CreateWithFailure) allRecordUpdsAndCreate + then (CreateWithFailure, funName) + else if any (\val -> val==UpdateWithFailure) allRecordUpdsAndCreate + then (UpdateWithFailure, funName) + else if any (\val -> val==Create) allRecordUpdsAndCreate + then (Create, funName) + else if any (\val -> val==Update) allRecordUpdsAndCreate + then (Update, funName) + else if any (\val -> val==Default) allRecordUpdsAndCreate + then (Default, funName) + else if any (\val -> isInfixOf val (showSDocUnsafe $ ppr x) ) enumList && (Just enumType) == (lastMaybe (splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr (lastMaybe allValsTypes))) + then (Default, funName) + else (NoChange,[]) +-- pure $ if any (\(x,y) -> y==True) allRecordUpdsAndCreate then (Update, funName) +-- else if any (\(x,y) -> x==True) allRecordUpdsAndCreate then (Create, funName) +-- else if any (\val -> isInfixOf val (showSDocUnsafe $ ppr x)) enumList && any (\val -> isInfixOf enumType val ) allValsTypes then (Default, funName) else (NoChange,[]) +getFunctionNameIfFailure _ _ _ _ _ _hasFld _ _ = pure $ (NoChange,[]) + +loopOverLHsBindLR :: [String] -> CheckerConfig -> String -> LHsBindLR GhcTc GhcTc -> TcM ((Maybe UpdateInfo), (HM.HashMap String [FunctionInfo])) +loopOverLHsBindLR allPaths checkerCase moduleName' x@(L _ AbsBinds {abs_binds = binds1}) = do + case checkerCase of + FieldsCheck (EnumCheck{..}) -> do + let binds = ( bagToList binds1 ^? biplateRef) + -- let allValsToCheck = ((bagToList binds1 ^? biplateRef :: [LHsExpr GhcTc])) + -- liftIO (print (showSDocUnsafe $ ppr binds1, "and", showSDocUnsafe $ ppr binds)) + let allVals = binds -- ((binds ^? biplateRef :: [LHsExpr GhcTc])) + let allValsTypes = mapMaybe getExprTypeAsType allVals + isF = any (\val -> isInfixOf val (showSDocUnsafe $ ppr x) ) enumList && (Just enumType) == (lastMaybe (splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr (lastMaybe allValsTypes))) + let allLetPats = HM.fromList $ ((mapMaybe processAllLetPats (bagToList binds1 ^? biplateRef :: [LHsBindLR GhcTc GhcTc]))) + processedPats <- mapM (\(funInfo :: [FunctionInfo]) -> + if any (\val -> val `elem` enumList) (name <$> funInfo) then pure True + else do + allCHecks <- liftIO $ mapM (checkInOtherModsWithoutError allPaths checkerCase moduleName') funInfo + pure $ any (==True) allCHecks) allLetPats + allBinds <- liftIO $ concat <$> catMaybes <$> mapM loopOverFunBindM (bagToList binds1 ^? biplateRef :: [LHsBindLR GhcTc GhcTc]) + let filteredAllVals = filter processHsCase allVals + let funName = map (\y -> transformFromNameStableString y (showSDocUnsafe $ ppr $ getLoc $ x) isF ) (getFunctionName x) + if length filteredAllVals > 0 then do + let fname = name <$> funName + allRecordUpdsAndCrea <- mapM (\val -> getAllNeededFunOuter recordType enumList enumType fieldType isF x allBinds processedPats val ) filteredAllVals + let allNrFuns = concat $ snd <$> allRecordUpdsAndCrea + let allRecordUpdsAndCreate = concat $ fst <$> allRecordUpdsAndCrea + -- liftIO $ print ("FInal", allRecordUpdsAndCreate) + pure $ (if any (\val -> val==CreateWithFailure) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] [] funName [] [] + else if any (\val -> val==UpdateWithFailure) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] funName [] [] [] + else if any (\val -> val==Create) allRecordUpdsAndCreate + then Just $ UpdateInfo funName [] [] [] [] [] + else if any (\val -> val==Update) allRecordUpdsAndCreate + then Just $ UpdateInfo [] funName [] [] [] [] + else if any (\val -> val==Default) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] [] [] funName [] + else if isF + then Just $ UpdateInfo [] [] [] [] funName [] + else Nothing, foldl (\acc val -> HM.insert (val) (nub allNrFuns) acc) HM.empty fname) + + else do + let allNrFuns = nub $ ((concatMap processExpr allVals)) + allRecordUpdsAndCrea <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats allBinds) $ allVals ^? biplateRef + let fname = name <$> funName + let allRecordUpdsAndCreate = catMaybes allRecordUpdsAndCrea + pure $ (if any (\val -> val==CreateWithFailure) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] [] funName [] [] + else if any (\val -> val==UpdateWithFailure) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] funName [] [] [] + else if any (\val -> val==Create) allRecordUpdsAndCreate + then Just $ UpdateInfo funName [] [] [] [] [] + else if any (\val -> val==Update) allRecordUpdsAndCreate + then Just $ UpdateInfo [] funName [] [] [] [] + else if any (\val -> val==Default) allRecordUpdsAndCreate + then Just $ UpdateInfo [] [] [] [] funName [] + else if isF + then Just $ UpdateInfo [] [] [] [] funName [] + else Nothing, foldl (\acc val -> HM.insert (val) (nub allNrFuns) acc) HM.empty fname) + FunctionCheck (FunctionCheckConfig{..}) -> do + let binds = ( (bagToList binds1) ^? biplateRef :: [LHsExpr GhcTc]) + let allNrFuns = nub $ ((concatMap processExpr binds)) + let funName = map (\y -> transformFromNameStableString y (showSDocUnsafe $ ppr $ getLoc $ x) False ) (getFunctionName x) + let fname = name <$> funName + first <- liftIO $ (ifM (anyM (\val@(FunctionInfo _ _ y _ _) -> do + let fc = ((y `elem` listOfRestrictedFuns)) + nc <- checkInOtherModsWithoutError allPaths checkerCase moduleName' val + pure $ fc || nc) (allNrFuns)) + (pure $ Just $ UpdateInfo [] [] [] [] [] funName) + (pure Nothing) + ) + pure (first, foldl (\acc val -> HM.insert (val) (nub allNrFuns) acc) HM.empty fname) + + +-- liftIO $ print (allLetPats, showSDocUnsafe $ ppr binds1) +loopOverLHsBindLR _ _ _ _ = pure (Nothing, HM.empty) + +processHsCase :: LHsExpr GhcTc -> Bool +processHsCase (L _ (HsCase _ _ _)) = True +processHsCase _ = False + +getAllNeededFunOuter :: String -> [String] -> String -> String -> Bool -> LHsBindLR GhcTc GhcTc -> [String] -> HM.HashMap String Bool -> LHsExpr GhcTc -> TcM ([TypeOfUpdate], [FunctionInfo]) +getAllNeededFunOuter recordType enumList enumType fieldType isF x allBinds processedPats (L _ (HsCase _ _ exprLStmt)) = do + allRecordUpdsAndCrea <- mapM (getAllNeededFun recordType enumList enumType fieldType isF x allBinds processedPats) $ map unLoc $ unLoc $ mg_alts exprLStmt + -- liftIO $ print (concat allRecordUpdsAndCrea) + pure $ (concat $ fst <$> allRecordUpdsAndCrea, concat $ snd <$> allRecordUpdsAndCrea) +getAllNeededFunOuter _ _ _ _ _ _ _ _ _ = pure ([], []) + + -- pure $ catMaybes $ allRecordUpdsAndCrea + +getAllNeededFun :: String -> [String] -> String -> String -> Bool -> LHsBindLR GhcTc GhcTc -> [String] -> HM.HashMap String Bool -> Match GhcTc (LHsExpr GhcTc) -> TcM ([TypeOfUpdate], [FunctionInfo]) +getAllNeededFun recordType enumList enumType fieldType _ _ allBinds processedPats match = do + let normalBinds = (\(GRHS _ _ stmt )-> stmt ) <$> unLoc <$> (grhssGRHSs $ m_grhss match) + argBinds = m_pats match + a = any isVarPat argBinds + checker = any (\val -> isVarPatExprBool val) (normalBinds ^? biplateRef :: [LHsExpr GhcTc] ) + if checker then pure ([], []) else + if a then do + let allNrFuns = nub $ concatMap processExpr (normalBinds ^? biplateRef :: [LHsExpr GhcTc] ) + allRecordUpdsAndCrea <- mapM (getDataTypeDetails recordType enumList enumType fieldType processedPats allBinds) $ normalBinds ^? biplateRef + -- liftIO $ print (checker, a, showSDocUnsafe $ ppr argBinds, showSDocUnsafe $ ppr match, catMaybes allRecordUpdsAndCrea) + pure $ (catMaybes allRecordUpdsAndCrea, allNrFuns) + else pure ([], []) +-- liftIO $ print ("All updates ", allRecordUpdsAndCreate) +-- liftIO $ print (name, showSDocUnsafe <$> ppr <$> allValsTypes, any (\val -> isInfixOf val (showSDocUnsafe $ ppr x) ) enumList , map (\val -> (lastMaybe (splitOn " " $ replace "->" "" $ showSDocUnsafe $ ppr val))) allValsTypes) + +loopAndColect :: HM.HashMap String [FunctionInfo] -> LHsBindLR GhcTc GhcTc -> IO ((HM.HashMap String [FunctionInfo])) +loopAndColect allFunsList x@(L _ AbsBinds {abs_binds = binds}) = do + let fname = map name $ map (\y -> transformFromNameStableString (y) (showSDocUnsafe $ ppr $ getLoc $ x) False) $ (getFunctionName x) + allVals = ((bagToList binds ^? biplateRef :: [LHsExpr GhcTc])) +-- liftIO $ print (fname, showSDocUnsafe $ ppr x) +-- let allBinds = concat $ mapMaybe loopOverFunBind (bagToList binds ^? biplateRef :: [LHsBindLR GhcTc GhcTc]) + allNrFuns = nub $ ((concatMap processExpr allVals)) + pure $ foldl (\acc val -> HM.insert (val) (nub allNrFuns) acc) allFunsList fname +loopAndColect allFunsList _ = pure allFunsList + +loopOverFunBindM :: LHsBindLR GhcTc GhcTc -> (IO (Maybe [String])) +#if __GLASGOW_HASKELL__ < 900 +loopOverFunBindM (L _ (FunBind _ _ matches _ _)) = do +#else +loopOverFunBindM (L _ (FunBind _ _ matches _)) = do +#endif + let inte = unLoc $ mg_alts matches + -- print ("iam here", showSDocUnsafe $ ppr x, length inte) + if null inte then pure Nothing else do + y <- mapM loopOverVarPatM $ m_pats $ unLoc $ head inte + pure $ Just $ catMaybes y +loopOverFunBindM (L _ _) = do + -- print ("iam not here " ++ show (toConstr x) ++ (showSDocUnsafe $ ppr x)) + pure Nothing + +processAllLetPatsM :: LHsBindLR GhcTc GhcTc -> (Maybe (String, [FunctionInfo])) +#if __GLASGOW_HASKELL__ < 900 +processAllLetPatsM (L _ (FunBind _ name matches _ _)) = do +#else +processAllLetPatsM (L _ (FunBind _ name matches _)) = do +#endif + let inte = unLoc $ mg_alts matches + if null inte then Nothing + else Just (nameStableString $ varName $ unLoc name, concat $ map (\(GRHS _ _ val) -> processExpr val) $ map unLoc $ grhssGRHSs $ m_grhss $ unLoc $ head $ inte ) +processAllLetPatsM (L _ _) = do + Nothing + + +loopOverFunBind :: LHsBindLR GhcTc GhcTc -> (Maybe [String]) +#if __GLASGOW_HASKELL__ < 900 +loopOverFunBind (L _ (FunBind _ _ matches _ _)) = do +#else +loopOverFunBind (L _ (FunBind _ _ matches _)) = do +#endif + let inte = unLoc $ mg_alts matches + if null inte then Nothing else do + let y = mapMaybe loopOverVarPat $ m_pats $ unLoc $ head inte + Just y +loopOverFunBind (L _ _) = do + Nothing + +loopOverVarPat :: LPat GhcTc -> Maybe String +loopOverVarPat (L _ (VarPat _ (L _ name))) = Just $ nameStableString $ varName name +loopOverVarPat (L _ _) = Nothing + +loopOverVarPatM :: LPat GhcTc -> IO (Maybe String) +loopOverVarPatM (L _ (VarPat _ (L _ name))) = pure $ Just $ nameStableString $ varName name +loopOverVarPatM (L _ _) = do + -- print ("loopOverVarPatM", show $ toConstr x) + pure Nothing + +mkGhcCompileError :: CompileError -> (SrcSpan, OP.SDoc) +mkGhcCompileError err = (src_span err, OP.text $ err_msg err) + +transformFromNameStableString :: (String) -> String -> Bool -> FunctionInfo +transformFromNameStableString ( str) loc isF = + let parts = filter (\x -> x /= "") $ splitOn ("$") str + in if length parts == 2 then FunctionInfo "" (parts !! 0) (parts !! 1) loc isF + else if length parts == 3 then FunctionInfo (parts !! 0) (parts !! 1) (parts !! 2) loc isF + else FunctionInfo "" "" "" loc isF + +parseYAMLFile :: (FromJSON a) => FilePath -> IO (Either ParseException a) +parseYAMLFile file = decodeFileEither file \ No newline at end of file diff --git a/dc/src/DC/Types.hs b/dc/src/DC/Types.hs new file mode 100644 index 0000000..923817e --- /dev/null +++ b/dc/src/DC/Types.hs @@ -0,0 +1,134 @@ +{-# LANGUAGE DerivingStrategies, CPP, RecordWildCards, DeriveGeneric, DeriveAnyClass #-} + +module DC.Types where + +import qualified Data.HashMap.Strict as HM +import Data.Aeson +import GHC hiding (typeKind) +import GHC.Generics (Generic) +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Data.FastString +#else +import GhcPlugins hiding ((<>)) +#endif + +data PluginOpts = PluginOpts { + failOnFileNotFound :: Bool, + domainConfigFile :: String, + pathsTobeChecked :: [String] + } deriving (Show, Eq) + +defaultPluginOpts :: PluginOpts +defaultPluginOpts = + PluginOpts { + failOnFileNotFound = True, + domainConfigFile = ".juspay/domainConfig.yaml", + pathsTobeChecked = ["euler-x/src","euler-x/src-generated","euler-x/src-extras","euler-api-decider/src", "ecPrelude/src", "ecPrelude/src-generated","ecPrelude/src-extras", "oltp/src", "oltp/src-generated","oltp/src-extras", "dbTypes/src-generated", "src/"] + } + +instance FromJSON PluginOpts where + parseJSON = withObject "PluginOpts" $ \o -> do + failOnFileNotFound <- o .:? "failOnFileNotFound" .!= (failOnFileNotFound defaultPluginOpts) + domainConfigFile <- o .:? "domainConfigFile" .!= (domainConfigFile defaultPluginOpts) + pathsTobeChecked <- o .:? "pathsTobeChecked" .!= (pathsTobeChecked defaultPluginOpts) + return PluginOpts {domainConfigFile = domainConfigFile, failOnFileNotFound = failOnFileNotFound, pathsTobeChecked = pathsTobeChecked } + +data EnumCheck = + EnumCheck + { enumList :: [String] + , enumType :: String + , recordType :: String + , fieldType :: String + } + deriving (Generic, Show, Eq, Ord) + deriving (ToJSON, FromJSON) + +data FunctionCheckConfig = + FunctionCheckConfig + { listOfRestrictedFuns :: [String] + , moduleNameToCheck :: String + , funNameToCheck :: String + } + deriving (Generic, Show, Eq, Ord) + deriving (ToJSON, FromJSON) + +data CheckerConfig = FieldsCheck EnumCheck | FunctionCheck FunctionCheckConfig + deriving (Generic, Show, Eq, Ord) + deriving (ToJSON, FromJSON) + +data ErrorCase = Errors [CompileError] | Functions (HM.HashMap String [CompileError]) + deriving (Show, Eq) + +data FunctionInfo = FunctionInfo + { package_name :: String + , module_name :: String + , name :: String + , src_loc :: String + , isFailure :: Bool + } deriving (Show, Eq, Ord) + +data UpdateInfo = UpdateInfo + { createdRecordsFun :: [FunctionInfo] + , updatedRecordsFun :: [FunctionInfo] + , updatedFailurs :: [FunctionInfo] + , createdFailurs :: [FunctionInfo] + , allFailures :: [FunctionInfo] + , commonError :: [FunctionInfo] + } deriving (Show, Eq, Ord) + +data UpdateInfoAsText = UpdateInfoAsText + { createdRecords :: [String] + , updatedRecords :: [String] + , updatedFailures :: [String] + , createdFailures :: [String] + , allFailuresRecords :: [String] + , commonErrorFuns :: [String] + } deriving (Generic, Show, Eq, Ord) + deriving (ToJSON, FromJSON) + +data TypeOfUpdate = Update | Create | UpdateWithFailure | CreateWithFailure | Default | NoChange + deriving (Show, Eq, Ord) + +data CompileError = CompileError + { + pkg_name :: String, + mod_name :: String, + err_msg :: String, + src_span :: SrcSpan + } deriving (Eq, Show, Generic) + +instance ToJSON CompileError where +#if __GLASGOW_HASKELL__ < 900 + toJSON (CompileError pkg modName errMsg (RealSrcSpan srcLoc)) = +#else + toJSON (CompileError pkg modName errMsg (RealSrcSpan srcLoc _)) = +#endif + object [ "package_name" .= pkg + , "module_name" .= modName + , "error_message" .= errMsg + , "src_span_name" .= unpackFS (srcSpanFile srcLoc) + , "src_span_startl" .= (srcSpanStartLine srcLoc) + , "src_span_endl" .= (srcSpanEndLine srcLoc) + , "src_span_startC" .= (srcSpanStartCol srcLoc) + , "src_span_endC" .= (srcSpanEndCol srcLoc) + ] + toJSON (CompileError pkg modName errMsg _) = + object [ "package_name" .= pkg + , "module_name" .= modName + , "error_message" .= errMsg + ] + + +instance FromJSON CompileError where + parseJSON = withObject "CompileError" $ \o -> do + pkg_name <- o .: "package_name" + mod_name <- o .: "module_name" + err_msg <- o .: "error_message" + src_span_name <- o .: "src_span_name" + src_span_startl <- o .: "src_span_startl" + src_span_endl <- o .: "src_span_endl" + src_span_startC <- o .: "src_span_startC" + src_span_endC <- o .: "src_span_endC" + src_span <- pure $ mkSrcSpan (mkSrcLoc (mkFastString src_span_name) src_span_startl src_span_startC) (mkSrcLoc (mkFastString src_span_name) src_span_endl src_span_endC) + return CompileError { .. } + diff --git a/dc/test/Main.hs b/dc/test/Main.hs new file mode 100644 index 0000000..3e2059e --- /dev/null +++ b/dc/test/Main.hs @@ -0,0 +1,4 @@ +module Main (main) where + +main :: IO () +main = putStrLn "Test suite not yet implemented." diff --git a/fdep/fdep.cabal b/fdep/fdep.cabal index 4c7f867..1bb8f1f 100644 --- a/fdep/fdep.cabal +++ b/fdep/fdep.cabal @@ -12,7 +12,7 @@ build-type: Simple extra-doc-files: CHANGELOG.md common common-options - build-depends: base ^>=4.14.3.0 + build-depends: base ghc-options: -Wall -Wincomplete-uni-patterns -Wincomplete-record-updates @@ -39,17 +39,26 @@ library bytestring , containers , filepath - , ghc ^>= 8.10.7 + , ghc , ghc-exactprint , unordered-containers - , uniplate >= 1.6 && < 1.7 + , uniplate , references , classyplate , aeson , directory , extra , aeson-pretty - , streamly + , streamly-core + , async + , time + , text + , binary + , conduit + , deepseq + , websockets + , network + , primitive hs-source-dirs: src default-language: Haskell2010 @@ -68,4 +77,5 @@ test-suite fdep-test main-is: Main.hs build-depends: fdep + , text ghc-options: -fplugin=Fdep.Plugin -fplugin-opt Fdep.Plugin:./tmp/fdep/ diff --git a/fdep/fdep_merge.py b/fdep/fdep_merge.py new file mode 100644 index 0000000..27304c6 --- /dev/null +++ b/fdep/fdep_merge.py @@ -0,0 +1,118 @@ +import json +import os +import concurrent.futures +import sys + +base_dir_path = "/tmp/fdep/" +fdep = dict() +def replace_all(text, replacements): + for old, new in replacements: + text = text.replace(old, new) + return text + +def get_module_name(base_dir_path, path, to_replace=""): + path = path.replace(base_dir_path, "") + patterns = [ + ("src/", "src/"), + ("src-generated/", "src-generated/"), + ("src-extras/", "src-extras/") + ] + for pattern, split_pattern in patterns: + if pattern in path: + path = path.split(split_pattern)[-1] + break + module_name = replace_all(path, [("/", "."), (to_replace, "")]) + return module_name + +def list_files_recursive(directory): + files_list = [] + for root, dirs, files in os.walk(directory): + for file in files: + if ".hs.json" in file: + files_list.append(os.path.join(root, file)) + return files_list + +def process_each_fdep_module(obj,code_string_dict): + local_fdep = {} + + def update_nested_key(d,keys, value): + current = d + try: + for key in keys[:-1]: + if current != None: + if current.get(key) == None: + current[key] = {} + current[key]["where_functions"] = {} + else: + if current[key].get("where_functions") == None: + current[key]["where_functions"] = {} + current = current[key] + else: + current = {} + current["where_functions"] = {} + current["where_functions"][keys[-1]] = value + except Exception as e: + print("update_nested_key",e) + + for (functionsName,functionData) in obj.items(): + if not "::" in functionsName: + fName = functionsName.replace("$_in$","") + srcLoc = functionsName.replace("$_in$","").split("**")[1] + try: + if local_fdep != None and local_fdep.get(fName) == None: + local_fdep[fName] = {} + local_fdep[fName]["function_name"] = fName + local_fdep[fName]["src_loc"] = srcLoc + if code_string_dict.get(fName) != None: + local_fdep[fName]["stringified_code"] = code_string_dict.get(fName,{}).get("parser_stringified_code","") + for i in functionData: + if i != None and i.get("typeSignature") != None: + local_fdep[fName]["function_signature"] = i.get("typeSignature") + elif i != None and i.get("expr") != None: + if local_fdep[fName].get("functions_called") == None: + local_fdep[fName]["functions_called"] = [] + local_fdep[fName]["functions_called"].append(i.get("expr")) + else: + local_fdep[fName]["functions_called"] = {} + except Exception as e: + exc_type, exc_obj, exc_tb = sys.exc_info() + fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] + print(e,fname, exc_tb.tb_lineno) + else: + parentFunctions = functionsName.replace("$_in$","").split("::") + (currentFunctionName,currentFunctionSrcLocation) = parentFunctions[(len(parentFunctions) - 1)].split("**") + currentFunctionDict = dict() + for i in functionData: + if i != None and i.get("typeSignature") != None: + currentFunctionDict["function_signature"] = i.get("typeSignature") + currentFunctionDict["src_log"] = currentFunctionSrcLocation + currentFunctionDict["function_name"] = currentFunctionName + elif i != None and i.get("expr") != None: + if currentFunctionDict.get("functions_called") == None: + currentFunctionDict["functions_called"] = [] + currentFunctionDict["functions_called"].append(i.get("expr")) + update_nested_key(local_fdep,parentFunctions,currentFunctionDict) + return local_fdep + +def process_fdep_output(file): + code_string_dict = dict() + if os.path.exists(file.replace(".hs.json",".hs.function_code.json")): + with open(file.replace(".hs.json",".hs.function_code.json")) as code_string: + code_string_dict = json.load(code_string) + with open(file,'r') as f: + return process_each_fdep_module(json.load(f),code_string_dict) + +files = list_files_recursive("./euler-api-order/tmp/") + +with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_file = {executor.submit(process_fdep_output, file): file for file in files} + for future in concurrent.futures.as_completed(future_to_file): + file = future_to_file[future] + try: + module_name = get_module_name(base_dir_path,file,".hs.json") + fdep[module_name] = future.result() + except Exception as e: + print(f"Error reading {file}: {e}") + +with open("data.json","w") as f: + json.dump(fdep,f,indent=4) \ No newline at end of file diff --git a/fdep/src/Fdep/Group.hs b/fdep/src/Fdep/Group.hs index 9b9abe6..9d7e725 100644 --- a/fdep/src/Fdep/Group.hs +++ b/fdep/src/Fdep/Group.hs @@ -6,30 +6,42 @@ import System.Directory import System.FilePath import Control.Monad import qualified Data.Map as Map -import qualified Data.Set as Set -import Data.Maybe (fromMaybe) -import Data.List import Data.Aeson.Encode.Pretty (encodePretty) -import Data.List.Extra (replace,splitOn) -import System.Environment (lookupEnv) import Fdep.Types +import qualified Data.HashMap.Strict as HM +import Data.Text (Text) +import qualified Data.Text as T - -processDumpFile :: String -> FilePath -> IO (String,Map.Map String Function) -processDumpFile baseDirPath path = do - let module_name = replace ".hs.json" "" - $ replace "/" "." - $ if (("src/")) `isInfixOf` (path) - then last (splitOn ("src/") (replace baseDirPath "" path)) - else if (("src-generated/")) `isInfixOf` (path) - then last (splitOn ("src-generated/") (replace baseDirPath "" path)) - else if (("src-extras/")) `isInfixOf` (path) - then last (splitOn ("src-extras/") (replace baseDirPath "" path)) - else replace baseDirPath "" path - putStrLn module_name - content <- B.readFile path - let d = Map.fromList $ filter (\x -> fst x /= "") $ map (\x -> (function_name x,x)) $ fromMaybe [] (Aeson.decode content :: Maybe [Function]) +processDumpFile :: Text -> Text -> Text -> IO (Text,Map.Map Text Function) +processDumpFile toReplace baseDirPath path = do + let module_name = T.replace toReplace "" + $ T.replace "/" "." + $ if (("src/")) `T.isInfixOf` (path) + then last (T.splitOn ("src/") (T.replace baseDirPath "" path)) + else if (("src-generated/")) `T.isInfixOf` (path) + then last (T.splitOn ("src-generated/") (T.replace baseDirPath "" path)) + else if (("src-extras/")) `T.isInfixOf` (path) + then last (T.splitOn ("src-extras/") (T.replace baseDirPath "" path)) + else T.replace baseDirPath "" path + parserCodeExists <- doesFileExist (T.unpack $ T.replace ".json" ".function_code.json" path) + contentHM <- if parserCodeExists + then do + parsercode <- B.readFile $ T.unpack $ T.replace ".json" ".function_code.json" path + case Aeson.decode parsercode of + (Just (x :: HM.HashMap Text PFunction)) -> pure $ x + Nothing -> pure $ HM.empty + else pure HM.empty + content <- B.readFile $ T.unpack path + decodedContent <- case Aeson.decode content of + (Just (x :: HM.HashMap Text Function)) -> pure $ x + Nothing -> pure $ HM.empty + let d = Map.fromList $ filter (\x -> fst x /= "") $ map (\(name,x) -> (name,updateCodeString (name) x contentHM)) $ HM.toList $ decodedContent pure (module_name, d) + where + updateCodeString functionName functionObject contentHM = + case HM.lookup functionName contentHM of + Just val -> functionObject {stringified_code = (parser_stringified_code val)} + Nothing -> functionObject run :: Maybe String -> IO () run bPath = do @@ -38,8 +50,8 @@ run bPath = do Just val -> val _ -> "/tmp/fdep/" files <- getDirectoryContentsRecursive baseDirPath - let jsonFiles = filter (\x -> (".hs.json" `isSuffixOf`) $ x) files - functionGraphs <- mapM (processDumpFile baseDirPath) jsonFiles + let jsonFiles = map T.pack $ filter (\x -> (".hs.json" `T.isSuffixOf`) $ T.pack x) files + (functionGraphs) <- mapM (processDumpFile ".hs.json" (T.pack baseDirPath)) jsonFiles B.writeFile (baseDirPath <> "data.json") (encodePretty (Map.fromList functionGraphs)) getDirectoryContentsRecursive :: FilePath -> IO [FilePath] diff --git a/fdep/src/Fdep/Plugin.hs b/fdep/src/Fdep/Plugin.hs index b26a3c9..dff78f8 100644 --- a/fdep/src/Fdep/Plugin.hs +++ b/fdep/src/Fdep/Plugin.hs @@ -1,108 +1,97 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} -module Fdep.Plugin (plugin) where +module Fdep.Plugin (plugin,collectDecls) where -import Annotations -import Avail -import Bag (bagToList, listToBag) -import BasicTypes (FractionalLit (..), IntegralLit (..)) -import Control.Concurrent -import Control.Exception (SomeException, try) -import Control.Monad (foldM, when) +import Control.Concurrent ( forkIO ) +import Control.DeepSeq (force) +import Control.Exception (SomeException, evaluate, try) +import Control.Monad (void, when) import Control.Monad.IO.Class (MonadIO (..)) import Control.Reference (biplateRef, (^?)) -import Data.Aeson +import Data.Aeson ( encode, Value(String, Object), ToJSON(toJSON) ) +import qualified Data.Aeson as A import Data.Aeson.Encode.Pretty (encodePretty) import Data.Bool (bool) -import Data.ByteString.Lazy (writeFile) +import Data.ByteString.Lazy (toStrict, writeFile) +import qualified Data.ByteString.Lazy as BL import Data.Data (toConstr) import Data.Generics.Uniplate.Data () -import Data.List -import Data.List (nub) -import Data.List.Extra (replace, splitOn) +import Data.List.Extra (splitOn) import qualified Data.Map as Map -import Data.Maybe (catMaybes, fromJust, fromMaybe, isJust, mapMaybe) -import DynFlags () +import Data.Maybe (fromJust, fromMaybe, isJust) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Time ( diffUTCTime, getCurrentTime ) import Fdep.Types -import GHC ( - GRHS (..), - FieldOcc(..), - rdrNameAmbiguousFieldOcc, - GRHSs (..), - GenLocated (L), - GhcPass, - GhcTc, - HsBindLR (..), - HsConDetails (..), - HsConPatDetails, - HsExpr (..), - HsRecField' (..), - HsRecFields (..), - HsValBinds (..), - Id (..), - IdP (..), - LGRHS, - LHsCmd (..), - LHsExpr, - LHsRecField, - LHsRecUpdField (..), - LMatch, - LPat, - Match (m_grhss), - MatchGroup (..), - Module (moduleName), - Name, - OutputableBndrId, - OverLitVal (..), - Pat (..), - PatSynBind (..), - StmtLR (..), - TyCon, - getName, - moduleNameString, - nameSrcSpan, - noLoc, - ol_val, - ol_witness, - ) + ( PFunction(PFunction), + FunctionInfo(FunctionInfo) ) +import Text.Read (readMaybe) +import Prelude hiding (id, mapM, mapM_, writeFile) +import qualified Prelude as P +import qualified Data.List.Extra as Data.List +import Network.Socket (withSocketsDo) +import qualified Network.WebSockets as WS +import System.Directory ( createDirectoryIfMissing ) +import System.Environment (lookupEnv) +import GHC.IO (unsafePerformIO) +#if __GLASGOW_HASKELL__ >= 900 +import Streamly.Internal.Data.Stream (fromList,mapM_,mapM,toList) +import GHC +import GHC.Driver.Plugins (Plugin(..),CommandLineOption,defaultPlugin,PluginRecompile(..)) +import GHC.Driver.Env +import GHC.Tc.Types +import GHC.Unit.Module.ModSummary +import GHC.Utils.Outputable (showSDocUnsafe,ppr) +import GHC.Data.Bag (bagToList) +import GHC.Types.Name hiding (varName) +import GHC.Types.Var +import qualified Data.Aeson.KeyMap as HM +import GHC.Hs.Expr +#else +import Streamly.Internal.Data.Stream (fromList, mapM, mapM_, toList) +import qualified Data.HashMap.Strict as HM +import Bag (bagToList) +import DynFlags () +import GHC import GHC.Hs.Binds -import GHC.Hs.Expr (unboundVarOcc) -import GHC.Hs.Utils as GHCHs -import GhcPlugins (RdrName(..),rdrNameOcc,Plugin (pluginRecompile), PluginRecompile (..), Var (..), binderArgFlag, binderType, binderVars, elemNameSet, getOccString, idName, idType, nameSetElemsStable, ppr, pprPrefixName, pprPrefixOcc, showSDocUnsafe, tidyOpenType, tyConBinders, unLoc, unpackFS) -import HscTypes (ModSummary (..), typeEnvIds) -import Name (nameStableString,occName,occNameString,occNameSpace,occNameFS,pprNameSpaceBrief) + ( HsBindLR(PatBind, FunBind, AbsBinds, VarBind, PatSynBind, + XHsBindsLR, fun_id, abs_binds, var_rhs), + LHsBindLR, + HsValBindsLR(XValBindsLR, ValBinds), + HsLocalBindsLR(HsValBinds), + NHsValBindsLR(NValBinds), + PatSynBind(XPatSynBind, PSB, psb_def) ) +import GHC.Hs.Decls + ( HsDecl(SigD, TyClD, InstD, DerivD, ValD), LHsDecl ) +import GhcPlugins (HsParsedModule, Hsc, Plugin (..), PluginRecompile (..), Var (..), getOccString, hpm_module, ppr, showSDocUnsafe) +import HscTypes (ModSummary (..),msHsFilePath) +import Name (nameStableString) import Outputable () -import PatSyn -import Plugins (CommandLineOption, Plugin (typeCheckResultAction), defaultPlugin) -import SrcLoc -import Streamly -import Streamly.Prelude (drain, fromList, mapM, mapM_, toList) -import System.Directory -import System.Directory (createDirectoryIfMissing, getHomeDirectory) -import TcEnv +import Plugins (CommandLineOption, defaultPlugin) +import SrcLoc ( GenLocated(L), getLoc, noLoc, unLoc ) import TcRnTypes (TcGblEnv (..), TcM) -import TyCoPpr (pprSigmaType, pprTypeApp, pprUserForAll) -import TyCon -import Prelude hiding (id, mapM, mapM_, writeFile) +import StringBuffer +#endif plugin :: Plugin plugin = defaultPlugin { typeCheckResultAction = fDep - , pluginRecompile = purePlugin + , pluginRecompile = (\_ -> return NoForceRecompile) + , parsedResultAction = collectDecls } -purePlugin :: [CommandLineOption] -> IO PluginRecompile -purePlugin _ = return NoForceRecompile - +filterList :: [Text] filterList = [ "show" , "showsPrec" , "from" , "to" + , "showList" , "toConstr" , "toDomResAcc" , "toEncoding" @@ -114,6 +103,7 @@ filterList = , "toJSON" , "toJSONList" , "toJSONWithOptions" + , "encodeJSON" , "gfoldl" , "ghmParser" , "gmapM" @@ -154,388 +144,441 @@ filterList = , "fromXml" ] -fDep :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv -fDep opts modSummary tcEnv = do - liftIO $ +collectDecls :: [CommandLineOption] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule +collectDecls opts modSummary hsParsedModule = do + _ <- liftIO $ forkIO $ do let prefixPath = case opts of [] -> "/tmp/fdep/" local : _ -> local - moduleName' = moduleNameString $ moduleName $ ms_mod modSummary - modulePath = prefixPath <> ms_hspp_file modSummary - when True $ do - let path = (intercalate "/" . reverse . tail . reverse . splitOn "/") modulePath - print ("generating dependancy for module: " <> moduleName' <> " at path: " <> path) - let binds = bagToList $ tcg_binds tcEnv - depsMapList <- toList $ parallely $ mapM loopOverLHsBindLR $ fromList $ binds - functionVsUpdates <- getAllTypeManipulations binds - createDirectoryIfMissing True path - writeFile ((modulePath) <> ".typeUpdates.json") $ (encodePretty $ functionVsUpdates) - writeFile ((modulePath) <> ".json") (encodePretty $ concat depsMapList) - writeFile ((modulePath) <> ".missing.signatures.json") $ - encodePretty $ - Map.fromList $ - map (\element -> (\(x, y) -> (x, typeSignature y)) $ filterForMaxLenTypSig element) $ - groupBy (\a b -> (srcSpan a) == (srcSpan b)) $ - dumpMissingTypeSignatures tcEnv - print ("generated dependancy for module: " <> moduleName' <> " at path: " <> path) - return tcEnv - where - filterForMaxLenTypSig :: [MissingTopLevelBindsSignature] -> (String, MissingTopLevelBindsSignature) - filterForMaxLenTypSig x = - case x of - [el] -> (srcSpan $ el, el) - [el1, el2] -> (srcSpan el1, bool (el2) (el1) ((length $ typeSignature $ el1) > (length $ typeSignature $ el2))) - (xx : xs) -> (\(y, yy) -> (srcSpan xx, bool (yy) (xx) ((length $ typeSignature $ xx) > (length $ typeSignature $ yy)))) $ filterForMaxLenTypSig xs + modulePath = prefixPath <> msHsFilePath modSummary + path = (Data.List.intercalate "/" . reverse . tail . reverse . splitOn "/") modulePath + declsList = hsmodDecls $ unLoc $ hpm_module hsParsedModule + createDirectoryIfMissing True path + functionsVsCodeString <- toList $ mapM getDecls $ fromList declsList + writeFile (modulePath <> ".function_code.json") (encodePretty $ Map.fromList $ concat functionsVsCodeString) + pure hsParsedModule -getAllTypeManipulations :: [LHsBindLR GhcTc GhcTc] -> IO [DataTypeUC] -getAllTypeManipulations binds = do - bindWiseUpdates <- - toList $ - parallely $ - mapM - ( \x -> do - let functionName = getFunctionName x - filterRecordUpdateAndCon = filter (\x -> ((show $ toConstr x) `elem` ["RecordCon", "RecordUpd"])) (x ^? biplateRef :: [HsExpr GhcTc]) - pure $ bool (Nothing) (Just (DataTypeUC functionName (mapMaybe getDataTypeDetails filterRecordUpdateAndCon))) (length filterRecordUpdateAndCon > 0) - ) - (fromList binds) - pure $ catMaybes bindWiseUpdates +getDecls :: LHsDecl GhcPs -> IO [(Text, PFunction)] +getDecls x = do + case x of + (L _ (TyClD _ _)) -> pure mempty + (L _ (InstD _ _)) -> pure mempty + (L _ (DerivD _ _)) -> pure mempty + (L _ (ValD _ bind)) -> pure $ getFunBind bind + (L _ (SigD _ _)) -> pure mempty + _ -> pure mempty where - getDataTypeDetails :: HsExpr GhcTc -> Maybe TypeVsFields - getDataTypeDetails (RecordCon _ (L _ (iD)) rcon_flds) = Just (TypeVsFields (nameStableString $ getName $ idName iD) (extractRecordBinds (rcon_flds))) - getDataTypeDetails (RecordUpd _ rupd_expr rupd_flds) = Just (TypeVsFields (showSDocUnsafe $ ppr rupd_expr) (getFieldUpdates rupd_flds)) - - -- inferFieldType :: Name -> String - inferFieldTypeFieldOcc (L _ (FieldOcc _ (L _ rdrName))) = handleRdrName rdrName - inferFieldTypeAFieldOcc = (handleRdrName . rdrNameAmbiguousFieldOcc . unLoc) + getFunBind f@FunBind{fun_id = funId} = [((T.pack $ showSDocUnsafe $ ppr $ unLoc funId) <> "**" <> (T.pack $ getLoc' funId), PFunction ((T.pack $ showSDocUnsafe $ ppr $ unLoc funId) <> "**" <> (T.pack $ getLoc' funId)) (T.pack $ showSDocUnsafe $ ppr f) (T.pack $ getLoc' funId))] + getFunBind _ = mempty - handleRdrName :: RdrName -> String - handleRdrName x = - case x of - Unqual occName -> ("$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) - Qual moduleName occName -> ((moduleNameString moduleName) <> "$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) - Orig module' occName -> ((moduleNameString $ moduleName module') <> "$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) - Exact name -> nameStableString name - - getFieldUpdates :: [LHsRecUpdField GhcTc] -> [FieldRep] - getFieldUpdates fields = map extractField fields - where - extractField :: LHsRecUpdField GhcTc -> FieldRep - extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr, hsRecPun = pun})) = - if pun - then (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr lbl) (inferFieldTypeAFieldOcc lbl)) - else (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr (unLoc expr)) (inferFieldTypeAFieldOcc lbl)) - - extractRecordBinds :: HsRecFields GhcTc (LHsExpr GhcTc) -> [FieldRep] - extractRecordBinds (HsRecFields{rec_flds = fields}) = - map extractField fields - where - extractField :: LHsRecField GhcTc (LHsExpr GhcTc) -> FieldRep - extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr, hsRecPun = pun})) = - if pun - then (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr lbl) (inferFieldTypeFieldOcc lbl)) - else (FieldRep (showSDocUnsafe $ ppr lbl) (showSDocUnsafe $ ppr $ unLoc expr) (inferFieldTypeFieldOcc lbl)) +shouldForkPerFile :: Bool +shouldForkPerFile = readBool $ unsafePerformIO $ lookupEnv "SHOULD_FORK" + where + readBool :: (Maybe String) -> Bool + readBool (Just "true") = True + readBool (Just "True") = True + readBool (Just "TRUE") = True + readBool (Just "False") = False + readBool (Just "false") = False + readBool (Just "FALSE") = False + readBool _ = True - getFunctionName :: LHsBindLR GhcTc GhcTc -> [String] - getFunctionName (L _ x@(FunBind fun_ext id matches _ _)) = [nameStableString $ getName id] - getFunctionName (L _ (VarBind{var_id = var, var_rhs = expr, var_inline = inline})) = [nameStableString $ getName var] - getFunctionName (L _ (PatBind{pat_lhs = pat, pat_rhs = expr})) = [""] - getFunctionName (L _ (AbsBinds{abs_binds = binds})) = concatMap getFunctionName $ bagToList binds +shouldGenerateFdep :: Bool +shouldGenerateFdep = readBool $ unsafePerformIO $ lookupEnv "GENERATE_FDEP" + where + readBool :: (Maybe String) -> Bool + readBool (Just "true") = True + readBool (Just "True") = True + readBool (Just "TRUE") = True + readBool (Just "False") = False + readBool (Just "false") = False + readBool (Just "FALSE") = False + readBool _ = True -dumpMissingTypeSignatures :: TcGblEnv -> [MissingTopLevelBindsSignature] -dumpMissingTypeSignatures gbl_env = - let binds = (collectHsBindsBinders $ tcg_binds $ gbl_env) - whereBinds = concatMap (\x -> ((concatMap collectHsBindsBinders $ processHsLocalBindsForWhereFunctions $ unLoc $ processMatchForWhereFunctions x))) ((bagToList $ tcg_binds $ gbl_env) ^? biplateRef :: [LMatch GhcTc (LHsExpr GhcTc)]) - in nub $ mapMaybe add_bind_warn (binds <> whereBinds) +shouldLog :: Bool +shouldLog = readBool $ unsafePerformIO $ lookupEnv "ENABLE_LOGS" where - add_bind_warn :: Id -> Maybe MissingTopLevelBindsSignature - add_bind_warn id = - let name = idName id - ty = (idType id) - ty_msg = pprSigmaType ty - in add_warn (showSDocUnsafe $ ppr $ nameSrcSpan $ getName name) (showSDocUnsafe $ pprPrefixName name) (showSDocUnsafe $ ppr $ ty_msg) + readBool :: (Maybe String) -> Bool + readBool (Just "true") = True + readBool (Just "True") = True + readBool (Just "TRUE") = True + readBool _ = False + +websocketPort :: Int +websocketPort = maybe 8000 (fromMaybe 8000 . readMaybe) $ unsafePerformIO $ lookupEnv "SERVER_PORT" - add_warn "" msg ty_msg = Nothing - add_warn "" msg ty_msg = Nothing - add_warn _ msg "*" = Nothing - add_warn _ msg "* -> *" = Nothing - add_warn _ msg ('_' : xs) = Nothing - add_warn name msg ty_msg = - if "$" `isPrefixOf` msg - then Nothing - else Just $ MissingTopLevelBindsSignature{srcSpan = (name), typeSignature = (msg <> " :: " <> ty_msg)} +websocketHost :: String +websocketHost = fromMaybe "localhost" $ unsafePerformIO $ lookupEnv "SERVER_HOST" - processMatchForWhereFunctions :: LMatch GhcTc (LHsExpr GhcTc) -> LHsLocalBinds GhcTc - processMatchForWhereFunctions (L _ match) = (grhssLocalBinds (m_grhss match)) +decodeBlacklistedFunctions :: IO [Text] +decodeBlacklistedFunctions = do + mBlackListedFunctions <- lookupEnv "BLACKLIST_FUNCTIONS_FDEP" + pure $ case mBlackListedFunctions of + Just val' -> + case A.decode $ BL.fromStrict $ encodeUtf8 (T.pack val') of + Just val -> filterList <> val + _ -> filterList + _ -> filterList - processHsLocalBindsForWhereFunctions :: HsLocalBindsLR GhcTc GhcTc -> [LHsBindsLR GhcTc GhcTc] - processHsLocalBindsForWhereFunctions (HsValBinds _ (ValBinds _ x _)) = [x] - processHsLocalBindsForWhereFunctions (HsValBinds _ (XValBindsLR (NValBinds x _))) = map (\(_, binds) -> binds) $ x - processHsLocalBindsForWhereFunctions x = [] +fDep :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv +fDep opts modSummary tcEnv = do + when (shouldGenerateFdep) $ + liftIO $ + bool P.id (void . forkIO) shouldForkPerFile $ do + let prefixPath = case opts of + [] -> "/tmp/fdep/" + local : _ -> local + moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + modulePath = prefixPath <> msHsFilePath modSummary + let path = (Data.List.intercalate "/" . reverse . tail . reverse . splitOn "/") modulePath + when shouldLog $ print ("generating dependancy for module: " <> moduleName' <> " at path: " <> path) + createDirectoryIfMissing True path + let binds = bagToList $ tcg_binds tcEnv + t1 <- getCurrentTime + withSocketsDo $ do + eres <- try $ WS.runClient websocketHost websocketPort ("/" <> modulePath <> ".json") (\conn -> do mapM_ (loopOverLHsBindLR (Just conn) Nothing (T.pack ("/" <> modulePath <> ".json"))) (fromList binds)) + case eres of + Left (err :: SomeException) -> when shouldLog $ print err + Right _ -> pure () + t2 <- getCurrentTime + when shouldLog $ print ("generated dependancy for module: " <> moduleName' <> " at path: " <> path <> " total-timetaken: " <> show (diffUTCTime t2 t1)) + return tcEnv -transformFromNameStableString :: (Maybe String, Maybe String, Maybe String, [String]) -> Maybe FunctionInfo +transformFromNameStableString :: (Maybe Text, Maybe Text, Maybe Text, [Text]) -> Maybe FunctionInfo transformFromNameStableString (Just str, Just loc, _type, args) = - let parts = filter (\x -> x /= "") $ splitOn ("$") str + let parts = filter (\x -> x /= "") $ T.splitOn ("$") str in Just $ if length parts == 2 then FunctionInfo "" (parts !! 0) (parts !! 1) (fromMaybe "" _type) loc args else FunctionInfo (parts !! 0) (parts !! 1) (parts !! 2) (fromMaybe "" _type) loc args transformFromNameStableString (Just str, Nothing, _type, args) = - let parts = filter (\x -> x /= "") $ splitOn ("$") str + let parts = filter (\x -> x /= "") $ T.splitOn ("$") str in Just $ if length parts == 2 then FunctionInfo "" (parts !! 0) (parts !! 1) (fromMaybe "" _type) "" args else FunctionInfo (parts !! 0) (parts !! 1) (parts !! 2) (fromMaybe "" _type) "" args -filterFunctionInfos :: [Maybe FunctionInfo] -> IO [Maybe FunctionInfo] -filterFunctionInfos infos = do - let grouped = groupBy (\info1 info2 -> src_Loc info1 == src_Loc info2 && name info1 == name info2) $ catMaybes infos - pure $ - map (Just) $ - concat $ - map - ( \group -> - if length group == 1 - then group - else concat $ map (\x -> if (null $ arguments x) then [] else [x]) group - ) - $ grouped +sendTextData' :: WS.Connection -> Text -> Text -> IO () +sendTextData' conn path data_ = do + t1 <- getCurrentTime + res <- try $ WS.sendTextData conn data_ + case res of + Left (err :: SomeException) -> do + when (shouldLog) $ print err + withSocketsDo $ WS.runClient websocketHost websocketPort (T.unpack path) (\nconn -> WS.sendTextData nconn data_) + Right _ -> pure () + t2 <- getCurrentTime + when (shouldLog) $ print ("websocket call timetaken: " <> (T.pack $ show $ diffUTCTime t2 t1)) -loopOverLHsBindLR :: LHsBindLR GhcTc GhcTc -> IO [Function] -loopOverLHsBindLR (L _ x@(FunBind fun_ext id matches _ _)) = do - let funName = getOccString $ unLoc id - matchList = mg_alts matches - fName = nameStableString $ getName id - if ((funName) `elem` filterList || (("$_in$$" `isPrefixOf` fName) && (not $ "$_in$$sel:" `isPrefixOf` fName ))) - then pure [] +loopOverLHsBindLR :: Maybe WS.Connection -> (Maybe Text) -> Text -> LHsBindLR GhcTc GhcTc -> IO () +#if __GLASGOW_HASKELL__ >= 900 +loopOverLHsBindLR mConn mParentName path (L _ x@(FunBind fun_ext id matches _)) = do +#else +loopOverLHsBindLR mConn mParentName path (L _ x@(FunBind fun_ext id matches _ _)) = do +#endif + funName <- evaluate $ force $ T.pack $ getOccString $ unLoc id + fName <- evaluate $ force $ T.pack $ nameStableString $ getName id + let matchList = mg_alts matches + if funName `elem` (unsafePerformIO $ decodeBlacklistedFunctions) || ("$_in$$" `T.isPrefixOf` fName) + then pure mempty else do - (list, funcs) <- - foldM - ( \(x, y) xx -> do - (l, f) <- processMatch xx - pure $ (x <> l, y <> f) - ) - ([], []) - (unLoc matchList) - listTransformed <- filterFunctionInfos $ map transformFromNameStableString list - pure [(Function funName listTransformed (nub funcs) (showSDocUnsafe $ ppr $ getLoc id) (showSDocUnsafe $ ppr x) (showSDocUnsafe $ ppr $ varType $ unLoc id))] -loopOverLHsBindLR x@(L _ VarBind{var_rhs = rhs}) = do - pure [(Function "" (map transformFromNameStableString $ processExpr [] rhs) [] "" (showSDocUnsafe $ ppr x) "")] -loopOverLHsBindLR x@(L _ AbsBinds{abs_binds = binds}) = do - list <- toList $ parallely $ mapM loopOverLHsBindLR $ fromList $ bagToList binds - pure (concat list) -loopOverLHsBindLR x@(L _ (PatSynBind _ PSB{psb_def = def})) = do - let list = map transformFromNameStableString $ map (\(n, srcLoc) -> (Just $ nameStableString n, srcLoc,Nothing, [])) $ processPat def - pure [(Function "" list [] "" (showSDocUnsafe $ ppr x) "")] -loopOverLHsBindLR (L _ (PatSynBind _ (XPatSynBind _))) = do - pure [] -loopOverLHsBindLR (L _ (XHsBindsLR _)) = do - pure [] -loopOverLHsBindLR x@(L _ (PatBind _ _ pat_rhs _)) = do - r <- toList $ parallely $ mapM processGRHS $ fromList $ grhssGRHSs pat_rhs - let l = map transformFromNameStableString $ concat r - pure [(Function "" l [] "" (showSDocUnsafe $ ppr x) "")] + when (shouldLog) $ print ("processing function: " <> fName) +#if __GLASGOW_HASKELL__ >= 900 + name <- evaluate $ force (fName <> "**" <> (T.pack (getLoc' id))) +#else + name <- evaluate $ force (fName <> "**" <> (T.pack ((showSDocUnsafe . ppr . getLoc) id))) +#endif + typeSignature <- evaluate $ force $ (T.pack $ showSDocUnsafe (ppr (varType (unLoc id)))) + nestedNameWithParent <- evaluate $ force $ (maybe (name) (\x -> x <> "::" <> name) mParentName) + data_ <- evaluate $ force $ (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String nestedNameWithParent), ("typeSignature", String typeSignature)]) + t1 <- getCurrentTime + if isJust mConn + then do + sendTextData' (fromJust mConn) path data_ + mapM_ (processMatch (fromJust mConn) nestedNameWithParent path) (fromList $ unLoc matchList) + else + withSocketsDo $ + WS.runClient websocketHost websocketPort (T.unpack path) ( \conn -> do + sendTextData' (conn) path data_ + mapM_ (processMatch conn (nestedNameWithParent) path) (fromList $ unLoc matchList) + ) + t2 <- getCurrentTime + when (shouldLog) $ print $ "processed function: " <> fName <> " timetaken: " <> (T.pack $ show $ diffUTCTime t2 t1) +loopOverLHsBindLR mConn mParentName path (L _ AbsBinds{abs_binds = binds}) = + mapM_ (loopOverLHsBindLR mConn mParentName path) $ fromList $ bagToList binds +loopOverLHsBindLR _ _ _ (L _ VarBind{var_rhs = rhs}) = pure mempty +loopOverLHsBindLR _ _ _ (L _ (PatSynBind _ PSB{psb_def = def})) = pure mempty +loopOverLHsBindLR _ _ _ (L _ (PatSynBind _ (XPatSynBind _))) = pure mempty +loopOverLHsBindLR _ _ _ (L _ (XHsBindsLR _)) = pure mempty +loopOverLHsBindLR _ _ _ (L _ (PatBind _ _ pat_rhs _)) = pure mempty --- checkIfCreateOrUpdtingDataTypes binds = mapM_ (go) (fromList $ bagToList binds) --- where --- go (L _ (FunBind fun_ext id matches _ _)) = do +processMatch :: WS.Connection -> Text -> Text -> LMatch GhcTc (LHsExpr GhcTc) -> IO () +processMatch con keyFunction path (L _ match) = do +#if __GLASGOW_HASKELL__ >= 900 + whereClause <- (evaluate . force) =<< (processHsLocalBinds con keyFunction path $ grhssLocalBinds (m_grhss match)) +#else + whereClause <- (evaluate . force) =<< (processHsLocalBinds con keyFunction path $ unLoc $ grhssLocalBinds (m_grhss match)) +#endif + mapM_ (processGRHS con keyFunction path) $ fromList $ grhssGRHSs (m_grhss match) + pure mempty --- pure () --- go (L _ VarBind{var_rhs = rhs}) = pure () --- go (L _ AbsBinds{abs_binds = binds}) = pure () --- go (L _ (PatSynBind _ PSB{psb_def = def})) = pure () --- go (L _ (PatSynBind _ (XPatSynBind _))) = pure () +processGRHS :: WS.Connection -> Text -> Text -> LGRHS GhcTc (LHsExpr GhcTc) -> IO () +processGRHS con keyFunction path (L _ (GRHS _ _ body)) = processExpr con keyFunction path body +processGRHS _ _ _ _ = pure mempty -processMatch :: LMatch GhcTc (LHsExpr GhcTc) -> IO ([(Maybe String, Maybe String, Maybe String, [String])], [Function]) -processMatch (L _ match) = do - -- let stmts = match ^? biplateRef :: [StmtLR GhcTc GhcTc (LHsExpr GhcTc)] - -- mapM (print . showSDocUnsafe . ppr) stmts - -- let stmtsCMD = match ^? biplateRef :: [StmtLR GhcTc GhcTc (LHsCmd GhcTc)] - -- mapM (print . showSDocUnsafe . ppr) stmtsCMD - whereClause <- processHsLocalBinds $ unLoc $ grhssLocalBinds (m_grhss match) - -- let names = map (\x -> (Just (nameStableString x), Just $ showSDocUnsafe $ ppr $ getLoc $ x, mempty)) $ (match ^? biplateRef :: [Name]) - r <- toList $ parallely $ (mapM processGRHS (fromList $ grhssGRHSs (m_grhss match))) - pure $ (concat r, whereClause) +processHsLocalBinds :: WS.Connection -> Text -> Text -> HsLocalBindsLR GhcTc GhcTc -> IO () +processHsLocalBinds con keyFunction path (HsValBinds _ (ValBinds _ x y)) = do + mapM_ (loopOverLHsBindLR (Just con) (Just keyFunction) path) $ fromList $ bagToList $ x +processHsLocalBinds con keyFunction path (HsValBinds _ (XValBindsLR (NValBinds x y))) = do + mapM_ (\(recFlag, binds) -> mapM_ (loopOverLHsBindLR (Just con) (Just keyFunction) path) $ fromList $ bagToList binds) (fromList x) +processHsLocalBinds _ _ _ _ = pure mempty -processGRHS :: LGRHS GhcTc (LHsExpr GhcTc) -> IO [(Maybe String, Maybe String, Maybe String, [String])] -processGRHS (L _ (GRHS _ _ body)) = do - pure $ processExpr [] body -processGRHS _ = pure $ [] +grhsExpr :: LGRHS GhcTc (LHsExpr GhcTc) -> LHsExpr GhcTc +grhsExpr (L _ (GRHS _ _ body)) = body -processHsLocalBinds :: HsLocalBindsLR GhcTc GhcTc -> IO [Function] -processHsLocalBinds (HsValBinds _ (ValBinds _ x y)) = do - res <- toList $ parallely $ mapM loopOverLHsBindLR $ fromList $ bagToList $ x - pure $ concat res -processHsLocalBinds (HsValBinds _ (XValBindsLR (NValBinds x y))) = do - res <- - foldM - ( \acc (recFlag, binds) -> do - funcs <- toList $ parallely $ mapM loopOverLHsBindLR $ fromList $ bagToList binds - pure (acc <> funcs) - ) - [] - x - pure $ concat res -processHsLocalBinds x = - pure [] +hsStmtsExpr :: WS.Connection -> Text -> Text -> [LStmt GhcTc (LHsExpr GhcTc)] -> IO () +hsStmtsExpr con keyFunction path stmts = mapM_ (stmtExpr con keyFunction path) $ fromList stmts -processArgs (funr) = case funr of - (HsUnboundVar _ uv) -> [showSDocUnsafe $ pprPrefixOcc (unboundVarOcc uv)] - (HsConLikeOut _ c) -> [showSDocUnsafe $ pprPrefixOcc c] - (HsIPVar _ v) -> [showSDocUnsafe $ ppr v] - (HsOverLabel _ _ l) -> [showSDocUnsafe $ ppr l] - (HsLit _ lit) -> [showSDocUnsafe $ ppr lit] - (HsOverLit _ lit) -> [showSDocUnsafe $ ppr lit] - (HsPar _ e) -> [showSDocUnsafe $ ppr e] - (HsApp _ funl funr) -> processArgs (unLoc funr) <> processArgs (unLoc funl) - _ -> [] +stmtExpr :: WS.Connection -> Text -> Text -> LStmt GhcTc (LHsExpr GhcTc) -> IO () +stmtExpr con keyFunction path (L _ stmt) = case stmt of +#if __GLASGOW_HASKELL__ >= 900 + BindStmt _ pat expr -> processExpr con keyFunction path expr +#else + BindStmt _ pat expr _ _-> processExpr con keyFunction path expr +#endif + BodyStmt _ expr _ _ -> processExpr con keyFunction path expr + LastStmt _ expr _ _ -> processExpr con keyFunction path expr + ParStmt _ stmtBlocks _ _ -> mapM_ blockExprs (fromList stmtBlocks) + TransStmt{..} -> do + hsStmtsExpr con keyFunction path $ trS_stmts + processExpr con keyFunction path trS_using + maybe (pure ()) (processExpr con keyFunction path) trS_by + ApplicativeStmt _ args _ -> mapM_ (extractApplicativeArg con keyFunction path . snd) (fromList args) +#if __GLASGOW_HASKELL__ >= 900 + RecStmt{..} -> mapM_ (stmtExpr con keyFunction path) (fromList $ unXRec @(GhcTc) recS_stmts) + LetStmt _ binds -> processHsLocalBinds con keyFunction path binds +#else + RecStmt{..} -> mapM_ (stmtExpr con keyFunction path) (fromList recS_stmts) + LetStmt _ binds -> processHsLocalBinds con keyFunction path (unLoc binds) +#endif + XStmtLR{} -> pure () + where + blockExprs :: ParStmtBlock GhcTc GhcTc -> IO () + blockExprs (ParStmtBlock _ stmts _ _) = mapM_ (stmtExpr con keyFunction path) (fromList stmts) +#if __GLASGOW_HASKELL__ >= 900 + extractApplicativeArg con keyFunction path (ApplicativeArgOne _ _ arg_expr _) = processExpr con keyFunction path arg_expr + extractApplicativeArg con keyFunction path (ApplicativeArgMany _ app_stmts final_expr _ _) = do + hsStmtsExpr con keyFunction path app_stmts + processExpr con keyFunction path $ wrapXRec @(GhcTc) final_expr +#else + extractApplicativeArg con keyFunction path (ApplicativeArgOne _ _ arg_expr _ _) = processExpr con keyFunction path arg_expr + extractApplicativeArg con keyFunction path (ApplicativeArgMany _ app_stmts final_expr _) = do + hsStmtsExpr con keyFunction path app_stmts + processExpr con keyFunction path (noLoc final_expr) +#endif + extractApplicativeArg _ _ _ _ = pure () -processExpr :: [String] -> LHsExpr GhcTc -> [(Maybe String, Maybe String, Maybe String, [String])] -processExpr arguments x@(L _ (HsVar _ (L _ var))) = - let name = nameStableString $ varName var - _type = showSDocUnsafe $ ppr $ varType var - in [(Just name, Just $ showSDocUnsafe $ ppr $ getLoc $ x, Just _type, arguments)] -processExpr arguments (L _ (HsUnboundVar _ _)) = [] -processExpr arguments (L _ (HsApp _ funl funr)) = - let processedArgs = nub $ processArgs (unLoc funr) <> arguments - l = processExpr processedArgs funl - r = processExpr arguments funr - in l <> r -processExpr arguments (L _ (OpApp _ funl funm funr)) = - let l = processExpr arguments funl - m = processExpr arguments funm - r = processExpr arguments funr - in nub $ l <> m <> r -processExpr arguments (L _ (NegApp _ funl _)) = - processExpr arguments funl -processExpr arguments (L _ (HsTick _ _ fun)) = - processExpr arguments fun -processExpr arguments (L _ (HsStatic _ fun)) = - processExpr arguments fun -processExpr arguments (L _ x@(HsWrap _ _ fun)) = - let r = processArgs x - in processExpr (arguments <> r) (noLoc fun) -processExpr arguments (L _ (HsBinTick _ _ _ fun)) = - processExpr arguments fun -processExpr arguments (L _ (ExplicitList _ _ funList)) = - concatMap (processExpr arguments) funList -processExpr arguments (L _ (HsTickPragma _ _ _ _ fun)) = - processExpr arguments fun -processExpr arguments (L _ (HsSCC _ _ _ fun)) = - processExpr arguments fun -processExpr arguments (L _ (HsCoreAnn _ _ _ fun)) = - processExpr arguments fun -processExpr arguments (L _ (ExprWithTySig _ fun _)) = - processExpr arguments fun -processExpr arguments (L _ (HsDo _ _ exprLStmt)) = - let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] - in nub $ - concatMap - ( \x -> - let processedArgs = processArgs (unLoc x) - in processExpr (processedArgs) x - ) - stmts -processExpr arguments (L _ (HsLet _ exprLStmt func)) = - let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] - in processExpr arguments func <> nub (concatMap (processExpr arguments) stmts) -processExpr arguments (L _ (HsMultiIf _ exprLStmt)) = - let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] - in nub (concatMap (processExpr arguments) stmts) -processExpr arguments (L _ (HsIf _ exprLStmt funl funm funr)) = - let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) $ [funl, funm, funr] <> stmts) -processExpr arguments (L _ (HsCase _ funl exprLStmt)) = - let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) $ [funl] <> stmts) -processExpr arguments (L _ (ExplicitSum _ _ _ fun)) = processExpr arguments fun -processExpr arguments (L _ (SectionR _ funl funr)) = processExpr arguments funl <> processExpr arguments funr -processExpr arguments (L _ (ExplicitTuple _ exprLStmt _)) = - let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) stmts) -processExpr arguments (L _ (HsPar _ fun)) = - let processedArgs = processArgs (unLoc fun) - in processExpr (processedArgs) fun -processExpr arguments (L _ (HsAppType _ fun _)) = processExpr arguments fun -processExpr arguments (L _ x@(HsLamCase _ exprLStmt)) = - let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - processedArgs = processArgs (x) - res = - nub - ( concatMap - ( \x -> - let processedArgs = processArgs (unLoc x) - in processExpr (processedArgs) x - ) - stmts - ) - in case res of - [(x, y, t, [])] -> [(x, y, t, processedArgs)] - _ -> res -processExpr arguments (L _ x@(HsLam _ exprLStmt)) = +processExpr :: WS.Connection -> Text -> Text -> LHsExpr GhcTc -> IO () +processExpr con keyFunction path x@(L _ (HsVar _ (L _ var))) = do + let name = T.pack $ nameStableString $ varName var + _type = T.pack $ showSDocUnsafe $ ppr $ varType var + expr <- evaluate $ force $ transformFromNameStableString (Just name, Just $ T.pack $ getLocTC' $ x, Just _type, mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr _ _ _ (L _ (HsUnboundVar _ _)) = pure mempty +processExpr con keyFunction path (L _ (HsApp _ funl funr)) = do + processExpr con keyFunction path funl + processExpr con keyFunction path funr +processExpr con keyFunction path (L _ (OpApp _ funl funm funr)) = do + processExpr con keyFunction path funl + processExpr con keyFunction path funm + processExpr con keyFunction path funr +processExpr con keyFunction path (L _ (NegApp _ funl _)) = + processExpr con keyFunction path funl +processExpr con keyFunction path (L _ (HsTick _ _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsStatic _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsBinTick _ _ _ fun)) = + processExpr con keyFunction path fun +#if __GLASGOW_HASKELL__ < 900 +processExpr con keyFunction path (L _ (ExplicitList _ _ funList)) = + mapM_ (processExpr con keyFunction path) (fromList funList) +processExpr con keyFunction path (L _ (HsTickPragma _ _ _ _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsSCC _ _ _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsCoreAnn _ _ _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ x@(HsWrap _ _ fun)) = + processExpr con keyFunction path (noLoc fun) +processExpr con keyFunction path (L _ (HsIf _ exprLStmt funl funm funr)) = let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - processedArgs = processArgs (x) - res = - nub - ( concatMap - ( \x -> - let processedArgs = processArgs (unLoc x) - in processExpr (processedArgs <> arguments) x - ) - stmts - ) - in case res of - [(x, y, t, [])] -> [(x, y, t, processedArgs)] - _ -> res -processExpr arguments y@(L _ x@(HsLit _ hsLit)) = - [(Just $ ("$_lit$" <> (showSDocUnsafe $ ppr hsLit)), (Just $ showSDocUnsafe $ ppr $ getLoc $ y), (Just $ show $ toConstr hsLit), [])] -processExpr arguments y@(L _ x@(HsOverLit _ overLitVal)) = - [(Just $ ("$_lit$" <> (showSDocUnsafe $ ppr overLitVal)), (Just $ showSDocUnsafe $ ppr $ getLoc $ y), (Just $ show $ toConstr overLitVal), [])] -processExpr arguments (L _ (HsRecFld _ exprLStmt)) = - let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) stmts) -processExpr arguments (L _ (HsSpliceE exprLStmtL exprLStmtR)) = + in mapM_ (processExpr con keyFunction path) $ fromList $ [funl, funm, funr] <> stmts +processExpr con keyFunction path (L _ (HsTcBracketOut b exprLStmtL exprLStmtR)) = let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) (stmtsL <> stmtsR)) -processExpr arguments (L _ (ArithSeq _ (Just exprLStmtL) exprLStmtR)) = + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR) +processExpr con keyFunction path (L _ (ArithSeq _ Nothing exprLStmtR)) = + let stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + stmtsRNoLoc = (exprLStmtR ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsR <> ((map noLoc) $ stmtsRNoLoc)) +processExpr con keyFunction path (L _ (HsRecFld _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (exprLStmt ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map noLoc) stmtsNoLoc)) +processExpr con keyFunction path (L _ (HsRnBracketOut _ exprLStmtL exprLStmtR)) = let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) (stmtsL <> stmtsR)) -processExpr arguments (L _ (ArithSeq _ Nothing exprLStmtR)) = - let stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) stmtsR) -processExpr arguments (L _ (HsRnBracketOut _ exprLStmtL exprLStmtR)) = + stmtsLNoLoc = (exprLStmtL ^? biplateRef :: [HsExpr GhcTc]) + stmtsRNoLoc = (exprLStmtR ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR <> (map noLoc $ (stmtsLNoLoc <> stmtsRNoLoc))) +processExpr con keyFunction path (L _ x@(RecordCon expr (L _ (iD)) rcon_flds)) = do + let stmts = (rcon_flds ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (rcon_flds ^? biplateRef :: [HsExpr GhcTc]) + stmtsNoLocexpr = (expr ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmts <> (map noLoc) (stmtsNoLoc <> stmtsNoLocexpr)) +processExpr con keyFunction path (L _ (RecordUpd _ rupd_expr rupd_flds)) = + let stmts = (rupd_flds ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (rupd_flds ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmts <> (map noLoc) stmtsNoLoc) +#else +processExpr con keyFunction path (L _ (HsGetField _ exprLStmt _)) = + let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + stmtsNoLoc = exprLStmt ^? biplateRef :: [HsExpr GhcTc] + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map (wrapXRec @(GhcTc)) stmtsNoLoc))) +processExpr con keyFunction path (L _ (ExplicitList _ funList)) = + mapM_ (processExpr con keyFunction path) (fromList funList) +processExpr con keyFunction path (L _ (HsPragE _ _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsProc _ lPat fun)) = do + let stmts = lPat ^? biplateRef :: [LHsExpr GhcTc] + stmts' = fun ^? biplateRef :: [LHsExpr GhcTc] + mapM_ (processExpr con keyFunction path) (fromList (stmts <> stmts')) +processExpr con keyFunction path (L _ (HsIf exprLStmt funl funm funr)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) $ fromList $ [funl, funm, funr] <> stmts +processExpr con keyFunction path (L _ (HsTcBracketOut b mQW exprLStmtL exprLStmtR)) = let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) (stmtsL <> stmtsR)) -processExpr arguments (L _ (HsTcBracketOut _ exprLStmtL exprLStmtR)) = + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR) +processExpr con keyFunction path (L _ (ArithSeq _ Nothing exprLStmtR)) = + let stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + stmtsRNoLoc = (exprLStmtR ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsR <> ((map (wrapXRec @(GhcTc)) $ stmtsRNoLoc))) +processExpr con keyFunction path (L _ (HsRecFld _ exprLStmt)) = + let stmts = (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (exprLStmt ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map (wrapXRec @(GhcTc)) stmtsNoLoc))) +processExpr con keyFunction path (L _ (HsRnBracketOut _ exprLStmtL exprLStmtR)) = let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) (stmtsL <> stmtsR)) --- HsIPVar (XIPVar p) HsIPName --- HsOverLabel (XOverLabel p) (Maybe (IdP p)) FastString --- HsConLikeOut (XConLikeOut p) ConLike -processExpr arguments (L _ (RecordCon _ (L _ (iD)) rcon_flds)) = + stmtsLNoLoc = (exprLStmtL ^? biplateRef :: [HsExpr GhcTc]) + stmtsRNoLoc = (exprLStmtR ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR <> (map (wrapXRec @(GhcTc)) $ (stmtsLNoLoc <> stmtsRNoLoc))) +processExpr con keyFunction path (L _ x@(RecordCon expr (L _ (iD)) rcon_flds)) = do let stmts = (rcon_flds ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) stmts) - -- extractRecordBinds (rcon_flds) --- processExpr arguments (L _ (RecordUpd _ rupd_expr rupd_flds)) = (processExpr arguments rupd_expr) <> concatMap extractLHsRecUpdField rupd_flds -processExpr arguments (L _ (RecordUpd _ rupd_expr rupd_flds)) = + stmtsNoLoc = (rcon_flds ^? biplateRef :: [HsExpr GhcTc]) + stmtsNoLocexpr = (expr ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmts <> (map (wrapXRec @(GhcTc)) (stmtsNoLoc <> stmtsNoLocexpr))) +processExpr con keyFunction path (L _ (RecordUpd _ rupd_expr rupd_flds)) = let stmts = (rupd_flds ^? biplateRef :: [LHsExpr GhcTc]) - in nub (concatMap (processExpr arguments) stmts) - -- Just (TypeVsFields (showSDocUnsafe $ ppr rupd_expr) (getFieldUpdates rupd_flds)) -processExpr arguments _ = [] - -extractLHsRecUpdField :: GenLocated l (HsRecField' id (LHsExpr GhcTc)) -> [(Maybe String, Maybe String, Maybe String, [String])] -extractLHsRecUpdField (L _ (HsRecField _ fun _)) = processExpr [] fun - -processPat :: LPat GhcTc -> [(Name, Maybe String)] -processPat (L _ pat) = case pat of - ConPatIn _ details -> processDetails details - VarPat _ x@(L _ var) -> [(varName var, Just $ showSDocUnsafe $ ppr $ getLoc $ x)] - ParPat _ pat' -> processPat pat' - _ -> [] + stmtsNoLoc = (rupd_flds ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmts <> (map (wrapXRec @(GhcTc)) stmtsNoLoc)) +#endif +processExpr con keyFunction path (L _ (ExprWithTySig _ fun _)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsDo _ smtContext exprLStmt)) = + let stmts = (smtContext ^? biplateRef :: [LHsExpr GhcTc]) <> (exprLStmt ^? biplateRef :: [LHsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList stmts) +processExpr con keyFunction path (L _ (HsLet _ exprLStmt func)) = do +#if __GLASGOW_HASKELL__ >= 900 + processHsLocalBinds con keyFunction path exprLStmt +#else + processHsLocalBinds con keyFunction path (unLoc exprLStmt) +#endif + processExpr con keyFunction path func +processExpr con keyFunction path (L _ (HsMultiIf _ exprLStmt)) = + let stmts = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + in mapM_ (processExpr con keyFunction path) (fromList stmts) +processExpr con keyFunction path (L _ (HsCase _ funl exprLStmt)) = do + processExpr con keyFunction path funl + mapM_ (processMatch con keyFunction path) (fromList $ unLoc $ mg_alts exprLStmt) +processExpr con keyFunction path (L _ (ExplicitSum _ _ _ fun)) = processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (SectionR _ funl funr)) = processExpr con keyFunction path funl <> processExpr con keyFunction path funr +processExpr con keyFunction path (L _ (ExplicitTuple _ exprLStmt _)) = +#if __GLASGOW_HASKELL__ >= 900 + let l = (fromList exprLStmt) +#else + let l = fromList (unLoc <$> exprLStmt) +#endif + in mapM_ (\x -> + case x of + (Present _ exprs) -> processExpr con keyFunction path exprs + _ -> pure ()) l +processExpr con keyFunction path (L _ (HsPar _ fun)) = + processExpr con keyFunction path fun +processExpr con keyFunction path (L _ (HsAppType _ fun _)) = processExpr con keyFunction path fun +processExpr con keyFunction path (L _ x@(HsLamCase _ exprLStmt)) = + mapM_ (processMatch con keyFunction path) (fromList $ unLoc $ mg_alts exprLStmt) +processExpr con keyFunction path (L _ x@(HsLam _ exprLStmt)) = + mapM_ (processMatch con keyFunction path) (fromList $ unLoc $ mg_alts exprLStmt) +processExpr con keyFunction path y@(L _ x@(HsLit _ hsLit)) = do + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_lit$" <> (T.pack $ showSDocUnsafe $ ppr hsLit)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr hsLit), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path y@(L _ x@(HsOverLit _ overLitVal)) = do + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_lit$" <> (T.pack $ showSDocUnsafe $ ppr overLitVal)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr overLitVal), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path (L _ (HsSpliceE exprLStmtL exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR) +processExpr con keyFunction path (L _ (ArithSeq _ (Just exprLStmtL) exprLStmtR)) = + let stmtsL = (exprLStmtL ^? biplateRef :: [LHsExpr GhcTc]) + stmtsR = (exprLStmtR ^? biplateRef :: [LHsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList $ stmtsL <> stmtsR) +processExpr con keyFunction path y@(L _ x@(HsConLikeOut _ hsType)) = do + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_type$" <> (T.pack $ showSDocUnsafe $ ppr hsType)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr hsType), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path y@(L _ x@(HsIPVar _ implicit)) = do + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_implicit$" <> T.pack (showSDocUnsafe $ ppr implicit)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr x), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path (L _ (SectionL _ funl funr)) = processExpr con keyFunction path funl <> processExpr con keyFunction path funr +#if __GLASGOW_HASKELL__ > 900 +processExpr con keyFunction path y@(L _ (XExpr overLitVal)) = do + processXXExpr con keyFunction path overLitVal +processExpr con keyFunction path y@(L _ x@(HsOverLabel _ fs)) = do + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_overLabel$" <> (T.pack $ showSDocUnsafe $ ppr fs)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr x), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path (L _ x) = + let stmts = (x ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (x ^? biplateRef :: [HsExpr GhcTc]) + -- ids = (x ^? biplateRef :: [LIdP GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map (wrapXRec @(GhcTc)) stmtsNoLoc))) +#else +processExpr con keyFunction path y@(L _ (XExpr overLitVal)) = + let stmts = (overLitVal ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (overLitVal ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map (noLoc) stmtsNoLoc))) +processExpr con keyFunction path y@(L _ x@(HsOverLabel _ mIdp fs)) = do + print $ showSDocUnsafe $ ppr mIdp + expr <- evaluate $ force $ transformFromNameStableString (Just $ ("$_overLabel$" <> (T.pack $ showSDocUnsafe $ ppr fs)), (Just $ T.pack $ getLocTC' $ y), (Just $ T.pack $ show $ toConstr x), mempty) + sendTextData' con path (decodeUtf8 $ toStrict $ Data.Aeson.encode $ Object $ HM.fromList [("key", String keyFunction), ("expr", toJSON expr)]) +processExpr con keyFunction path (L _ x) = + let stmts = (x ^? biplateRef :: [LHsExpr GhcTc]) + stmtsNoLoc = (x ^? biplateRef :: [HsExpr GhcTc]) + -- ids = (x ^? biplateRef :: [LIdP GhcTc]) + in mapM_ (processExpr con keyFunction path) (fromList (stmts <> (map (noLoc) stmtsNoLoc))) +#endif +-- processExpr _ _ _ (L _ (HsBracket _ _)) = pure mempty +-- processExpr _ _ _ (L _ (HsProjection _ _)) = pure mempty -processDetails :: HsConPatDetails GhcTc -> [(Name, Maybe String)] -processDetails (PrefixCon args) = concatMap processPat args -processDetails (InfixCon arg1 arg2) = processPat arg1 <> processPat arg2 -processDetails (RecCon rec) = concatMap processPatField (rec_flds rec) +#if __GLASGOW_HASKELL__ > 900 +processXXExpr :: WS.Connection -> Text -> Text -> XXExprGhcTc -> IO () +processXXExpr con keyFunction path (WrapExpr (HsWrap hsWrapper hsExpr)) = do + -- print $ (showSDocUnsafe $ ppr $ hsExpr,toConstr $ hsExpr) + processExpr con keyFunction path (wrapXRec @(GhcTc) hsExpr) +processXXExpr con keyFunction path x = + let stmtsL = (x ^? biplateRef :: [HsExpr GhcTc]) + in mapM_ (processExpr con keyFunction path . (wrapXRec @(GhcTc))) (fromList stmtsL) -processPatField :: LHsRecField GhcTc (LPat GhcTc) -> [(Name, Maybe String)] -processPatField (L _ HsRecField{hsRecFieldArg = arg}) = processPat arg +getLocTC' = (showSDocUnsafe . ppr . la2r . getLoc) +getLoc' = (showSDocUnsafe . ppr . la2r . getLoc) +#else +getLocTC' = (showSDocUnsafe . ppr . getLoc) +getLoc' = (showSDocUnsafe . ppr . getLoc) +#endif \ No newline at end of file diff --git a/fdep/src/Fdep/Types.hs b/fdep/src/Fdep/Types.hs index 24c4039..dfa3370 100644 --- a/fdep/src/Fdep/Types.hs +++ b/fdep/src/Fdep/Types.hs @@ -1,112 +1,37 @@ +{-# LANGUAGE DeriveAnyClass #-} module Fdep.Types where import Data.Aeson - -data DataTypeUC = DataTypeUC { - function_name_ :: [String] - , typeVsFields :: [TypeVsFields] - } deriving (Show, Eq, Ord) - -data TypeVsFields = TypeVsFields { - type_name :: String - , fieldsVsExprs :: [(FieldRep)] -} deriving (Show, Eq, Ord) - -data FieldRep = FieldRep { - field_name :: String - , expression :: String - , field_type :: String -} deriving (Show, Eq, Ord) +import GHC.Generics (Generic) +import Data.Text +import Data.Binary +import Control.DeepSeq data FunctionInfo = FunctionInfo - { package_name :: String - , module_name :: String - , name :: String - , _type :: String - , src_Loc :: String - , arguments :: [String] - } deriving (Show, Eq, Ord) + { package_name :: Text + , module_name :: Text + , name :: Text + , _type :: Text + , src_Loc :: Text + , arguments :: [Text] + } deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) data Function = Function - { function_name :: String + { function_name :: Text , functions_called :: [Maybe FunctionInfo] , where_functions :: [Function] - , src_loc :: String - , stringified_code :: String - , function_signature :: String - } deriving (Show, Eq, Ord) + , src_loc :: Text + , stringified_code :: Text + , function_signature :: Text + } deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) data MissingTopLevelBindsSignature = MissingTopLevelBindsSignature { - srcSpan :: String - , typeSignature :: String -} deriving (Show, Eq, Ord) - -instance ToJSON MissingTopLevelBindsSignature where - toJSON (MissingTopLevelBindsSignature srcSpan typeSignature) = - object [ "srcSpan" .= srcSpan - , "typeSignature" .= typeSignature - ] - -instance ToJSON FunctionInfo where - toJSON (FunctionInfo pkg modName funcName _type srcLoc arguments) = - object [ "package_name" .= pkg - , "module_name" .= modName - , "name" .= funcName - , "_type" .= _type - , "src_loc" .= srcLoc - , "arguments" .= arguments - ] - -instance ToJSON Function where - toJSON (Function funcName funcsCalled whereFuncs srcLoc codeStringified function_signature) = - object [ "function_name" .= funcName - , "functions_called" .= funcsCalled - , "where_functions" .= whereFuncs - , "src_loc" .= srcLoc - , "code_string" .= codeStringified - , "function_signature" .= function_signature - ] - -instance FromJSON FunctionInfo where - parseJSON = withObject "FunctionInfo" $ \v -> - FunctionInfo <$> v .: "package_name" - <*> v .: "module_name" - <*> v .: "name" - <*> v .: "_type" - <*> v .: "src_loc" - <*> v .: "arguments" - -instance FromJSON Function where - parseJSON = withObject "Function" $ \v -> - Function <$> v .: "function_name" - <*> v .: "functions_called" - <*> v .: "where_functions" - <*> v .: "src_loc" - <*> v .: "code_string" - <*> v .: "function_signature" - -instance ToJSON DataTypeUC where - toJSON (DataTypeUC fn fields) = - object ["function_name" .= fn, "typeVsFields" .= fields] - -instance FromJSON DataTypeUC where - parseJSON (Object v) = - DataTypeUC <$> v .: "function_name" <*> v .: "typeVsFields" - parseJSON _ = fail "Invalid DataTypeUC JSON" - -instance ToJSON FieldRep where - toJSON (FieldRep field_name expression field_type) = - object ["field_name" .= field_name, "expression" .= expression , "field_type" .= field_type] - -instance FromJSON FieldRep where - parseJSON (Object v) = - FieldRep <$> v .: "field_name" <*> v .: "expression" <*> v .: "field_type" - parseJSON _ = fail "Invalid FieldRep JSON" - -instance ToJSON TypeVsFields where - toJSON (TypeVsFields tn fes) = - object ["type_name" .= tn, "fieldsVsExprs" .= fes] - -instance FromJSON TypeVsFields where - parseJSON (Object v) = - TypeVsFields <$> v .: "type_name" <*> v .: "fieldsVsExprs" - parseJSON _ = fail "Invalid TypeVsFields JSON" \ No newline at end of file + srcSpan :: Text + , typeSignature :: Text +} deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) + +data PFunction = PFunction { + parser_name :: Text + , parser_stringified_code :: Text + , parser_src_loc :: Text +} + deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) \ No newline at end of file diff --git a/fdep/test/Main.hs b/fdep/test/Main.hs index acd7bf5..e39dba4 100644 --- a/fdep/test/Main.hs +++ b/fdep/test/Main.hs @@ -1,20 +1,194 @@ -{-# LANGUAGE RankNTypes #-} -module Main (main,demo,ddd) where +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +module Main where + +import qualified Data.Text as T +import Data.List (foldl') +import Control.Monad (when) +import Control.Applicative ((<|>)) import qualified Fdep.Plugin () +import GHC.Base (undefined) + +-- Basic data type +data Person = Person + { name :: String + , age :: Int + } deriving (Show, Eq) + +-- Algebraic data type +data Shape + = Circle Double + | Rectangle Double Double + | Triangle Double Double Double + deriving (Show) + +-- Type with type variable +data Maybe' a = Just' a | Nothing' + deriving (Show) + +-- GADT +data Expr a where + LitInt :: Int -> Expr Int + LitBool :: Bool -> Expr Bool + Add :: Expr Int -> Expr Int -> Expr Int + IsZero :: Expr Int -> Expr Bool + +-- Type family +type family Elem a +type instance Elem [a] = a +type instance Elem (Maybe a) = a + +-- Associated type family +class Container c where + type ContainerElem c + empty :: c + insert :: ContainerElem c -> c -> c + +instance Container [a] where + type ContainerElem [a] = a + empty = [] + insert = (:) + +-- Multi-parameter type class with functional dependency +class Monad m => MonadState s m | m -> s where + get :: m s + put :: s -> m () + +-- Data kind and type-level operators +data Nat = Zero | Succ Nat + +type family (a :: Nat) :+ (b :: Nat) :: Nat where + 'Zero :+ b = b + ('Succ a) :+ b = 'Succ (a :+ b) + +-- Existential type +data SomeShape = forall a. Show a => SomeShape a + +-- Basic function +greet :: String -> String +greet name = "Hello, " ++ name ++ "!" + +-- Function with pattern matching +factorial :: Integer -> Integer +factorial 0 = 1 +factorial n = n * factorial (n - 1) + +-- Higher-order function +applyTwice :: (a -> a) -> a -> a +applyTwice f x = f (f x) + +-- Function using let binding +sumSquares :: Num a => a -> a -> a +sumSquares x y = + let squareX = x * x + squareY = y * y + in squareX + squareY + +-- Function using where clause +pythagoras :: Floating a => a -> a -> a +pythagoras a b = sqrt (aSq + bSq) + where + aSq = a * a + bSq = b * b + +-- Function with guards +absoluteValue :: (Num a, Ord a) => a -> a +absoluteValue x + | x < 0 = -x + | otherwise = x + +-- List comprehension +evenSquares :: [Integer] +evenSquares = [x^2 | x <- [1..10], even x] + +-- Do notation (with list monad) +cartesianProduct :: [a] -> [b] -> [(a, b)] +cartesianProduct xs ys = do + x <- xs + y <- ys + return (x, y) + +-- Lambda function +multiplyBy :: Num a => a -> a -> a +multiplyBy = \x y -> x * y + +-- Partial application +add5 :: Num a => a -> a +add5 = (+) 5 + +-- Function composition +lengthOfGreeting :: String -> Int +lengthOfGreeting = length . greet + +-- Folding +sumList :: Num a => [a] -> a +sumList = foldl' (+) 0 + +-- Recursive data structure +data Tree a = Leaf a | Node (Tree a) (Tree a) + deriving (Show) + +-- Function on recursive data structure +treeDepth :: Tree a -> Int +treeDepth (Leaf _) = 0 +treeDepth (Node left right) = 1 + max (treeDepth left) (treeDepth right) + +-- Monadic function +safeDivide :: (MonadFail m) => Int -> Int -> m Int -- Changed Monad m to MonadFail m +safeDivide _ 0 = fail "Division by zero" +safeDivide x y = return (x `div` y) + +data Person' = Person' String Int String + deriving (Show) + +createPerson :: Maybe String -> Maybe Int -> Maybe String -> Maybe Person' +createPerson name age email = Person' <$> name <*> age <*> email + +-- Type class instance +class Sizeable a where + size :: a -> Int + +instance Sizeable [a] where + size = length + +instance Sizeable (Tree a) where + size (Leaf _) = 1 + size (Node left right) = 1 + size left + size right --- main :: IO () -main = do - putStrLn "Test suite not yet implemented." - print ("HI there" :: String) - where - test :: String -> String - test "HI" = "HI" - test2 v = test (v <> ddd) - test10 :: String -> String - test10 _ = demo "1000" "!))" - -demo :: (Show x) => x -> x -> String -demo x y = (show y <> show x <> show 100) - -ddd = "HITHERE" \ No newline at end of file +main :: IO () +main = pure () +-- putStrLn "What's your name?" +-- name <- getLine +-- putStrLn $ greet name +-- let person = Person name 30 +-- print person +-- when (age person > 18) $ +-- putStrLn "You are an adult." +-- let shapes = [Circle 5, Rectangle 3 4, Triangle 3 4 5] +-- mapM_ print shapes +-- let someShapes = [SomeShape (Circle 1), SomeShape (Rectangle 2 3)] +-- mapM_ (\(SomeShape s) -> print s) someShapes + +-- -- Fixed safeDivide calls +-- result1 <- safeDivide 10 2 +-- print result1 +-- result2 <- safeDivide 10 0 +-- print result2 + +-- -- Fixed createPerson call +-- print $ createPerson (Just "John") (Just 25) (Just "john@example.com") + +-- let tree = Node (Node (Leaf 1) (Leaf 2)) (Leaf 3) +-- print $ treeDepth tree +-- print $ size tree +-- print evenSquares +-- print $ cartesianProduct [1,2] ['a','b'] \ No newline at end of file diff --git a/fieldInspector/.juspay/api-contract/test/Main.hs.yaml b/fieldInspector/.juspay/api-contract/test/Main.hs.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/fieldInspector/.juspay/api-contract/test/Main.hs.yaml @@ -0,0 +1 @@ +{} diff --git a/fieldInspector/README.MD b/fieldInspector/README.MD index 650774a..30e0d9b 100644 --- a/fieldInspector/README.MD +++ b/fieldInspector/README.MD @@ -51,7 +51,7 @@ The plugin generates two JSON files for each module: Contains information about types and their constructors within the module. -## The exe onsolidates these files into two comprehensive datasets: +## The exe consolidates these files into two comprehensive datasets: ### fieldUsage-data.json: consolidated all modules data diff --git a/fieldInspector/fieldInspector.cabal b/fieldInspector/fieldInspector.cabal index 0bfa8fe..c2623a2 100644 --- a/fieldInspector/fieldInspector.cabal +++ b/fieldInspector/fieldInspector.cabal @@ -48,10 +48,14 @@ build-type: Simple -- Extra doc files to be distributed with the package, such as a CHANGELOG or a README. extra-doc-files: CHANGELOG.md --- Extra source files to be distributed with the package, such as examples, or a tutorial module. --- extra-source-files: +flag enable-lr-plugins + description: set this flag to enable these plugins Data.Record.Plugin , Data.Record.Anon.Plugin , Data.Record.Plugin.HasFieldPattern + default: False + manual: True common common-options + if flag(enable-lr-plugins) + cpp-options: -DENABLE_LR_PLUGINS ghc-options: -Wall -Wincomplete-uni-patterns @@ -75,13 +79,14 @@ common common-options bytestring , containers , filepath - , ghc ^>= 8.10.7 + , ghc , unordered-containers , aeson , directory , extra , aeson-pretty - , base ^>=4.14.3.0 + , aeson + , base , text , base64-bytestring , optparse-applicative @@ -91,24 +96,29 @@ common common-options , cryptonite , hasbolt , universum + , streamly-core , data-default - , streamly + , large-records + , large-generics + , large-anon + , ghc-hasfield-plugin + , record-dot-preprocessor + , ghc-tcplugin-api + , typelet + , record-hasfield + , text + , binary + , references + , uniplate + , api-contract + , fdep library -- Import common warning flags. import: common-options -- Modules exported by the library. - exposed-modules: FieldInspector.Plugin,FieldInspector.Group,FieldInspector.Types - - -- Modules included in this library but not exported. - -- other-modules: - - -- LANGUAGE extensions used by modules in this package. - -- other-extensions: - - -- Other library packages from which modules are imported. - build-depends: base + exposed-modules: FieldInspector.PluginFields,FieldInspector.PluginTypes,FieldInspector.Group,FieldInspector.Types -- Directories containing source files. hs-source-dirs: src @@ -143,7 +153,7 @@ executable fieldInspector test-suite fieldInspector-test -- Import common warning flags. import: common-options - ghc-options: -fplugin=FieldInspector.Plugin -fplugin-opt FieldInspector.Plugin:./tmp/fieldInspector/ + ghc-options: -fplugin=ApiContract.Plugin -fplugin=FieldInspector.PluginFields -fplugin-opt FieldInspector.PluginFields:./tmp/fieldInspector/ -fplugin=FieldInspector.PluginTypes -fplugin-opt FieldInspector.PluginTypes:./tmp/fieldInspector/ -- Base language which the package is written in. default-language: Haskell2010 @@ -164,5 +174,13 @@ test-suite fieldInspector-test -- Test dependencies. build-depends: - base, - fieldInspector + base + , fieldInspector + , large-records + , large-generics + , large-anon + , record-dot-preprocessor + , ghc-tcplugin-api + , typelet + , record-hasfield + , scientific diff --git a/fieldInspector/src/FieldInspector/Group.hs b/fieldInspector/src/FieldInspector/Group.hs index 3650ef4..b480da1 100644 --- a/fieldInspector/src/FieldInspector/Group.hs +++ b/fieldInspector/src/FieldInspector/Group.hs @@ -75,7 +75,6 @@ run bPath = do Just val -> val _ -> "/tmp/fieldInspector/" files <- getDirectoryContentsRecursive baseDirPath - let jsonFiles = filter (\x -> (".hs.fieldUsage.json" `isSuffixOf`) $ x) files fieldUsage <- mapM (processDumpFileFieldUsage baseDirPath) jsonFiles B.writeFile (baseDirPath <> "/" <> "fieldUsage-data.json") (encodePretty (Map.fromList fieldUsage)) diff --git a/fieldInspector/src/FieldInspector/Plugin.hs b/fieldInspector/src/FieldInspector/Plugin.hs deleted file mode 100644 index 6f18b5d..0000000 --- a/fieldInspector/src/FieldInspector/Plugin.hs +++ /dev/null @@ -1,373 +0,0 @@ - -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE UndecidableInstances #-} - -module FieldInspector.Plugin (plugin) where - -import Control.Concurrent (MVar, modifyMVar, newMVar) -import CoreMonad (CoreM, CoreToDo (CoreDoPluginPass), liftIO) -import CoreSyn ( - AltCon (..), - Bind (NonRec, Rec), - CoreBind, - CoreExpr, - Expr (..), - mkStringLit - ) -import Data.Aeson -import Data.Aeson.Encode.Pretty (encodePretty) -import qualified Data.ByteString as DBS -import Data.ByteString.Lazy (toStrict) -import Data.Int (Int64) -import Data.List.Extra (intercalate, isSuffixOf, replace, splitOn,groupBy) -import Data.List ( sortBy, intercalate ,foldl') -import qualified Data.Map as Map -import Data.Text (Text, concat, isInfixOf, pack, unpack) -import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8, encodeUtf8) -import Data.Time -import TyCoRep -import GHC.IO (unsafePerformIO) -import GHC.Hs -import Data.Map (Map) -import Data.Data -import Data.Maybe (catMaybes) -import Control.Monad.IO.Class (liftIO) -import System.IO (writeFile) -import GHC.Hs.Decls -import GhcPlugins ( - CommandLineOption,Arg (..), - HsParsedModule(..), - Hsc, - Name,SDoc,DataCon,DynFlags,ModSummary(..),TyCon, - Literal (..),typeEnvElts, - ModGuts (mg_binds, mg_loc, mg_module),showSDoc, - Module (moduleName),tyConKind, - NamedThing (getName),getDynFlags,tyConDataCons,dataConOrigArgTys,dataConName, - Outputable (..),dataConFieldLabels, - Plugin (..), - Var,flLabel,dataConRepType, - coVarDetails, - defaultPlugin, - idName, - mkInternalName, - mkLitString, - mkLocalVar, - mkVarOcc, - moduleNameString, - nameStableString, - noCafIdInfo, - purePlugin, - showSDocUnsafe, - tyVarKind, - unpackFS, - tyConName - ) -import Id (isExportedId,idType) -import Name (getSrcSpan) -import Control.Monad (forM) -import SrcLoc -import Streamly (parallely, serially) -import Streamly.Prelude hiding (concatMap, init, length, map, splitOn,foldl') -import System.Directory (createDirectoryIfMissing, removeFile) -import System.Directory.Internal.Prelude hiding (mapM, mapM_) -import Unique (mkUnique) -import Var (isLocalId,varType) -import Prelude hiding (id, mapM, mapM_) -import FieldInspector.Types -import TcRnTypes -import TcRnMonad -import DataCon - -plugin :: Plugin -plugin = - defaultPlugin - { installCoreToDos = install - , pluginRecompile = GhcPlugins.purePlugin - , typeCheckResultAction = collectTypesTC - , parsedResultAction = collectTypeInfoParser - } - -install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo] -install args todos = return (CoreDoPluginPass "FieldInspector" (buildCfgPass args) : todos) - -removeIfExists :: FilePath -> IO () -removeIfExists fileName = removeFile fileName `catch` handleExists - where - handleExists e - | isDoesNotExistError e = return () - | otherwise = throwIO e - -buildCfgPass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts -buildCfgPass opts guts = do - let prefixPath = case opts of - [] -> "/tmp/fieldInspector/" - [local] -> local - _ -> error "unexpected no of arguments" - _ <- liftIO $ forkIO $ do - let binds = mg_binds guts - moduleN = moduleNameString $ GhcPlugins.moduleName $ mg_module guts - moduleLoc = prefixPath Prelude.<> getFilePath (mg_loc guts) - createDirectoryIfMissing True ((intercalate "/" . init . splitOn "/") moduleLoc) - removeIfExists (moduleLoc Prelude.<> ".fieldUsage.json") - print ("start generating fieldUsage for module: " <> moduleN <> " at path: " <> moduleLoc, length binds) - t1 <- getCurrentTime - l <- toList $ serially $ mapM (liftIO . toLBind) (fromList binds) - print ("started writing to file fieldUsage for module: " <> moduleN <> " at path: " <> moduleLoc, length l) - DBS.writeFile (moduleLoc Prelude.<> ".fieldUsage.json") $ toStrict $ encodePretty $ Map.fromList $ groupByFunction $ Prelude.concat l - t2 <- getCurrentTime - print $ diffUTCTime t2 t1 - print ("generated fieldUsage for module: " <> moduleN <> " at path: " <> moduleLoc, length binds) - return guts - -getFilePath :: SrcSpan -> String -getFilePath (RealSrcSpan rSSpan) = unpackFS $ srcSpanFile rSSpan -getFilePath (UnhelpfulSpan fs) = unpackFS fs - --- 1. `HasField _ r _` where r is a variable - --- 2. `HasField _ (T ...) _` if T is a data family --- (because it might have fields introduced later) - --- 3. `HasField x (T ...) _` where x is a variable, --- if T has any fields at all - --- 4. `HasField "foo" (T ...) _` if T has a "foo" field -processHasField :: Text -> Expr Var -> Expr Var -> IO [(Text,[FieldUsage])] -processHasField functionName b@(App (App (App getField (Type fieldName)) (Type haskellType@(TyConApp haskellTypeT _))) (Type finalFieldType)) hasField = - pure [(functionName,[FieldUsage (pack $ showSDocUnsafe $ ppr haskellType) (pack $ showSDocUnsafe $ ppr fieldName) (pack $ showSDocUnsafe $ ppr finalFieldType) (pack $ nameStableString $ GhcPlugins.tyConName haskellTypeT) (pack $ showSDocUnsafe $ ppr b)])] -processHasField functionName b@(App (App (App getField (Type fieldName)) (Type haskellType)) (Type finalFieldType)) hasField = - pure [(functionName,[FieldUsage (pack $ showSDocUnsafe $ ppr haskellType) (pack $ showSDocUnsafe $ ppr fieldName) (pack $ showSDocUnsafe $ ppr finalFieldType) (pack $ show $ toConstr haskellType) (pack $ showSDocUnsafe $ ppr b)])] -processHasField functionName x (Var hasField) = do - res <- toLexpr functionName x - let b = pack $ showSDocUnsafe $ ppr x - parts = words $ T.unpack $ T.replace "\t" "" $ T.replace "\n" "" $ T.strip (pack $ showSDocUnsafe $ ppr $ tyVarKind hasField) - case parts of - ["HasField", fieldName, dataType, fieldType] -> - pure $ res <> [(functionName,[ - FieldUsage - (pack dataType) - (pack $ init (Prelude.tail fieldName)) - (pack fieldType) - ""--(pack $ show $ toConstr $ haskellType) - b - ])] - ("HasField":fieldName:dataType:fieldTypeRest) -> - pure $ res <> [(functionName,[ - FieldUsage - (pack dataType) - (pack fieldName) - (pack $ unwords fieldTypeRest) - ""--(pack $ show $ toConstr $ haskellType) - b - ])] - ("Field":fieldName:dataType:fieldTypeRest) -> - pure $ res <> [(functionName,[ - FieldUsage - (pack dataType) - (pack fieldName) - (pack $ unwords fieldTypeRest) - ""--(pack $ show $ toConstr $ haskellType) - b - ])] - ("Field'":fieldName:dataType:fieldTypeRest) -> - pure $ res <> [(functionName,[ - FieldUsage - (pack dataType) - (pack fieldName) - (pack $ unwords fieldTypeRest) - ""--(pack $ show $ toConstr $ haskellType) - b - ])] - _ -> do - print (showSDocUnsafe $ ppr $ tyVarKind hasField) - pure res - -groupByFunction :: [(Text, [FieldUsage])] -> [(Text, [FieldUsage])] -groupByFunction = map mergeGroups . groupBy ((==) `on` fst) . sortBy (compare `on` fst) - where - mergeGroups :: [(Text, [FieldUsage])] -> (Text, [FieldUsage]) - mergeGroups xs = (fst (Prelude.head xs), concatMap snd xs) - -toLBind :: CoreBind -> IO [(Text,[FieldUsage])] -toLBind (NonRec binder expr) = do - res <- toLexpr (pack $ nameStableString $ idName binder) expr - pure $ groupByFunction res -toLBind (Rec binds) = do - r <- - toList $ - serially $ - mapM - ( \(b, e) -> - toLexpr (pack $ nameStableString (idName b)) e - ) - (fromList binds) - pure $ groupByFunction $ Prelude.concat r - -processFieldExtraction :: Text -> Var -> Var -> Text -> IO [(Text,[FieldUsage])] -processFieldExtraction functionName _field _type b = do - if "FunTy" == show (toConstr $ varType _field) - then do - let fieldType = T.strip $ Prelude.last $ T.splitOn "->" $ pack $ showSDocUnsafe $ ppr $ varType _field - pure [(functionName,[ - FieldUsage - (pack $ showSDocUnsafe $ ppr $ varType _type) - (pack $ showSDocUnsafe $ ppr _field) - fieldType - (pack $ show $ toConstr _type) - b - ])] - else pure mempty - -toLexpr :: Text -> Expr Var -> IO [(Text,[FieldUsage])] -toLexpr functionName (Var x) = pure mempty -toLexpr functionName (Lit x) = pure mempty -toLexpr functionName (Type _id) = pure mempty -toLexpr functionName x@(App func@(Var _field) args@(Var _type)) = do - processFieldExtraction functionName _field _type (pack $ showSDocUnsafe $ ppr x) -toLexpr functionName x@(App func@(App _ _) args@(Var isHasField)) - | "$_sys$$dHasField" == pack (nameStableString $ idName isHasField) = - processHasField functionName func args - | otherwise = do - processApp functionName x -toLexpr functionName x@(App _ _) = processApp functionName x -toLexpr functionName (Lam func args) = - toLexpr functionName args -toLexpr functionName (Let func args) = do - a <- toLexpr functionName args - f <- toLBind func - pure $ map (\(x,y) -> (functionName,y)) f <> a -toLexpr functionName (Case condition bind _type alts) = do - c <- toLexpr functionName condition - a <- toList $ serially $ mapM (toLAlt functionName) (fromList alts) - pure $ c <> Prelude.concat a -toLexpr functionName (Tick _ expr) = toLexpr functionName expr -toLexpr functionName (Cast expr _) = toLexpr functionName expr -toLexpr functionName _ = pure mempty - -processApp functionName x@(App func args) = do - f <- toLexpr functionName func - a <- toLexpr functionName args - pure $ f <> a - -toLAlt :: Text -> (AltCon, [Var], CoreExpr) -> IO [(Text,[FieldUsage])] -toLAlt functionName (DataAlt dataCon, val, e) = - toLexpr functionName e -toLAlt functionName (LitAlt lit, val, e) = - toLexpr functionName e -toLAlt functionName (DEFAULT, val, e) = - toLexpr functionName e - -collectTypesTC :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv -collectTypesTC opts modSummary tcg = do - dflags <- getDynFlags - _ <- liftIO $ - forkIO $ do - let prefixPath = case opts of - [] -> "/tmp/fieldInspector/" - local : _ -> local - moduleName' = moduleNameString $ GhcPlugins.moduleName $ ms_mod modSummary - modulePath = prefixPath <> ms_hspp_file modSummary - typeEnv = tcg_type_env tcg - path = (intercalate "/" . init . splitOn "/") modulePath - print ("generating types data for module: " <> moduleName' <> " at path: " <> path) - types <- toList $ parallely $ mapM (\tyThing -> - case tyThing of - ATyCon tyCon -> collectTyCon dflags tyCon - _ -> return []) (fromList $ typeEnvElts typeEnv) - createDirectoryIfMissing True path - DBS.writeFile (modulePath <> ".types.json") (toStrict $ encodePretty $ Map.fromList $ Prelude.concat types) - print ("generated types data for module: " <> moduleName' <> " at path: " <> path) - return tcg - -collectTyCon :: DynFlags -> GhcPlugins.TyCon -> IO [(String,TypeInfo)] -collectTyCon dflags tyCon' = do - let name = GhcPlugins.tyConName tyCon' - tyConStr = showSDoc dflags (pprTyCon name) - tyConKind' = tyConKind tyCon' - kindStr = showSDoc dflags (ppr tyConKind') - dataCons = tyConDataCons tyCon' - dataConInfos <- toList $ parallely $ mapM (collectDataCon dflags) (fromList dataCons) - return [(tyConStr,TypeInfo - { name = tyConStr - , typeKind = kindStr - , dataConstructors = dataConInfos - })] - -collectDataCon :: DynFlags -> DataCon -> IO DataConInfo -collectDataCon dflags dataCon = do - let name = GhcPlugins.dataConName dataCon - dataConStr = showSDoc dflags (pprDataCon name) - fields = map (unpackFS . flLabel) $ dataConFieldLabels dataCon - fieldTypes = map (showSDoc dflags . ppr) (dataConOrigArgTys dataCon) - fieldInfo = Map.fromList $ zip fields fieldTypes - return DataConInfo - { dataConName = dataConStr - , fields = fieldInfo - , sumTypes = getAllFunTy $ dataConRepType dataCon - } - where - getAllFunTy (FunTy _ ftArg ftRes) = [showSDoc dflags $ ppr ftArg] <> getAllFunTy ftRes - getAllFunTy _ = mempty - -pprTyCon :: Name -> SDoc -pprTyCon = ppr - -pprDataCon :: Name -> SDoc -pprDataCon = ppr - -collectTypeInfoParser :: [CommandLineOption] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule -collectTypeInfoParser opts modSummary hpm = do - _ <- liftIO $ - forkIO $ do - let prefixPath = case opts of - [] -> "/tmp/fieldInspector/" - local : _ -> local - moduleName' = moduleNameString $ GhcPlugins.moduleName $ ms_mod modSummary - modulePath = prefixPath <> ms_hspp_file modSummary - hm_module = unLoc $ hpm_module hpm - path = (intercalate "/" . init . splitOn "/") modulePath - print ("generating types data for module: " <> moduleName' <> " at path: " <> path) - types <- toList $ parallely $ mapM (pure . getTypeInfo) (fromList $ hsmodDecls hm_module) - createDirectoryIfMissing True path - DBS.writeFile (modulePath <> ".types.parser.json") (toStrict $ encodePretty $ Map.fromList $ Prelude.concat types) - print ("generated types data for module: " <> moduleName' <> " at path: " <> path) - return hpm - -getTypeInfo :: LHsDecl GhcPs -> [(String,TypeInfo)] -getTypeInfo (L _ (TyClD _ (DataDecl _ lname _ _ defn))) = - [(showSDocUnsafe (ppr lname) ,TypeInfo - { name = showSDocUnsafe (ppr lname) - , typeKind = "data" - , dataConstructors = map getDataConInfo (dd_cons defn) - })] -getTypeInfo (L _ (TyClD _ (SynDecl _ lname _ _ rhs))) = - [(showSDocUnsafe (ppr lname),TypeInfo - { name = showSDocUnsafe (ppr lname) - , typeKind = "type" - , dataConstructors = [DataConInfo (showSDocUnsafe (ppr lname)) (Map.singleton "synonym" (showSDocUnsafe (ppr rhs))) []] - })] -getTypeInfo _ = [] - -getDataConInfo :: LConDecl GhcPs -> DataConInfo -getDataConInfo (L _ ConDeclH98{ con_name = lname, con_args = args }) = - DataConInfo - { dataConName = showSDocUnsafe (ppr lname) - , fields = getFieldMap args - , sumTypes = [] -- For H98-style data constructors, sum types are not applicable - } -getDataConInfo (L _ ConDeclGADT{ con_names = lnames, con_res_ty = ty }) = - DataConInfo - { dataConName = intercalate ", " (map (showSDocUnsafe . ppr) lnames) - , fields = Map.singleton "gadt" (showSDocUnsafe (ppr ty)) - , sumTypes = [] -- For GADT-style data constructors, sum types can be represented by the type itself - } - -getFieldMap :: HsConDeclDetails GhcPs -> Map String String -getFieldMap (PrefixCon args) = Map.fromList $ Prelude.zipWith (\i t -> (show i, showSDocUnsafe (ppr t))) [1..] args -getFieldMap (RecCon (L _ fields)) = Map.fromList $ concatMap getRecField fields - where - getRecField (L _ (ConDeclField _ fnames t _)) = [(showSDocUnsafe (ppr fname), showSDocUnsafe (ppr t)) | L _ fname <- fnames] -getFieldMap (InfixCon t1 t2) = Map.fromList [("field1", showSDocUnsafe (ppr t1)), ("field2", showSDocUnsafe (ppr t2))] diff --git a/fieldInspector/src/FieldInspector/PluginFields.hs b/fieldInspector/src/FieldInspector/PluginFields.hs new file mode 100644 index 0000000..01d5100 --- /dev/null +++ b/fieldInspector/src/FieldInspector/PluginFields.hs @@ -0,0 +1,764 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE CPP #-} + +module FieldInspector.PluginFields (plugin) where + + +#if __GLASGOW_HASKELL__ >= 900 +import qualified Data.IntMap.Internal as IntMap +import Streamly.Internal.Data.Stream (fromList,mapM_,mapM,toList) +import GHC +import GHC.Driver.Plugins (Plugin(..),CommandLineOption,defaultPlugin,PluginRecompile(..)) +import GHC as GhcPlugins +import GHC.Core.DataCon as GhcPlugins +import GHC.Core.TyCo.Rep +import GHC.Core.TyCon as GhcPlugins +import GHC.Driver.Env +import GHC.Tc.Types +import GHC.Unit.Module.ModSummary +import GHC.Utils.Outputable (showSDocUnsafe,ppr,SDoc) +import GHC.Data.Bag (bagToList) +import GHC.Types.Name hiding (varName) +import GHC.Types.Var +import qualified Data.Aeson.KeyMap as HM +import GHC.Core.Opt.Monad +import GHC.Core +import GHC.Unit.Module.ModGuts +import GHC.Types.Name.Reader +import GHC.Types.Id +import GHC.Data.FastString + +#else +import CoreMonad (CoreM, CoreToDo (CoreDoPluginPass), liftIO) +import CoreSyn ( + AltCon (..), + Bind (NonRec, Rec), + CoreBind, + CoreExpr, + Expr (..), + mkStringLit + ) +import TyCoRep +import GHC.IO (unsafePerformIO) +import GHC.Hs +import GHC.Hs.Decls +import GhcPlugins ( + CommandLineOption,Arg (..), + HsParsedModule(..), + Hsc, + Name,SDoc,DataCon,DynFlags,ModSummary(..),TyCon, + Literal (..),typeEnvElts, + ModGuts (mg_binds, mg_loc, mg_module),showSDoc, + Module (moduleName),tyConKind, + NamedThing (getName),getDynFlags,tyConDataCons,dataConOrigArgTys,dataConName, + Outputable (..),dataConFieldLabels, + Plugin (..), + Var,flLabel,dataConRepType, + coVarDetails, + defaultPlugin, + idName, + mkInternalName, + mkLitString, + mkLocalVar, + mkVarOcc, + moduleNameString, + nameStableString, + noCafIdInfo, + purePlugin, + showSDocUnsafe, + tyVarKind, + unpackFS, + tyConName, + msHsFilePath + ) +import Id (isExportedId,idType) +import Name (getSrcSpan) +import SrcLoc +import Unique (mkUnique) +import Var (isLocalId,varType) +import FieldInspector.Types +import TcRnTypes +import TcRnMonad +import DataCon +import CoreMonad (CoreM, CoreToDo (CoreDoPluginPass)) +import CoreSyn ( + AltCon (..), + Bind (NonRec, Rec), + CoreBind, + CoreExpr, + Expr (..), + ) +import Bag (bagToList) +import GHC.Hs ( + ConDecl ( + ConDeclGADT, + ConDeclH98, + con_args, + con_name, + con_names, + con_res_ty + ), + ConDeclField (ConDeclField), + FieldOcc (FieldOcc), + GhcPs, + GhcTc, + HsBindLR ( + AbsBinds, + FunBind, + PatBind, + VarBind, + abs_binds, + pat_lhs, + pat_rhs, + var_id, + var_inline, + var_rhs + ), + HsConDeclDetails, + HsConDetails (InfixCon, PrefixCon, RecCon), + HsConPatDetails, + HsDataDefn (dd_cons), + HsDecl (TyClD), + HsExpr (RecordCon, RecordUpd), + HsModule (hsmodDecls), + HsRecField' (HsRecField, hsRecFieldArg, hsRecFieldLbl, hsRecPun), + HsRecFields (HsRecFields, rec_flds), + LConDecl, + LHsBindLR, + LHsDecl, + LHsExpr, + LHsRecField, + LHsRecUpdField, + LPat, + Pat (ConPatIn, ParPat, VarPat), + TyClDecl (DataDecl, SynDecl), + rdrNameAmbiguousFieldOcc, + ) +import GhcPlugins ( + CommandLineOption, + HsParsedModule (..), + PluginRecompile(..), + Hsc, + ModGuts (mg_binds, mg_loc), + ModSummary (..), + Module (moduleName), + Name, + NamedThing (getName), + Outputable (..), + Plugin (..), + RdrName (Exact, Orig, Qual, Unqual), + SDoc, + Var, + dataConName, + dataConTyCon, + defaultPlugin, + idName, + moduleNameString, + nameStableString, + purePlugin, + showSDocUnsafe, + tyConKind, + tyConName, + tyVarKind, + unpackFS, + msHsFilePath, + ) +import Name (OccName (occNameFS, occNameSpace), occNameString, pprNameSpaceBrief) +import SrcLoc ( + GenLocated (L), + RealSrcSpan (srcSpanFile), + SrcSpan (..), + getLoc, + unLoc, + ) +import TcRnMonad (MonadIO (liftIO)) +import TcRnTypes (TcGblEnv (tcg_binds), TcM) +import TyCoRep (Type (AppTy, FunTy, TyConApp, TyVarTy)) +import Var (varName, varType) +#endif + +import FieldInspector.Types +import Control.Concurrent (MVar, modifyMVar, newMVar) +import Data.Aeson +import Data.Aeson.Encode.Pretty (encodePretty) +import qualified Data.ByteString as DBS +import Data.ByteString.Lazy (toStrict) +import Data.Int (Int64) +import Data.List.Extra (intercalate, isSuffixOf, replace, splitOn,groupBy) +import Data.List ( sortBy, intercalate ,foldl') +import qualified Data.Map as Map +import Data.Text (Text, concat, isInfixOf, pack, unpack) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Time +import Data.Map (Map) +import Data.Data +import Data.Maybe (catMaybes) +import Control.Monad.IO.Class (liftIO) +import System.IO (writeFile) +import Control.Monad (forM) +import Streamly.Internal.Data.Stream hiding (concatMap, init, length, map, splitOn,foldl',intercalate) +import System.Directory (createDirectoryIfMissing, removeFile) +import System.Directory.Internal.Prelude hiding (mapM, mapM_) +import Prelude hiding (id, mapM, mapM_) +import Control.Exception (evaluate) +import Control.Exception (evaluate) +import Control.Reference (biplateRef, (^?)) +import Data.Aeson.Encode.Pretty (encodePretty) +import Data.Bool (bool) +import qualified Data.ByteString as DBS +import Data.ByteString.Lazy (toStrict) +import Data.Data (Data (toConstr)) +import Data.Generics.Uniplate.Data () +import Data.List (sortBy) +import Data.List.Extra (groupBy, intercalate, splitOn) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (mapMaybe) +import Data.Text (Text, pack) +import qualified Data.Text as T +import FieldInspector.Types ( + DataConInfo (..), + DataTypeUC (DataTypeUC), + FieldRep (FieldRep), + FieldUsage (FieldUsage), + TypeInfo (..), + TypeVsFields (TypeVsFields), + ) +import Streamly.Internal.Data.Stream (fromList, mapM, toList) +import System.Directory (createDirectoryIfMissing, removeFile) +import System.Directory.Internal.Prelude ( + catMaybes, + catch, + forkIO, + isDoesNotExistError, + on, + throwIO, + ) +import Prelude hiding (id, mapM, mapM_) + +plugin :: Plugin +plugin = + defaultPlugin + { installCoreToDos = install + , pluginRecompile = (\_ -> return NoForceRecompile) + , typeCheckResultAction = collectTypesTC + } +install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo] +install args todos = return (CoreDoPluginPass "FieldInspector" (buildCfgPass args) : todos) + +removeIfExists :: FilePath -> IO () +removeIfExists fileName = removeFile fileName `catch` handleExists + where + handleExists e + | isDoesNotExistError e = return () + | otherwise = throwIO e + +collectTypesTC :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv +collectTypesTC opts modSummary tcEnv = do + _ <- liftIO $ + forkIO $ + do + let prefixPath = case opts of + [] -> "/tmp/fieldInspector/" + local : _ -> local + modulePath = prefixPath <> msHsFilePath modSummary + path = (intercalate "/" . init . splitOn "/") modulePath + binds = bagToList $ tcg_binds tcEnv + createDirectoryIfMissing True path + functionVsUpdates <- getAllTypeManipulations binds + DBS.writeFile ((modulePath) <> ".typeUpdates.json") (toStrict $ encodePretty functionVsUpdates) + return tcEnv + +buildCfgPass :: [CommandLineOption] -> ModGuts -> CoreM ModGuts +buildCfgPass opts guts = do + let prefixPath = case opts of + [] -> "./tmp/fieldInspector/" + [local] -> local + _ -> error "unexpected no of arguments" + _ <- liftIO $ do + let binds = mg_binds guts + moduleLoc = prefixPath Prelude.<> getFilePath (mg_loc guts) + createDirectoryIfMissing True ((intercalate "/" . init . splitOn "/") moduleLoc) + removeIfExists (moduleLoc Prelude.<> ".fieldUsage.json") + l <- toList $ mapM (liftIO . toLBind) (fromList binds) + DBS.writeFile (moduleLoc Prelude.<> ".fieldUsage.json") =<< (evaluate $ toStrict $ encodePretty $ Map.fromList $ groupByFunction $ Prelude.concat l) + return guts + +getAllTypeManipulations :: [LHsBindLR GhcTc GhcTc] -> IO [DataTypeUC] +getAllTypeManipulations binds = do + bindWiseUpdates <- + toList $ + mapM + ( \x -> do + let functionName = getFunctionName x + filterRecordUpdateAndCon = Prelude.filter (\x -> ((show $ toConstr x) `Prelude.elem` ["HsGetField","RecordCon", "RecordUpd"])) (x ^? biplateRef :: [HsExpr GhcTc]) + pure $ bool (Nothing) (Just (DataTypeUC functionName (Data.Maybe.mapMaybe getDataTypeDetails filterRecordUpdateAndCon))) (not (Prelude.null filterRecordUpdateAndCon)) + ) + (fromList binds) + pure $ System.Directory.Internal.Prelude.catMaybes bindWiseUpdates + where + getDataTypeDetails :: HsExpr GhcTc -> Maybe TypeVsFields +#if __GLASGOW_HASKELL__ >= 900 + getDataTypeDetails (RecordCon _ (iD) rcon_flds) = Just (TypeVsFields (T.pack $ nameStableString $ getName (GHC.unXRec @(GhcTc) iD)) (extractRecordBinds (rcon_flds))) +#else + getDataTypeDetails (RecordCon _ (iD) rcon_flds) = Just (TypeVsFields (T.pack $ nameStableString $ getName $ idName $ unLoc $ iD) (extractRecordBinds (rcon_flds))) +#endif + getDataTypeDetails (RecordUpd _ rupd_expr rupd_flds) = Just (TypeVsFields (T.pack $ showSDocUnsafe $ ppr rupd_expr) (getFieldUpdates rupd_flds)) + + -- inferFieldType :: Name -> String + inferFieldTypeFieldOcc (L _ (FieldOcc _ (L _ rdrName))) = handleRdrName rdrName + inferFieldTypeAFieldOcc = (handleRdrName . rdrNameAmbiguousFieldOcc . unLoc) + + handleRdrName :: RdrName -> String + handleRdrName x = + case x of + Unqual occName -> ("$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) + Qual moduleName occName -> ((moduleNameString moduleName) <> "$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) + Orig module' occName -> ((moduleNameString $ moduleName module') <> "$" <> (showSDocUnsafe $ pprNameSpaceBrief $ occNameSpace occName) <> "$" <> (occNameString occName) <> "$" <> (unpackFS $ occNameFS occName)) + Exact name -> nameStableString name + +#if __GLASGOW_HASKELL__ >= 900 + getFieldUpdates :: Either [LHsRecUpdField GhcTc] [LHsRecUpdProj GhcTc] -> Either [FieldRep] [Text] + getFieldUpdates fields = + case fields of + Left x -> (Left . map (extractField . unLoc)) x + Right x -> (Right . map (T.pack . showSDocUnsafe . ppr)) x + where + extractField :: HsRecUpdField GhcTc -> FieldRep + extractField (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr, hsRecPun = pun}) = + if pun + then (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ inferFieldTypeAFieldOcc lbl)) + else (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr (unLoc expr)) (T.pack $ inferFieldTypeAFieldOcc lbl)) +#else + getFieldUpdates :: [LHsRecUpdField GhcTc]-> Either [FieldRep] [Text] + getFieldUpdates fields = Left $ map extractField fields + where + extractField :: LHsRecUpdField GhcTc -> FieldRep + extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr, hsRecPun = pun})) = + if pun + then (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ inferFieldTypeAFieldOcc lbl)) + else (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr (unLoc expr)) (T.pack $ inferFieldTypeAFieldOcc lbl)) +#endif + + extractRecordBinds :: HsRecFields GhcTc (LHsExpr GhcTc) -> Either [FieldRep] [Text] + extractRecordBinds (HsRecFields{rec_flds = fields}) = + Left $ map extractField fields + where + extractField :: LHsRecField GhcTc (LHsExpr GhcTc) -> FieldRep + extractField (L _ (HsRecField{hsRecFieldLbl = lbl, hsRecFieldArg = expr, hsRecPun = pun})) = + if pun + then (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ inferFieldTypeFieldOcc lbl)) + else (FieldRep (T.pack $ showSDocUnsafe $ ppr lbl) (T.pack $ showSDocUnsafe $ ppr $ unLoc expr) (T.pack $ inferFieldTypeFieldOcc lbl)) + + getFunctionName :: LHsBindLR GhcTc GhcTc -> [Text] +#if __GLASGOW_HASKELL__ >= 900 + getFunctionName (L _ x@(FunBind fun_ext id matches _)) = [T.pack $ nameStableString $ getName id] +#else + getFunctionName (L _ x@(FunBind fun_ext id matches _ _)) = [T.pack $ nameStableString $ getName id] +#endif + getFunctionName (L _ (VarBind{var_id = var, var_rhs = expr})) = [T.pack $ nameStableString $ getName var] + getFunctionName (L _ (PatBind{pat_lhs = pat, pat_rhs = expr})) = [""] + getFunctionName (L _ (AbsBinds{abs_binds = binds})) = Prelude.concatMap getFunctionName $ bagToList binds + +processPat :: LPat GhcTc -> [(Name, Maybe Text)] +processPat (L _ pat) = case pat of +#if __GLASGOW_HASKELL__ >= 900 + ConPat _ _ details -> processDetails details +#else + ConPatIn _ details -> processDetails details +#endif + VarPat _ x@(L _ var) -> [(varName var, Just $ T.pack $ showSDocUnsafe $ ppr $ getLoc $ x)] + ParPat _ pat' -> processPat pat' + _ -> [] + +processDetails :: HsConPatDetails GhcTc -> [(Name, Maybe Text)] +#if __GLASGOW_HASKELL__ >= 900 +processDetails (PrefixCon _ args) = Prelude.concatMap processPat args +#else +processDetails (PrefixCon args) = Prelude.concatMap processPat args +#endif +processDetails (InfixCon arg1 arg2) = processPat arg1 <> processPat arg2 +processDetails (RecCon rec) = Prelude.concatMap processPatField (rec_flds rec) + +processPatField :: LHsRecField GhcTc (LPat GhcTc) -> [(Name, Maybe Text)] +processPatField (L _ HsRecField{hsRecFieldArg = arg}) = processPat arg + +#if __GLASGOW_HASKELL__ >= 900 +getFilePath :: SrcSpan -> String +getFilePath (RealSrcSpan rSSpan _) = unpackFS $ srcSpanFile rSSpan +getFilePath (UnhelpfulSpan fs) = showSDocUnsafe $ ppr $ fs +#else +getFilePath :: SrcSpan -> String +getFilePath (RealSrcSpan rSSpan) = unpackFS $ srcSpanFile rSSpan +getFilePath (UnhelpfulSpan fs) = unpackFS fs +#endif + +-- 1. `HasField _ r _` where r is a variable + +-- 2. `HasField _ (T ...) _` if T is a data family +-- (because it might have fields introduced later) + +-- 3. `HasField x (T ...) _` where x is a variable, +-- if T has any fields at all + +-- 4. `HasField "foo" (T ...) _` if T has a "foo" field +processHasField :: Text -> Expr Var -> Expr Var -> IO [(Text, [FieldUsage])] +processHasField functionName b@(App (App (App getField (Type fieldName)) (Type haskellType@(TyConApp haskellTypeT _))) (Type finalFieldType)) hasField = + pure [(functionName, [FieldUsage (pack $ showSDocUnsafe $ ppr haskellType) (pack $ showSDocUnsafe $ ppr fieldName) (pack $ showSDocUnsafe $ ppr finalFieldType) (pack $ nameStableString $ GhcPlugins.tyConName haskellTypeT) (pack $ showSDocUnsafe $ ppr b)])] +processHasField functionName b@(App (App (App getField (Type fieldName)) (Type haskellType)) (Type finalFieldType)) hasField = + pure [(functionName, [FieldUsage (pack $ showSDocUnsafe $ ppr haskellType) (pack $ showSDocUnsafe $ ppr fieldName) (pack $ showSDocUnsafe $ ppr finalFieldType) (pack $ show $ toConstr haskellType) (pack $ showSDocUnsafe $ ppr b)])] +processHasField functionName (Var x) (Var hasField) = do + res <- pure mempty + let b = pack $ showSDocUnsafe $ ppr x + lensString = T.replace "\n" "" $ pack $ showSDocUnsafe $ ppr x + parts = + if ((Prelude.length (T.splitOn " @ " lensString)) >= 2) + then [] + else words $ T.unpack $ T.replace "\t" "" $ T.replace "\n" "" $ T.strip (pack $ showSDocUnsafe $ ppr $ tyVarKind hasField) + case tyVarKind hasField of + (TyConApp haskellTypeT z) -> do + let y = map (\(zz) -> (pack $ showSDocUnsafe $ ppr zz, pack $ extractVarFromType zz)) z + if length y == 4 + then + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 3) + (T.strip $ snd $ y Prelude.!! 2) + lensString + ] + ) + ] + else + if length y == 3 + then + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 0) + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ snd $ y Prelude.!! 1) + lensString + ] + ) + ] + else do + pure res +#if __GLASGOW_HASKELL__ >= 900 + (FunTy _ _ a _) -> do +#else + (FunTy _ a _) -> do +#endif + let fieldType = T.strip $ Prelude.last $ T.splitOn "->" $ pack $ showSDocUnsafe $ ppr $ varType hasField + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack $ showSDocUnsafe $ ppr $ varType x) + (pack $ showSDocUnsafe $ ppr hasField) + fieldType + (pack $ extractVarFromType $ varType x) + b + ] + ) + ] + (TyVarTy a) -> do + let fieldType = T.strip $ Prelude.last $ T.splitOn "->" $ pack $ showSDocUnsafe $ ppr $ varType hasField + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack $ showSDocUnsafe $ ppr $ varType x) + (pack $ showSDocUnsafe $ ppr hasField) + fieldType + (pack $ extractVarFromType $ varType x) + b + ] + ) + ] + _ -> do + case parts of + ["HasField", fieldName, dataType, fieldType] -> + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack dataType) + (pack $ init (Prelude.tail fieldName)) + (pack fieldType) + "" --(pack $ show $ toConstr $ haskellType) + b + ] + ) + ] + ("HasField" : fieldName : dataType : fieldTypeRest) -> + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack dataType) + (pack fieldName) + (pack $ unwords fieldTypeRest) + "" --(pack $ show $ toConstr $ haskellType) + b + ] + ) + ] + _ -> do + pure res +processHasField functionName x (Var hasField) = do + res <- toLexpr functionName x + let b = pack $ showSDocUnsafe $ ppr x + lensString = T.replace "\n" "" $ pack $ showSDocUnsafe $ ppr x + parts = + if ((Prelude.length (T.splitOn " @ " lensString)) >= 2) + then [] + else words $ T.unpack $ T.replace "\t" "" $ T.replace "\n" "" $ T.strip (pack $ showSDocUnsafe $ ppr $ tyVarKind hasField) + case tyVarKind hasField of + (TyConApp haskellTypeT z) -> do + let y = map (\(zz) -> (pack $ showSDocUnsafe $ ppr zz, pack $ extractVarFromType zz)) z + if length y == 4 + then + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 3) + (T.strip $ snd $ y Prelude.!! 2) + lensString + ] + ) + ] + else + if length y == 3 + then + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 0) + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ snd $ y Prelude.!! 1) + lensString + ] + ) + ] + else do + pure res + _ -> do + case parts of + ["HasField", fieldName, dataType, fieldType] -> + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack dataType) + (pack $ init (Prelude.tail fieldName)) + (pack fieldType) + "" --(pack $ show $ toConstr $ haskellType) + b + ] + ) + ] + ("HasField" : fieldName : dataType : fieldTypeRest) -> + pure $ + res + <> [ + ( functionName + , + [ FieldUsage + (pack dataType) + (pack fieldName) + (pack $ unwords fieldTypeRest) + "" --(pack $ show $ toConstr $ haskellType) + b + ] + ) + ] + _ -> do + pure res + +groupByFunction :: [(Text, [FieldUsage])] -> [(Text, [FieldUsage])] +groupByFunction = map mergeGroups . groupBy ((==) `on` fst) . sortBy (compare `on` fst) + where + mergeGroups :: [(Text, [FieldUsage])] -> (Text, [FieldUsage]) + mergeGroups xs = (fst (Prelude.head xs), concatMap snd xs) + +toLBind :: CoreBind -> IO [(Text, [FieldUsage])] +toLBind (NonRec binder expr) = do + res <- toLexpr (pack $ nameStableString $ idName binder) expr + pure $ groupByFunction res +toLBind (Rec binds) = do + r <- + toList $ + mapM + ( \(b, e) -> do + toLexpr (pack $ nameStableString (idName b)) e + ) + (fromList binds) + pure $ groupByFunction $ Prelude.concat r + +processFieldExtraction :: Text -> Var -> Var -> Text -> IO [(Text, [FieldUsage])] +processFieldExtraction functionName _field _type b = do + res <- case (varType _field) of +#if __GLASGOW_HASKELL__ >= 900 + (FunTy _ _ a _) -> do +#else + (FunTy _ a _) -> do +#endif + let fieldType = T.strip $ Prelude.last $ T.splitOn "->" $ pack $ showSDocUnsafe $ ppr $ varType _field + pure + [ + ( functionName + , + [ FieldUsage + (pack $ showSDocUnsafe $ ppr $ varType _type) + (pack $ showSDocUnsafe $ ppr _field) + fieldType + (pack $ extractVarFromType $ varType _type) + b + ] + ) + ] + (TyConApp haskellTypeT z) -> do + let y = map (\(zz) -> (pack $ showSDocUnsafe $ ppr zz, pack $ extractVarFromType zz)) z + if length y == 4 + then + pure $ + [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 3) + (T.strip $ snd $ y Prelude.!! 2) + b + ] + ) + ] + else + if length y == 3 + then + pure $ + [ + ( functionName + , + [ FieldUsage + (T.strip $ fst $ y Prelude.!! 1) + (T.strip $ fst $ y Prelude.!! 0) + (T.strip $ fst $ y Prelude.!! 2) + (T.strip $ snd $ y Prelude.!! 1) + b + ] + ) + ] + else do + pure mempty + _ -> pure mempty + pure $ res + +extractVarFromType :: Type -> String +extractVarFromType = go + where + go :: Type -> String + go (TyVarTy v) = (nameStableString $ varName v) + go (TyConApp haskellTypeT z) = (nameStableString $ GhcPlugins.tyConName haskellTypeT) + go (AppTy a b) = go a <> "," <> go b + go _ = mempty + +toLexpr :: Text -> Expr Var -> IO [(Text, [FieldUsage])] +toLexpr functionName (Var x) = pure mempty +toLexpr functionName (Lit x) = pure mempty +toLexpr functionName (Type _id) = pure mempty +toLexpr functionName x@(App func@(App _ _) args@(Var isHasField)) + | "$_sys$$dHasField" == pack (nameStableString $ idName isHasField) = do + processHasField functionName func args + | otherwise = do + processApp functionName x +toLexpr functionName x@(App func@(Var _field) args@(Var _type)) = do + processFieldExtraction functionName _field _type (pack $ showSDocUnsafe $ ppr x) +toLexpr functionName x@(App _ _) = processApp functionName x +toLexpr functionName (Lam func args) = + toLexpr functionName args +toLexpr functionName (Let func args) = do + a <- toLexpr functionName args + f <- toLBind func + pure $ map (\(x, y) -> (functionName, y)) f <> a +toLexpr functionName (Case condition bind _type alts) = do + c <- toLexpr functionName condition + a <- toList $ mapM (toLAlt functionName) (fromList alts) + pure $ c <> Prelude.concat a +toLexpr functionName (Tick _ expr) = toLexpr functionName expr +toLexpr functionName (Cast expr _) = toLexpr functionName expr +toLexpr functionName _ = pure mempty + +processApp functionName x@(App func args) = do + f <- toLexpr functionName func + a <- toLexpr functionName args + pure $ f <> a + +#if __GLASGOW_HASKELL__ >= 900 +toLAlt :: Text -> Alt Var -> IO [(Text, [FieldUsage])] +toLAlt x (Alt a b c) = toLAlt' x (a,b,c) + where + toLAlt' :: Text -> (AltCon, [Var], CoreExpr) -> IO [(Text, [FieldUsage])] + toLAlt' functionName (DataAlt dataCon, val, e) = do + let typeName = GhcPlugins.tyConName $ GhcPlugins.dataConTyCon dataCon + extractingConstruct = showSDocUnsafe $ ppr $ GhcPlugins.dataConName dataCon + kindStr = showSDocUnsafe $ ppr $ tyConKind $ GhcPlugins.dataConTyCon dataCon + res <- toLexpr functionName e + pure $ ((map (\x -> (functionName, [FieldUsage (pack $ showSDocUnsafe $ ppr $ typeName) (pack $ extractingConstruct) (pack $ showSDocUnsafe $ ppr $ varType x) (pack $ nameStableString $ typeName) (pack $ showSDocUnsafe $ ppr x)])) val)) <> res + toLAlt' functionName (LitAlt lit, val, e) = + toLexpr functionName e + toLAlt' functionName (DEFAULT, val, e) = + toLexpr functionName e +#else +toLAlt :: Text -> (AltCon, [Var], CoreExpr) -> IO [(Text, [FieldUsage])] +toLAlt functionName (DataAlt dataCon, val, e) = do + let typeName = GhcPlugins.tyConName $ GhcPlugins.dataConTyCon dataCon + extractingConstruct = showSDocUnsafe $ ppr $ GhcPlugins.dataConName dataCon + kindStr = showSDocUnsafe $ ppr $ tyConKind $ GhcPlugins.dataConTyCon dataCon + res <- toLexpr functionName e + pure $ ((map (\x -> (functionName, [FieldUsage (pack $ showSDocUnsafe $ ppr $ typeName) (pack $ extractingConstruct) (pack $ showSDocUnsafe $ ppr $ varType x) (pack $ nameStableString $ typeName) (pack $ showSDocUnsafe $ ppr x)])) val)) <> res +toLAlt functionName (LitAlt lit, val, e) = + toLexpr functionName e +toLAlt functionName (DEFAULT, val, e) = + toLexpr functionName e +#endif + +pprTyCon :: Name -> SDoc +pprTyCon = ppr + +pprDataCon :: Name -> SDoc +pprDataCon = ppr \ No newline at end of file diff --git a/fieldInspector/src/FieldInspector/PluginTypes.hs b/fieldInspector/src/FieldInspector/PluginTypes.hs new file mode 100644 index 0000000..084f785 --- /dev/null +++ b/fieldInspector/src/FieldInspector/PluginTypes.hs @@ -0,0 +1,297 @@ + +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE CPP #-} + +module FieldInspector.PluginTypes (plugin) where + +#if __GLASGOW_HASKELL__ >= 900 +import Language.Haskell.Syntax.Type +import GHC.Hs.Extension () +import GHC.Parser.Annotation () +import GHC.Utils.Outputable () +import qualified Data.IntMap.Internal as IntMap +import Streamly.Internal.Data.Stream (fromList,mapM_,mapM,toList) +import GHC +import GHC.Driver.Plugins (Plugin(..),CommandLineOption,defaultPlugin,PluginRecompile(..)) +import GHC.Driver.Env +import GHC.Tc.Types +import GHC.Unit.Module.ModSummary +import GHC.Utils.Outputable (showSDocUnsafe,ppr,SDoc,Outputable) +import GHC.Data.Bag (bagToList) +import GHC.Types.Name hiding (varName) +import GHC.Types.Var +import qualified Data.Aeson.KeyMap as HM +import GHC.Core.Opt.Monad +import GHC.Rename.HsType +-- import GHC.HsToCore.Docs +import GHC.Types.Name.Reader + + +#else +import CoreMonad (CoreM, CoreToDo (CoreDoPluginPass), liftIO) +import CoreSyn ( + AltCon (..), + Bind (NonRec, Rec), + CoreBind, + CoreExpr, + Expr (..), + mkStringLit + ) +import TyCoRep +import GHC.IO (unsafePerformIO) +import GHC.Hs +import GHC.Hs.Decls +import GhcPlugins ( + CommandLineOption,Arg (..), + HsParsedModule(..), + Hsc, + Name,SDoc,DataCon,DynFlags,ModSummary(..),TyCon, + Literal (..),typeEnvElts, + ModGuts (mg_binds, mg_loc, mg_module),showSDoc, + Module (moduleName),tyConKind, + NamedThing (getName),getDynFlags,tyConDataCons,dataConOrigArgTys,dataConName, + Outputable (..),dataConFieldLabels,PluginRecompile(..), + Plugin (..), + Var,flLabel,dataConRepType, + coVarDetails, + defaultPlugin, + idName, + mkInternalName, + mkLitString, + mkLocalVar, + mkVarOcc, + moduleNameString, + nameStableString, + noCafIdInfo, + purePlugin, + showSDocUnsafe, + tyVarKind, + unpackFS, + tyConName, + msHsFilePath + ) +import Id (isExportedId,idType) +import Name (getSrcSpan) +import SrcLoc +import Unique (mkUnique) +import Var (isLocalId,varType) +import FieldInspector.Types +import TcRnTypes +import TcRnMonad +import DataCon +#endif + +import FieldInspector.Types +import Control.Concurrent (MVar, modifyMVar, newMVar) +import Data.Aeson +import Data.Aeson.Encode.Pretty (encodePretty) +import qualified Data.ByteString as DBS +import Data.ByteString.Lazy (toStrict) +import Data.Int (Int64) +import Data.List.Extra (intercalate, isSuffixOf, replace, splitOn,groupBy) +import Data.List ( sortBy, intercalate ,foldl') +import qualified Data.Map as Map +import Data.Text (Text, concat, isInfixOf, pack, unpack) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Time +import Data.Map (Map) +import Data.Data +import Data.Maybe (catMaybes) +import Control.Monad.IO.Class (liftIO) +import System.IO (writeFile) +import Control.Monad (forM) +import Streamly.Internal.Data.Stream hiding (concatMap, init, length, map, splitOn,foldl',intercalate) +import System.Directory (createDirectoryIfMissing, removeFile) +import System.Directory.Internal.Prelude hiding (mapM, mapM_) +import Prelude hiding (id, mapM, mapM_) +import Control.Exception (evaluate) +import qualified Data.Record.Plugin as DRP +import qualified Data.Record.Anon.Plugin as DRAP +import qualified Data.Record.Plugin.HasFieldPattern as DRPH +import qualified RecordDotPreprocessor as RDP +import qualified ApiContract.Plugin as ApiContract +import qualified Fdep.Plugin as Fdep + +plugin :: Plugin +plugin = (defaultPlugin{ + -- installCoreToDos = install + pluginRecompile = (\_ -> return NoForceRecompile) + , parsedResultAction = collectTypeInfoParser + }) + <> ApiContract.plugin +#if defined(ENABLE_LR_PLUGINS) + <> DRP.plugin + <> DRAP.plugin + <> DRPH.plugin +#endif + <> RDP.plugin + +instance Semigroup Plugin where + p <> q = defaultPlugin { + parsedResultAction = \args summary -> + parsedResultAction p args summary + >=> parsedResultAction q args summary + , typeCheckResultAction = \args summary -> + typeCheckResultAction p args summary + >=> typeCheckResultAction q args summary + , pluginRecompile = \args -> + (<>) + <$> pluginRecompile p args + <*> pluginRecompile q args + , tcPlugin = \args -> + case (tcPlugin p args, tcPlugin q args) of + (Nothing, Nothing) -> Nothing + (Just tp, Nothing) -> Just tp + (Nothing, Just tq) -> Just tq + (Just (TcPlugin tcPluginInit1 tcPluginSolve1 tcPluginStop1), Just (TcPlugin tcPluginInit2 tcPluginSolve2 tcPluginStop2)) -> Just $ TcPlugin + { tcPluginInit = do + ip <- tcPluginInit1 + iq <- tcPluginInit2 + return (ip, iq) + , tcPluginSolve = \(sp,sq) given derived wanted -> do + solveP <- tcPluginSolve1 sp given derived wanted + solveQ <- tcPluginSolve2 sq given derived wanted + return $ combineTcPluginResults solveP solveQ + , tcPluginStop = \(solveP,solveQ) -> do + tcPluginStop1 solveP + tcPluginStop2 solveQ + } + } + +combineTcPluginResults :: TcPluginResult -> TcPluginResult -> TcPluginResult +combineTcPluginResults resP resQ = + case (resP, resQ) of + (TcPluginContradiction ctsP, TcPluginContradiction ctsQ) -> + TcPluginContradiction (ctsP ++ ctsQ) + + (TcPluginContradiction ctsP, TcPluginOk _ _) -> + TcPluginContradiction ctsP + + (TcPluginOk _ _, TcPluginContradiction ctsQ) -> + TcPluginContradiction ctsQ + + (TcPluginOk solvedP newP, TcPluginOk solvedQ newQ) -> + TcPluginOk (solvedP ++ solvedQ) (newP ++ newQ) + + +instance Monoid Plugin where + mempty = defaultPlugin + +pprTyCon :: Name -> SDoc +pprTyCon = ppr + +pprDataCon :: Name -> SDoc +pprDataCon = ppr + +collectTypeInfoParser :: [CommandLineOption] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule +collectTypeInfoParser opts modSummary hpm = do + _ <- Fdep.collectDecls opts modSummary hpm + _ <- liftIO $ forkIO $ + do + let prefixPath = case opts of + [] -> "/tmp/fieldInspector/" + local : _ -> local + moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + modulePath = prefixPath <> msHsFilePath modSummary + hm_module = unLoc $ hpm_module hpm + path = (intercalate "/" . init . splitOn "/") modulePath + -- print ("generating types data for module: " <> moduleName' <> " at path: " <> path) + types <- toList $ mapM (pure . getTypeInfo) (fromList $ hsmodDecls hm_module) + createDirectoryIfMissing True path + DBS.writeFile (modulePath <> ".types.parser.json") =<< (evaluate $ toStrict $ encodePretty $ Map.fromList $ Prelude.concat types) + -- print ("generated types data for module: " <> moduleName' <> " at path: " <> path) + pure hpm + +getTypeInfo :: LHsDecl GhcPs -> [(String,TypeInfo)] +getTypeInfo (L _ (TyClD _ (DataDecl _ lname _ _ defn))) = + [((showSDocUnsafe' lname) ,TypeInfo + { name = showSDocUnsafe' lname + , typeKind = "data" + , dataConstructors = map getDataConInfo (dd_cons defn) + })] +getTypeInfo (L _ (TyClD _ (SynDecl _ lname _ _ rhs))) = + [((showSDocUnsafe' lname),TypeInfo + { name = showSDocUnsafe' lname + , typeKind = "type" +#if __GLASGOW_HASKELL__ >= 900 + , dataConstructors = [DataConInfo (showSDocUnsafe' lname) (maybe mempty (Map.singleton "synonym" . unpackHDS) (hsTypeToString $ unLoc rhs)) []] +#else + , dataConstructors = [DataConInfo (showSDocUnsafe' lname) (Map.singleton "synonym" ((showSDocUnsafe . ppr . unLoc) rhs)) []] +#endif + })] +getTypeInfo _ = [] + +instance Outputable Void where + +getDataConInfo :: LConDecl GhcPs -> DataConInfo +getDataConInfo (L _ x@ConDeclH98{ con_name = lname, con_args = args }) = + DataConInfo + { dataConNames = showSDocUnsafe' lname + , fields = getFieldMap args + , sumTypes = [] -- For H98-style data constructors, sum types are not applicable + } +getDataConInfo (L _ ConDeclGADT{ con_names = lnames, con_res_ty = ty }) = + DataConInfo + { dataConNames = intercalate ", " (map (showSDocUnsafe') lnames) +#if __GLASGOW_HASKELL__ >= 900 + , fields = maybe (mempty) (\x -> Map.singleton "gadt" $ unpackHDS x) (hsTypeToString $ unLoc ty) +#else + , fields = Map.singleton "gadt" (showSDocUnsafe $ ppr ty) +#endif + , sumTypes = [] -- For GADT-style data constructors, sum types can be represented by the type itself + } + +#if __GLASGOW_HASKELL__ >= 900 +hsTypeToString :: HsType GhcPs -> Maybe HsDocString +hsTypeToString = f + where + f :: HsType GhcPs -> Maybe HsDocString + f (HsDocTy _ _ lds) = Just (unLoc lds) + f (HsBangTy _ _ (L _ (HsDocTy _ _ lds))) = Just (unLoc lds) + f x = Just (mkHsDocString $ showSDocUnsafe $ ppr x) + +extractInfixCon :: [HsType GhcPs] -> Map.Map String String +extractInfixCon x = + let l = length x + in Map.fromList $ map (\(a,b) -> (show a , b)) $ Prelude.zip [0..l] (map f x) + where + f :: HsType GhcPs -> (String) + f (HsDocTy _ _ lds) = showSDocUnsafe $ ppr $ (unLoc lds) + f (HsBangTy _ _ (L _ (HsDocTy _ _ lds))) = showSDocUnsafe $ ppr $ (unLoc lds) + f x = (showSDocUnsafe $ ppr x) + +extractConDeclField :: [ConDeclField GhcPs] -> Map.Map String String +extractConDeclField x = Map.fromList (go x) + where + go :: [ConDeclField GhcPs] -> [(String,String)] + go [] = [] + go ((ConDeclField _ cd_fld_names cd_fld_type _):xs) = + [((intercalate "," $ convertRdrNameToString cd_fld_names),(showSDocUnsafe $ ppr cd_fld_type))] <> (go xs) + + convertRdrNameToString x = map (showSDocUnsafe . ppr . rdrNameOcc . unLoc . reLocN . rdrNameFieldOcc . unXRec @(GhcPs)) x + +getFieldMap :: HsConDeclH98Details GhcPs -> Map.Map String String +getFieldMap con_args = + case con_args of + PrefixCon _ args -> extractInfixCon $ map (unLoc . hsScaledThing) args + InfixCon arg1 arg2 -> extractInfixCon $ map (unLoc . hsScaledThing) [arg1,arg2] + RecCon (fields) -> extractConDeclField $ map unLoc $ (unXRec @(GhcPs)) fields + +#else +getFieldMap :: HsConDeclDetails GhcPs -> Map String String +getFieldMap (PrefixCon args) = Map.fromList $ Prelude.zipWith (\i t -> (show i, showSDocUnsafe (ppr t))) [1..] args +getFieldMap (RecCon (L _ fields)) = Map.fromList $ concatMap getRecField fields + where + getRecField (L _ (ConDeclField _ fnames t _)) = [(showSDocUnsafe (ppr fname), showSDocUnsafe (ppr t)) | L _ fname <- fnames] +getFieldMap (InfixCon t1 t2) = Map.fromList [("field1", showSDocUnsafe (ppr t1)), ("field2", showSDocUnsafe (ppr t2))] +#endif + +#if __GLASGOW_HASKELL__ >= 900 +showSDocUnsafe' = showSDocUnsafe . ppr . GHC.unXRec @(GhcPs) +#else +showSDocUnsafe' = showSDocUnsafe . ppr +#endif \ No newline at end of file diff --git a/fieldInspector/src/FieldInspector/Types.hs b/fieldInspector/src/FieldInspector/Types.hs index 2022770..2b88b62 100644 --- a/fieldInspector/src/FieldInspector/Types.hs +++ b/fieldInspector/src/FieldInspector/Types.hs @@ -1,8 +1,9 @@ -{-# LANGUAGE BangPatterns #-} + {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE UndecidableInstances,DeriveDataTypeable,DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE UndecidableInstances, DeriveAnyClass #-} module FieldInspector.Types where @@ -11,6 +12,8 @@ import GHC.Generics (Generic) import Data.Text import Data.Data import qualified Data.Map as Map +import Data.Binary +import Control.DeepSeq data FieldUsage = FieldUsage { typeName :: Text @@ -19,16 +22,32 @@ data FieldUsage = FieldUsage { , typeSrcLoc :: Text , beautifiedCode :: Text } - deriving (Generic, Data, Show, ToJSON, FromJSON) + deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) data TypeInfo = TypeInfo { name :: String , typeKind :: String , dataConstructors :: [DataConInfo] - } deriving (Generic, Data, Show, ToJSON, FromJSON) + } deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) data DataConInfo = DataConInfo - { dataConName :: String + { dataConNames :: String , fields :: Map.Map String String , sumTypes :: [String] - } deriving (Generic, Data, Show, ToJSON, FromJSON) \ No newline at end of file + } deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) + +data DataTypeUC = DataTypeUC { + function_name_ :: [Text] + , typeVsFields :: [TypeVsFields] + } deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) + +data TypeVsFields = TypeVsFields { + type_name :: Text + , fieldsVsExprs :: Either [(FieldRep)] [Text] +} deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) + +data FieldRep = FieldRep { + field_name :: Text + , expression :: Text + , field_type :: Text +} deriving (Show, Eq, Ord,Binary,Generic,NFData,ToJSON,FromJSON) diff --git a/fieldInspector/test/Main.hs b/fieldInspector/test/Main.hs index e542539..1e91562 100644 --- a/fieldInspector/test/Main.hs +++ b/fieldInspector/test/Main.hs @@ -1,6 +1,19 @@ -module Main (main) where +{-# LANGUAGE NamedFieldPuns,DataKinds,FlexibleInstances,MultiParamTypeClasses,TypeFamilies,DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts, FlexibleInstances, UndecidableInstances, DeriveDataTypeable, DeriveAnyClass,RecordWildCards #-} +module Main where + +import GHC.Generics (Generic) + main :: IO () -main = putStrLn "Test suite not yet implemented." +main = do + print $ demo $ (A "Test suite not yet implemented." 0) + pure () + +data A = A {name :: String,age :: Int} + deriving (Generic,Show) -data A = B Int String | C String \ No newline at end of file +demo :: A -> String +demo a = + case a of + (A {name}) -> name \ No newline at end of file diff --git a/flake.lock b/flake.lock index 2400bb8..a7b5b0b 100644 --- a/flake.lock +++ b/flake.lock @@ -1,19 +1,41 @@ { "nodes": { - "classyplate": { + "beam": { "flake": false, "locked": { - "lastModified": 1678370822, - "narHash": "sha256-8AJ/55ShKCe49MEcyMqzJ3ADjs5dvtuTIhuTTq2q5nQ=", - "owner": "Chaitanya-nair", + "lastModified": 1696055201, + "narHash": "sha256-BIq3ZjZQWQ0w3zWA19zGBggiVVfnOzR5d4b7De0oVZY=", + "owner": "juspay", + "repo": "beam", + "rev": "c4f86057db76640245c3d1fde040176c53e9b9a3", + "type": "github" + }, + "original": { + "owner": "juspay", + "repo": "beam", + "rev": "c4f86057db76640245c3d1fde040176c53e9b9a3", + "type": "github" + } + }, + "classyplate": { + "inputs": { + "flake-parts": "flake-parts", + "haskell-flake": "haskell-flake", + "nixpkgs": "nixpkgs", + "systems": "systems" + }, + "locked": { + "lastModified": 1721385699, + "narHash": "sha256-Gof2hSQSX581LA8GGnHGjXWu5F899Cot+Id1SYxlUMY=", + "owner": "eswar2001", "repo": "classyplate", - "rev": "46f5e0e7073e1d047f70473bf3c75366a613bfeb", + "rev": "a360f56820df6ca5284091f318bcddcd3e065243", "type": "github" }, "original": { - "owner": "Chaitanya-nair", + "owner": "eswar2001", "repo": "classyplate", - "rev": "46f5e0e7073e1d047f70473bf3c75366a613bfeb", + "rev": "a360f56820df6ca5284091f318bcddcd3e065243", "type": "github" } }, @@ -22,11 +44,11 @@ "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1717285511, - "narHash": "sha256-iKzJcpdXih14qYVcZ9QC9XuZYnPc6T8YImb6dX166kw=", + "lastModified": 1719994518, + "narHash": "sha256-pQMhCCHyQGRzdfAkdJ4cIWiw+JNuWsTX7f0ZYSyz0VY=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "2a55567fcf15b1b1c7ed712a2c6fadaec7412ea8", + "rev": "9227223f6d922fee3c7b190b2cc238a99527bbb7", "type": "github" }, "original": { @@ -39,6 +61,24 @@ "inputs": { "nixpkgs-lib": "nixpkgs-lib_2" }, + "locked": { + "lastModified": 1726153070, + "narHash": "sha256-HO4zgY0ekfwO5bX0QH/3kJ/h4KvUDFZg8YpkNwIbg1U=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "bcef6817a8b2aa20a5a6dbb19b43e63c5bf8619a", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_3": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib_3" + }, "locked": { "lastModified": 1685662779, "narHash": "sha256-cKDDciXGpMEjP1n6HlzKinN0H+oLmNpgeCTzYnsA2po=", @@ -53,13 +93,211 @@ "type": "github" } }, + "flake-parts_4": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib_4" + }, + "locked": { + "lastModified": 1719994518, + "narHash": "sha256-pQMhCCHyQGRzdfAkdJ4cIWiw+JNuWsTX7f0ZYSyz0VY=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9227223f6d922fee3c7b190b2cc238a99527bbb7", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_5": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib_5" + }, + "locked": { + "lastModified": 1719994518, + "narHash": "sha256-pQMhCCHyQGRzdfAkdJ4cIWiw+JNuWsTX7f0ZYSyz0VY=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9227223f6d922fee3c7b190b2cc238a99527bbb7", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_6": { + "inputs": { + "nixpkgs-lib": "nixpkgs-lib_6" + }, + "locked": { + "lastModified": 1685662779, + "narHash": "sha256-cKDDciXGpMEjP1n6HlzKinN0H+oLmNpgeCTzYnsA2po=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "71fb97f0d875fd4de4994dfb849f2c75e17eb6c3", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "ghc-hasfield-plugin": { + "inputs": { + "flake-parts": "flake-parts_5", + "haskell-flake": "haskell-flake_4", + "nixpkgs": "nixpkgs_3", + "systems": "systems_2" + }, + "locked": { + "lastModified": 1721371073, + "narHash": "sha256-1xTFZRE/vAHV/mLMW5rNyZH1SkkbyFqDxXZvw7JwOHo=", + "owner": "eswar2001", + "repo": "ghc-hasfield-plugin", + "rev": "c932ebc0d7e824129bb70c8a078f3c68feed85c9", + "type": "github" + }, + "original": { + "owner": "eswar2001", + "repo": "ghc-hasfield-plugin", + "rev": "c932ebc0d7e824129bb70c8a078f3c68feed85c9", + "type": "github" + } + }, + "ghc8-beam": { + "flake": false, + "locked": { + "lastModified": 1689929344, + "narHash": "sha256-uE2/Hq8u9+BjABrM9m6qV+H/88aGnRRzhsE0k8QKSL0=", + "owner": "juspay", + "repo": "beam", + "rev": "e50e6dc6a5a83c4c0c50183416fad33084c81d9e", + "type": "github" + }, + "original": { + "owner": "juspay", + "repo": "beam", + "rev": "e50e6dc6a5a83c4c0c50183416fad33084c81d9e", + "type": "github" + } + }, + "ghc8-classyplate": { + "flake": false, + "locked": { + "lastModified": 1678370822, + "narHash": "sha256-8AJ/55ShKCe49MEcyMqzJ3ADjs5dvtuTIhuTTq2q5nQ=", + "owner": "Chaitanya-nair", + "repo": "classyplate", + "rev": "46f5e0e7073e1d047f70473bf3c75366a613bfeb", + "type": "github" + }, + "original": { + "owner": "Chaitanya-nair", + "repo": "classyplate", + "rev": "46f5e0e7073e1d047f70473bf3c75366a613bfeb", + "type": "github" + } + }, + "ghc8-ghc-hasfield-plugin": { + "flake": false, + "locked": { + "lastModified": 1658487566, + "narHash": "sha256-pZ6kFNfRtBWWqJ3zZSJhZQz7hcdgTdpkqUbzRCuRSl8=", + "owner": "juspay", + "repo": "ghc-hasfield-plugin", + "rev": "d82ac5a6c0ad643eebe2b9b32c91f6523d3f30dc", + "type": "github" + }, + "original": { + "owner": "juspay", + "repo": "ghc-hasfield-plugin", + "rev": "d82ac5a6c0ad643eebe2b9b32c91f6523d3f30dc", + "type": "github" + } + }, + "ghc8-large-records": { + "flake": false, + "locked": { + "lastModified": 1719312727, + "narHash": "sha256-NLs4yiUh4vNf4sqOQUUTCr0Fpld1y6ZyZJNhqSTzAI0=", + "owner": "eswar2001", + "repo": "large-records", + "rev": "e393f4501d76a98b4482b0a5b35d120ae70e5dd3", + "type": "github" + }, + "original": { + "owner": "eswar2001", + "repo": "large-records", + "rev": "e393f4501d76a98b4482b0a5b35d120ae70e5dd3", + "type": "github" + } + }, + "ghc8-nixpkgs": { + "locked": { + "lastModified": 1643795778, + "narHash": "sha256-sBxYgXu+4JTpXPu3c1QGl2a2zzzDJj4VNsVatF1sEIY=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "43e3b6af08f29c4447a6073e3d5b86a4f45dd420", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "rev": "43e3b6af08f29c4447a6073e3d5b86a4f45dd420", + "type": "github" + } + }, + "ghc8-record-dot-preprocessor": { + "flake": false, + "locked": { + "lastModified": 1644582826, + "narHash": "sha256-BXprRyjI4ZTG+Orz858xmttiC8O0yuubaaKmeRAL/UY=", + "owner": "ndmitchell", + "repo": "record-dot-preprocessor", + "rev": "99452d27f35ea1ff677be9af570d834e8fab4caf", + "type": "github" + }, + "original": { + "owner": "ndmitchell", + "repo": "record-dot-preprocessor", + "rev": "99452d27f35ea1ff677be9af570d834e8fab4caf", + "type": "github" + } + }, + "ghc8-references": { + "inputs": { + "flake-parts": "flake-parts_3", + "haskell-flake": "haskell-flake_2", + "nixpkgs": "nixpkgs_2" + }, + "locked": { + "lastModified": 1686714318, + "narHash": "sha256-Ogy9S6cF/8WNfpcQ1k65rPjjTfWlH15Jp5JeraYaAQQ=", + "owner": "eswar2001", + "repo": "references", + "rev": "35912f3cc72b67fa63a8d59d634401b79796469e", + "type": "github" + }, + "original": { + "owner": "eswar2001", + "repo": "references", + "rev": "35912f3cc72b67fa63a8d59d634401b79796469e", + "type": "github" + } + }, "haskell-flake": { "locked": { - "lastModified": 1719249394, - "narHash": "sha256-ytIvs6dq1dD3eicwhmqMyhIDH52DfqhOiCpmJbjBYVI=", + "lastModified": 1720977934, + "narHash": "sha256-k9kwz2lpUqafRUpuCMgkv4AWtHEoJPCds1ZPRkyW2XE=", "owner": "srid", "repo": "haskell-flake", - "rev": "dfea80e8a907a7818f11090788f84f1a62985694", + "rev": "cd449f1c04175efdf5b553302d22916640090066", "type": "github" }, "original": { @@ -83,35 +321,175 @@ "type": "github" } }, + "haskell-flake_3": { + "locked": { + "lastModified": 1726441645, + "narHash": "sha256-mXVvqtBqgcDnT2MTJP8eJeQtajKbNrYevPHpoDqKnVQ=", + "owner": "srid", + "repo": "haskell-flake", + "rev": "96aad3a08f30333fead66da396dbf7a21ac4adb6", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "haskell-flake", + "type": "github" + } + }, + "haskell-flake_4": { + "locked": { + "lastModified": 1720977934, + "narHash": "sha256-k9kwz2lpUqafRUpuCMgkv4AWtHEoJPCds1ZPRkyW2XE=", + "owner": "srid", + "repo": "haskell-flake", + "rev": "cd449f1c04175efdf5b553302d22916640090066", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "haskell-flake", + "type": "github" + } + }, + "haskell-flake_5": { + "locked": { + "lastModified": 1721530802, + "narHash": "sha256-eUMmQKXjt4WQq+IBscftg/Y9bXWiOYhasfeH5Yb9Psc=", + "owner": "srid", + "repo": "haskell-flake", + "rev": "f8f38ecd259338167cc0c85fd541479297a315af", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "haskell-flake", + "type": "github" + } + }, + "haskell-flake_6": { + "locked": { + "lastModified": 1686160859, + "narHash": "sha256-UE+0TQHyPxF8jhbLEeqvNQAy7B79bBix/rpFrf5nsn0=", + "owner": "srid", + "repo": "haskell-flake", + "rev": "908a59167f78035a123ab71ed77af79bed519771", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "haskell-flake", + "type": "github" + } + }, + "large-records": { + "inputs": { + "beam": [ + "beam" + ], + "flake-parts": "flake-parts_4", + "ghc-hasfield-plugin": "ghc-hasfield-plugin", + "haskell-flake": "haskell-flake_5", + "nixpkgs": "nixpkgs_4", + "systems": "systems_3" + }, + "locked": { + "lastModified": 1721562622, + "narHash": "sha256-4XivoIvlVl7UyVCyZneeLIvyKBbRIvDEOEnJBxnZp+c=", + "owner": "eswar2001", + "repo": "large-records", + "rev": "b60bcb312c7d55f1d638aa1a5143696e6586e76d", + "type": "github" + }, + "original": { + "owner": "eswar2001", + "repo": "large-records", + "rev": "b60bcb312c7d55f1d638aa1a5143696e6586e76d", + "type": "github" + } + }, "nixpkgs": { "locked": { - "lastModified": 1643795778, - "narHash": "sha256-sBxYgXu+4JTpXPu3c1QGl2a2zzzDJj4VNsVatF1sEIY=", + "lastModified": 1698266953, + "narHash": "sha256-jf72t7pC8+8h8fUslUYbWTX5rKsRwOzRMX8jJsGqDXA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "43e3b6af08f29c4447a6073e3d5b86a4f45dd420", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", "type": "github" }, "original": { "owner": "nixos", "repo": "nixpkgs", - "rev": "43e3b6af08f29c4447a6073e3d5b86a4f45dd420", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", "type": "github" } }, "nixpkgs-lib": { "locked": { - "lastModified": 1717284937, - "narHash": "sha256-lIbdfCsf8LMFloheeE6N31+BMIeixqyQWbSr2vk79EQ=", + "lastModified": 1719876945, + "narHash": "sha256-Fm2rDDs86sHy0/1jxTOKB1118Q0O3Uc7EC0iXvXKpbI=", "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/eb9ceca17df2ea50a250b6b27f7bf6ab0186f198.tar.gz" + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" }, "original": { "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/eb9ceca17df2ea50a250b6b27f7bf6ab0186f198.tar.gz" + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" } }, "nixpkgs-lib_2": { + "locked": { + "lastModified": 1725233747, + "narHash": "sha256-Ss8QWLXdr2JCBPcYChJhz4xJm+h/xjl4G0c0XlP6a74=", + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/356624c12086a18f2ea2825fed34523d60ccc4e3.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/356624c12086a18f2ea2825fed34523d60ccc4e3.tar.gz" + } + }, + "nixpkgs-lib_3": { + "locked": { + "dir": "lib", + "lastModified": 1685564631, + "narHash": "sha256-8ywr3AkblY4++3lIVxmrWZFzac7+f32ZEhH/A8pNscI=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "4f53efe34b3a8877ac923b9350c874e3dcd5dc0a", + "type": "github" + }, + "original": { + "dir": "lib", + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-lib_4": { + "locked": { + "lastModified": 1719876945, + "narHash": "sha256-Fm2rDDs86sHy0/1jxTOKB1118Q0O3Uc7EC0iXvXKpbI=", + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" + } + }, + "nixpkgs-lib_5": { + "locked": { + "lastModified": 1719876945, + "narHash": "sha256-Fm2rDDs86sHy0/1jxTOKB1118Q0O3Uc7EC0iXvXKpbI=", + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://github.com/NixOS/nixpkgs/archive/5daf0514482af3f97abaefc78a6606365c9108e2.tar.gz" + } + }, + "nixpkgs-lib_6": { "locked": { "dir": "lib", "lastModified": 1685564631, @@ -145,35 +523,126 @@ "type": "github" } }, + "nixpkgs_3": { + "locked": { + "lastModified": 1698266953, + "narHash": "sha256-jf72t7pC8+8h8fUslUYbWTX5rKsRwOzRMX8jJsGqDXA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + } + }, + "nixpkgs_4": { + "locked": { + "lastModified": 1698266953, + "narHash": "sha256-jf72t7pC8+8h8fUslUYbWTX5rKsRwOzRMX8jJsGqDXA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + } + }, + "nixpkgs_5": { + "locked": { + "lastModified": 1698266953, + "narHash": "sha256-jf72t7pC8+8h8fUslUYbWTX5rKsRwOzRMX8jJsGqDXA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + } + }, + "nixpkgs_6": { + "locked": { + "lastModified": 1698266953, + "narHash": "sha256-jf72t7pC8+8h8fUslUYbWTX5rKsRwOzRMX8jJsGqDXA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "rev": "75a52265bda7fd25e06e3a67dee3f0354e73243c", + "type": "github" + } + }, "references": { "inputs": { - "flake-parts": "flake-parts_2", - "haskell-flake": "haskell-flake_2", - "nixpkgs": "nixpkgs_2" + "flake-parts": "flake-parts_6", + "haskell-flake": "haskell-flake_6", + "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1686714318, - "narHash": "sha256-Ogy9S6cF/8WNfpcQ1k65rPjjTfWlH15Jp5JeraYaAQQ=", + "lastModified": 1721735703, + "narHash": "sha256-0F/xsz64sUwKQvKL5yuU+7+QPiyvlQFUb8zZI1ZTbrI=", "owner": "eswar2001", "repo": "references", - "rev": "35912f3cc72b67fa63a8d59d634401b79796469e", + "rev": "120ae7826a7af01a527817952ad0c3f5ef08efd0", "type": "github" }, "original": { "owner": "eswar2001", "repo": "references", - "rev": "35912f3cc72b67fa63a8d59d634401b79796469e", + "rev": "120ae7826a7af01a527817952ad0c3f5ef08efd0", "type": "github" } }, "root": { "inputs": { + "beam": "beam", "classyplate": "classyplate", - "flake-parts": "flake-parts", - "haskell-flake": "haskell-flake", - "nixpkgs": "nixpkgs", + "flake-parts": "flake-parts_2", + "ghc8-beam": "ghc8-beam", + "ghc8-classyplate": "ghc8-classyplate", + "ghc8-ghc-hasfield-plugin": "ghc8-ghc-hasfield-plugin", + "ghc8-large-records": "ghc8-large-records", + "ghc8-nixpkgs": "ghc8-nixpkgs", + "ghc8-record-dot-preprocessor": "ghc8-record-dot-preprocessor", + "ghc8-references": "ghc8-references", + "haskell-flake": "haskell-flake_3", + "large-records": "large-records", + "nixpkgs": "nixpkgs_5", "references": "references", - "systems": "systems" + "streamly": "streamly", + "systems": "systems_4" + } + }, + "streamly": { + "flake": false, + "locked": { + "lastModified": 1701516357, + "narHash": "sha256-Ap7kdurs4NZyMUeMUIF5qU5eHKifO9YmnO5eSEvdtA8=", + "owner": "composewell", + "repo": "streamly", + "rev": "12d85026291d9305f93f573d284d0d35abf40968", + "type": "github" + }, + "original": { + "owner": "composewell", + "repo": "streamly", + "rev": "12d85026291d9305f93f573d284d0d35abf40968", + "type": "github" } }, "systems": { @@ -190,6 +659,51 @@ "repo": "default", "type": "github" } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_4": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index d1fdf55..7495fa1 100644 --- a/flake.nix +++ b/flake.nix @@ -1,23 +1,75 @@ { inputs = { - nixpkgs.url = "github:nixos/nixpkgs/43e3b6af08f29c4447a6073e3d5b86a4f45dd420"; systems.url = "github:nix-systems/default"; flake-parts.url = "github:hercules-ci/flake-parts"; haskell-flake.url = "github:srid/haskell-flake"; - classyplate.flake = false; - classyplate.url = "github:Chaitanya-nair/classyplate/46f5e0e7073e1d047f70473bf3c75366a613bfeb"; - references.flake = true; - references.url = "github:eswar2001/references/35912f3cc72b67fa63a8d59d634401b79796469e"; + streamly.url = "github:composewell/streamly/12d85026291d9305f93f573d284d0d35abf40968"; + streamly.flake = false; + + # ghc 9.2.8 packages + nixpkgs.url = "github:nixos/nixpkgs/75a52265bda7fd25e06e3a67dee3f0354e73243c"; + classyplate.url = "github:eswar2001/classyplate/a360f56820df6ca5284091f318bcddcd3e065243"; + references.url = "github:eswar2001/references/120ae7826a7af01a527817952ad0c3f5ef08efd0"; + beam.url = "github:juspay/beam/c4f86057db76640245c3d1fde040176c53e9b9a3"; + beam.flake = false; + large-records.url = "github:eswar2001/large-records/b60bcb312c7d55f1d638aa1a5143696e6586e76d"; + large-records.inputs.beam.follows = "beam"; + + # ghc 8.10.7 packages + ghc8-nixpkgs.url = "github:nixos/nixpkgs/43e3b6af08f29c4447a6073e3d5b86a4f45dd420"; + ghc8-beam.url = "github:juspay/beam/e50e6dc6a5a83c4c0c50183416fad33084c81d9e"; + ghc8-beam.flake = false; + ghc8-classyplate.url = "github:Chaitanya-nair/classyplate/46f5e0e7073e1d047f70473bf3c75366a613bfeb"; + ghc8-classyplate.flake = false; + ghc8-references.url = "github:eswar2001/references/35912f3cc72b67fa63a8d59d634401b79796469e"; + ghc8-references.flake = true; + ghc8-ghc-hasfield-plugin.url = "github:juspay/ghc-hasfield-plugin/d82ac5a6c0ad643eebe2b9b32c91f6523d3f30dc"; + ghc8-ghc-hasfield-plugin.flake = false; + ghc8-large-records.url = "github:eswar2001/large-records/e393f4501d76a98b4482b0a5b35d120ae70e5dd3"; + ghc8-large-records.flake = false; + ghc8-record-dot-preprocessor.url = "github:ndmitchell/record-dot-preprocessor/99452d27f35ea1ff677be9af570d834e8fab4caf"; + ghc8-record-dot-preprocessor.flake = false; }; outputs = inputs@{ self, nixpkgs, flake-parts, ... }: - flake-parts.lib.mkFlake { inherit inputs; } { + flake-parts.lib.mkFlake { inherit inputs; } ({ withSystem, ...}: { systems = import inputs.systems; imports = [ inputs.haskell-flake.flakeModule ]; - perSystem = { self', pkgs, ... }: { - + perSystem = { self', pkgs, system, ... }: { # Typically, you just want a single project named "default". But # multiple projects are also possible, each using different GHC version. + + # GHC 8 support + haskellProjects.ghc8 = { + projectFlakeName = "spider"; + basePackages = inputs.ghc8-nixpkgs.legacyPackages.${system}.haskell.packages.ghc8107; + imports = [ + inputs.ghc8-references.haskellFlakeProjectModules.output + ]; + packages = { + classyplate.source = inputs.ghc8-classyplate; + ghc-hasfield-plugin.source = inputs.ghc8-ghc-hasfield-plugin; + large-records.source = inputs.ghc8-large-records + /large-records; + large-generics.source = inputs.ghc8-large-records + /large-generics; + large-anon.source = inputs.ghc8-large-records + /large-anon; + ghc-tcplugin-api.source = "0.7.1.0"; + typelet.source = inputs.ghc8-large-records + /typelet; + record-dot-preprocessor.source = inputs.ghc8-record-dot-preprocessor; + streamly-core.source = inputs.streamly + /core; + beam-core.source = inputs.ghc8-beam + /beam-core; + }; + settings = { + beam-core.jailbreak = true; + sheriff.check = false; + }; + devShell = { + mkShellArgs = { + name = "ghc8-spider"; + }; + hlsCheck.enable = inputs.ghc8-nixpkgs.legacyPackages.${system}.stdenv.isDarwin; # On darwin, sandbox is disabled, so HLS can use the network. + }; + }; + haskellProjects.default = { # The base package set representing a specific GHC version. # By default, this is pkgs.haskellPackages. @@ -29,13 +81,21 @@ # Note that local packages are automatically included in `packages` # (defined by `defaults.packages` option). # + # defaults.enable = false; + # devShell.tools = hp: with hp; { + # inherit cabal-install; + # inherit hp; + # }; projectFlakeName = "spider"; - basePackages = pkgs.haskell.packages.ghc8107; + # basePackages = pkgs.haskell.packages.ghc8107; + basePackages = pkgs.haskell.packages.ghc92; imports = [ inputs.references.haskellFlakeProjectModules.output + inputs.classyplate.haskellFlakeProjectModules.output + inputs.large-records.haskellFlakeProjectModules.output ]; packages = { - classyplate.source = inputs.classyplate; + streamly-core.source = inputs.streamly + /core; }; settings = { # aeson = { @@ -45,7 +105,11 @@ # haddock = false; # broken = false; # }; - sheriff.check = false; + # primitive-checked = { + # broken = false; + # jailbreak = true; + # }; + sheriff.check = false; }; devShell = { @@ -54,14 +118,33 @@ # Programs you want to make available in the shell. # Default programs can be disabled by setting to 'null' - # tools = hp: { fourmolu = hp.fourmolu; ghcid = null; }; - + # tools = hp: { fourmolu = null; ghcid = null; }; + mkShellArgs = { + name = "spider"; + }; hlsCheck.enable = pkgs.stdenv.isDarwin; # On darwin, sandbox is disabled, so HLS can use the network. }; }; # haskell-flake doesn't set the default package, but you can do it here. packages.default = self'.packages.fdep; + + }; + + flake.haskellFlakeProjectModules = { + # To use ghc 9 version, use + # inputs.spider.haskellFlakeProjectModules.output + + # To use ghc 8 version, use + # inputs.spider.haskellFlakeProjectModules.output-ghc8 + + output-ghc9 = { pkgs, lib, ... }: withSystem pkgs.system ({ config, ... }: + config.haskellProjects."default".defaults.projectModules.output + ); + + output-ghc8 = { pkgs, lib, ... }: withSystem pkgs.system ({ config, ... }: + config.haskellProjects."ghc8".defaults.projectModules.output + ); }; - }; + }); } diff --git a/paymentFlow/.gitignore b/paymentFlow/.gitignore new file mode 100644 index 0000000..8a0ab4f --- /dev/null +++ b/paymentFlow/.gitignore @@ -0,0 +1,7 @@ +dist-* +result +test/dumps +test/out* +cabal.project.local +.juspay/tmp* +.tmp* \ No newline at end of file diff --git a/paymentFlow/CHANGELOG.md b/paymentFlow/CHANGELOG.md new file mode 100644 index 0000000..30beb0c --- /dev/null +++ b/paymentFlow/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for code-checker + +## 0.1.0.0 -- 2024-07-19 + +* First version. Basic rules based compilation error. diff --git a/paymentFlow/DOC.md b/paymentFlow/DOC.md new file mode 100644 index 0000000..6c75e8a --- /dev/null +++ b/paymentFlow/DOC.md @@ -0,0 +1,7 @@ +### paymentFlow plugin + +#### What it does? + +`paymentFlow` is a compiler plugin designed to incorporate business logic validation checks during compilation. It performs the following verification: + +***Restrict Access to Specified Type Fields***: This check ensures that deprecated fields within a type are not accessed. The goal is to prevent usage of these restricted fields and to suggest alternative methods for accessing the required information. \ No newline at end of file diff --git a/paymentFlow/LICENSE b/paymentFlow/LICENSE new file mode 100644 index 0000000..189cd23 --- /dev/null +++ b/paymentFlow/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2024 Juspay + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/paymentFlow/README.MD b/paymentFlow/README.MD new file mode 100644 index 0000000..90609a9 --- /dev/null +++ b/paymentFlow/README.MD @@ -0,0 +1,24 @@ + +# Haskell Code Checker Plugin - Sheriff + +## Overview + +This Haskell plugin automatically verifies `fields` access from a `type` for rule violations. It scans the source code to identify types and evaluates them against predefined rules to detect any violations. Currently, it supports the following rules: + +1. Blocking access to certain `fields` from a specified `type`. +2. Allowing exceptions to the rule for field access from the type based on field_access_whitelisted_fns or whitelisted_line_nos. + +This tool is useful for developers to enforce better coding practices and prevent the use of specific fields from a type in the code. + +## Usage + +Add this to your ghc-options in cabal and mention `paymentFlow` in build-depends + +``` +-fplugin=PaymentFlow.Plugin +``` +Also, we can provide flags to the plugin in as follows: +``` +-fplugin=PaymentFlow.Plugin:{"throwCompilationError":true,"saveToFile":true,"savePath":".juspay/tmp/paymentFlow/","failOnFileNotFound":true} +``` +By default, it throwsCompilationErrors. \ No newline at end of file diff --git a/paymentFlow/paymentFlow.cabal b/paymentFlow/paymentFlow.cabal new file mode 100644 index 0000000..fc6a09e --- /dev/null +++ b/paymentFlow/paymentFlow.cabal @@ -0,0 +1,99 @@ +cabal-version: 3.0 +name: paymentFlow +version: 0.1.0.0 +synopsis: A checker plugin to throw compilation errors based on given rules. +license: MIT +license-file: LICENSE +author: harshith.ak-juspay +maintainer: harshith.ak@juspay.in +category: Development +build-type: Simple +extra-doc-files: CHANGELOG.md + +Flag Dev + Description: Use ghc options to dump ASTs in dev mode + Default: False + Manual: True + +common common-options + build-depends: base + ghc-options: -Wall + -Wincomplete-uni-patterns + -Wincomplete-record-updates + -Wincomplete-patterns + -Wcompat + -Widentities + -Wredundant-constraints + -fhide-source-paths + + default-language: Haskell2010 + default-extensions: DeriveGeneric + GeneralizedNewtypeDeriving + InstanceSigs + LambdaCase + OverloadedStrings + RecordWildCards + ScopedTypeVariables + StandaloneDeriving + TypeApplications + CPP + +library + import: common-options + exposed-modules: + PaymentFlow.Plugin + other-modules: + PaymentFlow.Types + PaymentFlow.Patterns + build-depends: + bytestring + , containers + , filepath + , ghc + , ghc-exactprint + , unordered-containers + , uniplate + , references + , classyplate + , aeson + , directory + , extra + , yaml + , text + , aeson-pretty + hs-source-dirs: src + default-language: Haskell2010 + +test-suite paymentFlow-test + import: common-options + + default-language: Haskell2010 + type: exitcode-stdio-1.0 + + hs-source-dirs: test + + main-is: Main.hs + other-modules: + Types + Types1 + + build-depends: + , paymentFlow + , aeson + , text + , containers + , bytestring + , aeson-pretty + , extra + , record-dot-preprocessor + , record-hasfield + , lens >= 4.0 + if flag(Dev) + ghc-options: + -- -fplugin=PaymentFlow.Plugin + -- -fplugin-opt=PaymentFlow.Plugin:{"rulesConfigPath":".juspay/paymentFlowRules.yaml","failOnFileNotFound":true} + else + ghc-options: + -- -fplugin=PaymentFlow.Plugin + + default-extensions: DataKinds \ No newline at end of file diff --git a/paymentFlow/src/PaymentFlow/Patterns.hs b/paymentFlow/src/PaymentFlow/Patterns.hs new file mode 100644 index 0000000..04faa6a --- /dev/null +++ b/paymentFlow/src/PaymentFlow/Patterns.hs @@ -0,0 +1,30 @@ +{-# LANGUAGE PatternSynonyms #-} + +module PaymentFlow.Patterns where + +import GHC hiding (exprType) + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.TyCo.Rep +import GHC.Tc.Types.Evidence +import Language.Haskell.Syntax.Expr +#else +import GHC.Hs.Expr +import TcEvidence +import TyCoRep +#endif + +#if __GLASGOW_HASKELL__ >= 900 + +pattern PatHsWrap :: HsWrapper -> HsExpr GhcTc -> HsExpr GhcTc +pattern PatHsWrap wrapper expr <- (XExpr (WrapExpr (HsWrap wrapper expr))) + +pattern PatHsExpansion :: HsExpr GhcRn -> HsExpr GhcTc -> HsExpr GhcTc +pattern PatHsExpansion orig expanded <- (XExpr (ExpansionExpr (HsExpanded orig expanded))) + +#else + +pattern PatHsWrap :: HsWrapper -> HsExpr (GhcPass p) -> HsExpr (GhcPass p) +pattern PatHsWrap wrapper expr <- (HsWrap _ wrapper expr) + +#endif \ No newline at end of file diff --git a/paymentFlow/src/PaymentFlow/Plugin.hs b/paymentFlow/src/PaymentFlow/Plugin.hs new file mode 100644 index 0000000..5b2a692 --- /dev/null +++ b/paymentFlow/src/PaymentFlow/Plugin.hs @@ -0,0 +1,373 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} + +module PaymentFlow.Plugin (plugin) where + +-- paymentFlow imports +import PaymentFlow.Types (VoilationRuleResult(..), PFRules(..), Rule(..), PluginOpts(..), defaultPluginOpts) +import PaymentFlow.Patterns + +-- GHC imports + +import Control.Applicative ((<|>)) +import Control.Monad (when) +import Control.Monad.IO.Class (MonadIO (..)) +import Control.Reference (biplateRef, (^?), Simple, Traversal) +import Data.Aeson as A +import qualified Data.ByteString.Lazy.Char8 as Char8 +import Data.Data +import Data.Function (on) +import Data.List (nub, sortBy, groupBy, isInfixOf, isSuffixOf, isPrefixOf, stripPrefix) +import Data.Maybe (catMaybes, fromMaybe, listToMaybe) +import Data.Yaml +import GHC hiding (exprType) +import Prelude hiding (id) +import Data.Generics.Uniplate.Data + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.ConLike +import GHC.Core.TyCo.Rep +import GHC.Data.Bag +import GHC.HsToCore.Monad +import GHC.HsToCore.Expr +import GHC.Plugins hiding ((<>), getHscEnv, purePlugin) +import GHC.Tc.Types +import GHC.Tc.Types.Evidence +import GHC.Tc.Utils.Monad +import GHC.Tc.Utils.TcType +import GHC.Types.Annotations +import qualified GHC.Utils.Outputable as OP +#else +import Bag +import ConLike +import DsExpr +import DsMonad +import GhcPlugins hiding ((<>), getHscEnv, purePlugin) +import qualified Outputable as OP +import TcEvidence +import TcRnMonad +import TcRnTypes +import TcType +import TyCoRep +#endif + +mkInvalidYamlFileErr :: String -> OP.SDoc +mkInvalidYamlFileErr err = OP.text err + +parseYAMLFile :: (FromJSON a) => FilePath -> IO (Either ParseException a) +parseYAMLFile file = decodeFileEither file + +plugin :: Plugin +plugin = defaultPlugin { + typeCheckResultAction = paymentFlow + , pluginRecompile = purePlugin + } + +purePlugin :: [CommandLineOption] -> IO PluginRecompile +purePlugin _ = return NoForceRecompile + +paymentFlow :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv +paymentFlow opts modSummary tcEnv = do + let pluginOpts = case opts of + [] -> defaultPluginOpts + (x : _) -> + fromMaybe defaultPluginOpts $ A.decode (Char8.pack x) + moduleNm = moduleNameString $ moduleName $ ms_mod modSummary + paymentFlowRulesConfigPath = rulesConfigPath pluginOpts + parsedPaymentFlowRules <- liftIO $ parseYAMLFile paymentFlowRulesConfigPath + ruleList <- case parsedPaymentFlowRules of + Left err -> do + when (failOnFileNotFound pluginOpts) $ addErr (mkInvalidYamlFileErr (show err)) + pure [] + Right (rule :: PFRules) -> pure (rules rule) + let binds = tcg_binds tcEnv + if ("Types" `isSuffixOf` moduleNm || "Types" `isPrefixOf` moduleNm || "Types" `isInfixOf` moduleNm ) + then pure () + else do + errors <- concat <$> mapM (checkBind ruleList) (bagToList binds) + + let sortedErrors = sortBy (leftmost_smallest `on` srcSpan) errors + groupedErrors = groupBy (\a b -> srcSpan a == srcSpan b) sortedErrors + childFnFilterLogic srcGrpErrArr = do + let srcSpn = maybe Nothing (\value -> Just $ srcSpan value) (listToMaybe srcGrpErrArr) + srcSpanLine = getSrcSpanLine srcSpn + shouldThroughError = (any (\(VoilationRuleResult{..}) -> do + let whitelistedRules = field_access_whitelisted_fns rule + fnName `elem` whitelistedRules || coreFnName `elem` whitelistedRules) srcGrpErrArr) || (any (\result -> srcSpanLine `elem` (whitelisted_line_nos (rule result))) srcGrpErrArr) + if shouldThroughError + then Nothing + else listToMaybe srcGrpErrArr + filteredErrors = (\srcGrpErrArr -> childFnFilterLogic srcGrpErrArr) <$> groupedErrors + mapM_ (\ (VoilationRuleResult {..}) -> addErrAt srcSpan $ OP.text $ field_rule_fixes rule ) (catMaybes filteredErrors) + return tcEnv + +checkBind :: [Rule] -> LHsBindLR GhcTc GhcTc -> TcM [VoilationRuleResult] +checkBind rule (L _ (FunBind{..} )) = do + let funMatches = unLoc $ mg_alts fun_matches + concat <$> mapM (checkMatch rule (getVarNameFromIDP $ unLoc fun_id)) funMatches +checkBind rule (L _ (AbsBinds {abs_binds = binds})) = + concat <$> (mapM (checkBind rule) $ bagToList binds) +checkBind _ _ = pure [] + +checkMatch :: [Rule] -> String -> LMatch GhcTc (LHsExpr GhcTc) -> TcM [VoilationRuleResult] +checkMatch rule coreFn (L _ (Match _ _ _ grhss)) = do + let whereBinds = (grhssLocalBinds grhss) ^? biplateRef :: [LHsExpr GhcTc] + nonWhereBinds = (grhssGRHSs grhss) ^? biplateRef :: [LHsExpr GhcTc] + loopOverExprInArgsPerFnName (nonWhereBinds <> whereBinds) rule coreFn +checkMatch _ _ _ = pure [] + +loopOverExprInArgsPerFnName :: [LHsExpr GhcTc] -> [Rule] -> String -> TcM [VoilationRuleResult] +loopOverExprInArgsPerFnName exprs rules coreFn = do + let fnArgTuple = catMaybes (getFnNameWithAllArgs <$> exprs) + nub <$> concat <$> mapM (lookOverExpr rules coreFn) fnArgTuple +loopOverExprInArgsPerFnName _ _ _ = pure [] + +lookOverExpr :: [Rule] -> String -> (Located Var, [LHsExpr GhcTc]) -> TcM [VoilationRuleResult] +lookOverExpr rules funId (fnName, args) = do + let updatedArgs = args ^? biplateRef :: [LHsExpr GhcTc] + tupleResponse <- catMaybes <$> sequence (checkExpr rules <$> updatedArgs) + pure $ (\(x, y) -> VoilationRuleResult { fnName = getVarName fnName, srcSpan = x, rule = y, coreFnName = funId }) <$> tupleResponse + +checkExpr :: [Rule]-> LHsExpr GhcTc -> TcM (Maybe (SrcSpan, Rule)) +checkExpr rules expr = + case expr of + L _ (HsPar _ exp) -> checkExpr rules exp + +#if __GLASGOW_HASKELL__ >= 900 + L loc (PatHsExpansion orig expanded) -> checkExpr rules (L loc expanded) + + L (SrcSpanAnn _ loc1) (HsApp _ (L _ (HsApp _ op' (L _ (HsVar _ (L _ var))))) (L _ (PatHsWrap _ (HsVar _ (L _ lens))))) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithLeftExprAsType lens var rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc1, rule) + + L _ (HsApp _ (L _ (PatHsWrap _ (HsAppType _ _ (HsWC _ (L (SrcSpanAnn _ loc) (HsTyLit _ (HsStrTy _ fieldName)))) ))) (L _ (HsVar _ (L _ var)))) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithExprAndFieldAsName (showS fieldName) var rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc, rule) + + L (SrcSpanAnn _ loc) (HsApp _ (L _ (HsRecFld _ (Unambiguous name _))) (L _ (HsVar _ (L _ var)))) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithExprAndFieldAsName (showS name) var rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc, rule) + + _ -> pure Nothing + +#else + + L loc1 (HsApp _ (L _ (HsVar _ (L _ var))) _) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithRightExprAsType var rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc1, rule) + + L _ (OpApp _ (L loc1 (OpApp _ (L _ (HsVar _ (L _ leftVar))) _ (L _ (PatHsWrap _ (HsVar _ (L _ var)))))) _ _) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithLeftExprAsType var leftVar rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc1, rule) + + L loc1 (OpApp _ (L _ (HsVar _ (L _ leftVar))) _ (L _ (PatHsWrap _ (HsVar _ (L _ var))))) -> do + let voilationSatisfiedRules = verifyAndGetRuleVoilatedFnInfoWithLeftExprAsType var leftVar rules + case listToMaybe voilationSatisfiedRules of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc1, rule) + + L _ (HsApp _ (L loc2 (HsAppType _ (L _ (PatHsWrap (WpCompose (WpCompose (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableType))) (WpTyApp (LitTy (StrTyLit fastString)))) (WpTyApp _)) (HsVar _ opr))) _)) _) -> do + let tblName' = case tableType of + AppTy ty1 _ -> showS ty1 + TyConApp ty1 _ -> showS ty1 + ty -> showS ty + filteredRule = filter (\rule -> (type_name rule) == tblName' && fastString == (mkFastString $ blocked_field rule)) rules + case listToMaybe filteredRule of + Nothing -> pure Nothing + Just rule -> pure $ Just (loc2, rule) + + _ -> pure Nothing + +#endif + +showS :: (Outputable a) => a -> String +showS = showSDocUnsafe . ppr + +verifyAndGetRuleVoilatedFnInfoWithLeftExprAsType :: Var -> Var -> [Rule] -> [Rule] +verifyAndGetRuleVoilatedFnInfoWithLeftExprAsType var leftVar rules = do + let name = showS $ varName var + vType = varType leftVar + arrTypeCon = getTypeConFromType vType + updatedName = if "_" `isPrefixOf` name + then fromMaybe name (stripPrefix "_" name) + else name + filter (\rule -> elem (type_name rule) arrTypeCon && updatedName == blocked_field rule) rules + +verifyAndGetRuleVoilatedFnInfoWithExprAndFieldAsName :: String -> Var -> [Rule] -> [Rule] +verifyAndGetRuleVoilatedFnInfoWithExprAndFieldAsName name leftVar rules = do + let vType = varType leftVar + arrTypeCon = getTypeConFromType vType + updatedName = if "_" `isPrefixOf` name + then fromMaybe name (stripPrefix "_" name) + else name + filter (\rule -> elem (type_name rule) arrTypeCon && updatedName == blocked_field rule) rules + +getTypeConFromType :: Type -> [String] +getTypeConFromType vType = + case getTyConInStringFormat vType of + Just value -> value + Nothing -> + case vType of + (TyConApp typ tys) -> + if null tys + then [showS typ] + else + (\var -> do + case tyConAppTyCon_maybe var of + Just tyCon -> showS tyCon + Nothing -> "NA" + ) <$> tys + _ -> [] + +verifyAndGetRuleVoilatedFnInfoWithRightExprAsType :: Var -> [Rule] -> [Rule] +verifyAndGetRuleVoilatedFnInfoWithRightExprAsType var rules = do + let name = showS $ varName var + vType = varType var + arrTypeCon = getTypeConFromType vType + updatedName = if "_" `isPrefixOf` name + then fromMaybe name (stripPrefix "_" name) + else name + filter (\rule -> elem (type_name rule) arrTypeCon && updatedName == blocked_field rule) rules + + where + + getTypeConFromType :: Type -> [String] + getTypeConFromType vType = + case getTyConInStringFormat vType of + Just value -> value + Nothing -> + case vType of + (TyConApp _ tys) -> + (\localVar -> do + case tyConAppTyCon_maybe localVar of + Just tyCon -> showS tyCon + Nothing -> "NA" + ) <$> tys + _ -> [] + +getTyConInStringFormat :: Type -> Maybe [String] +getTyConInStringFormat vType = +#if __GLASGOW_HASKELL__ >= 900 + case splitFunTy_maybe vType of + Just (_, tyCon, _) -> Just [showS tyCon] + Nothing -> Nothing +#else + case splitFunTy_maybe vType of + Just (tyCon, _) -> Just [showS tyCon] + Nothing -> Nothing +#endif + +conLikeWrapId :: ConLike -> Maybe Var +conLikeWrapId (RealDataCon dc) = Just (dataConWrapId dc) +conLikeWrapId _ = Nothing + +#if __GLASGOW_HASKELL__ >= 900 +noExtFieldOrAnn :: EpAnn a +noExtFieldOrAnn = noAnn + +getLoc2 :: GenLocated (SrcSpanAnn' a) e -> SrcSpan +getLoc2 = getLocA + +noExprLoc :: a -> Located a +noExprLoc = noLoc + +getLocated :: GenLocated (SrcSpanAnn' a) e -> Located e +getLocated ap = L (getLocA ap) (unLoc ap) + +getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) +getFnNameWithAllArgs (L _ (HsVar _ v)) = Just (getLocated v, []) +getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl +getFnNameWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameWithAllArgs expr +getFnNameWithAllArgs (L _ (HsApp _ (L _ (HsVar _ v)) funr)) = Just (getLocated v, [funr]) +getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do + let res = getFnNameWithAllArgs funl + case res of + Nothing -> Nothing + Just (fnName, ls) -> Just (fnName, ls ++ [funr]) +getFnNameWithAllArgs (L loc (OpApp _ funl op funr)) = do + case op of + (L _ (HsVar _ v)) -> Just (getLocated v, [funl,funr]) + (L _ (PatHsWrap _ (HsVar _ var))) -> Just (getLocated var, [funl,funr]) + _ -> Nothing +getFnNameWithAllArgs (L loc (PatHsWrap _ expr)) = getFnNameWithAllArgs (L loc expr) +getFnNameWithAllArgs (L _ (HsCase _ funl exprLStmt)) = do + let res = getFnNameWithAllArgs funl + case res of + Nothing -> Nothing + Just (fnName, ls) -> do + let exprs = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + Just (fnName, ls <> exprs) +getFnNameWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of + "($)" -> getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> getFnNameWithAllArgs (L loc expanded) + _ -> getFnNameWithAllArgs (L loc expanded) +getFnNameWithAllArgs _ = Nothing + +#else + +noExtFieldOrAnn :: NoExtField +noExtFieldOrAnn = noExtField + +getLoc2 :: HasSrcSpan a => a -> SrcSpan +getLoc2 = getLoc + +noExprLoc :: (HasSrcSpan a) => SrcSpanLess a -> a +noExprLoc = noLoc + +getLocated :: (HasSrcSpan a) => a -> Located (SrcSpanLess a) +getLocated ap = L (getLoc ap) (unLoc ap) + +getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) +getFnNameWithAllArgs (L _ (HsVar _ v)) = Just (v, []) +getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noLoc clId, [])) <$> conLikeWrapId_maybe cl +getFnNameWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameWithAllArgs expr +getFnNameWithAllArgs (L _ (HsApp _ (L _ (HsVar _ v)) funr)) = Just (v, [funr]) +getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do + let res = getFnNameWithAllArgs funl + case res of + Nothing -> Nothing + Just (fnName, ls) -> Just (fnName, ls ++ [funr]) +getFnNameWithAllArgs (L loc (OpApp _ funl op funr)) = + case showS op of + "($)" -> getFnNameWithAllArgs $ (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> Nothing +getFnNameWithAllArgs (L loc (PatHsWrap _ expr)) = getFnNameWithAllArgs (L loc expr) +getFnNameWithAllArgs (L _ (HsCase _ funl exprLStmt)) = do + let res = getFnNameWithAllArgs funl + case res of + Nothing -> Nothing + Just (fnName, ls) -> do + let exprs = exprLStmt ^? biplateRef :: [LHsExpr GhcTc] + Just (fnName, ls <> exprs) +getFnNameWithAllArgs _ = Nothing + +#endif + +getVarNameFromIDP :: IdP GhcTc -> String +getVarNameFromIDP var = occNameString . occName $ var + +getVarName :: Located Var -> String +getVarName var = (getOccString . varName . unLoc) var + +getSrcSpanLine :: Maybe SrcSpan -> Int +getSrcSpanLine = \case +#if __GLASGOW_HASKELL__ >= 900 + (Just (RealSrcSpan span _)) -> srcSpanStartLine span + _ -> 0 +#else + (Just (RealSrcSpan span)) -> srcSpanStartLine span + _ -> 0 +#endif \ No newline at end of file diff --git a/paymentFlow/src/PaymentFlow/Types.hs b/paymentFlow/src/PaymentFlow/Types.hs new file mode 100644 index 0000000..b72f6d0 --- /dev/null +++ b/paymentFlow/src/PaymentFlow/Types.hs @@ -0,0 +1,69 @@ +module PaymentFlow.Types where + +import Data.Aeson + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Types.SrcLoc +#else +import SrcLoc +#endif + +data PluginOpts = PluginOpts { + failOnFileNotFound :: Bool, + rulesConfigPath :: String + } deriving (Show, Eq) + +defaultPluginOpts :: PluginOpts +defaultPluginOpts = + PluginOpts { + failOnFileNotFound = True, + rulesConfigPath = ".juspay/paymentFlowRules.yaml" + } + +instance FromJSON PluginOpts where + parseJSON = withObject "PluginOpts" $ \o -> do + failOnFileNotFound <- o .:? "failOnFileNotFound" .!= (failOnFileNotFound defaultPluginOpts) + rulesConfigPath <- o .:? "rulesConfigPath" .!= (rulesConfigPath defaultPluginOpts) + return PluginOpts {rulesConfigPath = rulesConfigPath, failOnFileNotFound = failOnFileNotFound } + +type Suggestion = String + +data Rule = + Rule + { type_name :: String + , field_access_whitelisted_fns :: [String] + , blocked_field :: String + , field_rule_fixes :: Suggestion + , whitelisted_line_nos :: [Int] + } deriving (Show, Eq) + +instance FromJSON Rule where + parseJSON = withObject "Rule" $ \o -> do + type_name <- o .: "type_name" + field_access_whitelisted_fns <- o .: "field_access_whitelisted_fns" + blocked_field <- o .: "blocked_field" + field_rule_fixes <- o .: "field_rule_fixes" + whitelisted_line_nos <- o .: "whitelisted_line_nos" + return Rule + { type_name = type_name + , field_access_whitelisted_fns = field_access_whitelisted_fns + , blocked_field = blocked_field + , field_rule_fixes = field_rule_fixes + , whitelisted_line_nos = whitelisted_line_nos + } + +data PFRules = PFRules + { rules :: [Rule] + } deriving (Show, Eq) + +instance FromJSON PFRules where + parseJSON = withObject "PFRules" $ \o -> do + rules <- o .: "rules" + return PFRules { rules = rules } + +data VoilationRuleResult = VoilationRuleResult + { fnName :: String + , srcSpan :: SrcSpan + , rule :: Rule + , coreFnName :: String + } deriving (Show, Eq) \ No newline at end of file diff --git a/paymentFlow/test/Main.hs b/paymentFlow/test/Main.hs new file mode 100644 index 0000000..5414b2f --- /dev/null +++ b/paymentFlow/test/Main.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -ddump-tc-ast #-} +{-# OPTIONS_GHC -fplugin=RecordDotPreprocessor #-} +{-# LANGUAGE TemplateHaskell #-} + +module Main (main) where + +import Data.Text as T +import Data.Maybe (fromMaybe) +import Control.Applicative ((<|>)) +import Prelude +import Data.Aeson as A +import Types as PT +import Types1 as PT1 +import Control.Lens + +main :: IO () +main = putStrLn "Test suite not yet implemented." + +decidePayStartPathbySurchargeAmt :: PT.TxnDetail -> Text -> Text -> PT.MerchantAccount -> Text +decidePayStartPathbySurchargeAmt txn defaultStartPayPath payStartPath mAcc = do + -- let surchargeConfigStatusAndValue = getMerchantConfigStatusAndvalueForPaymentFlow (mAcc ^. PT.showSurchargeBreakupScreen) + let surchargeConfigStatusAndValue = getMerchantConfigStatus + -- getMerchantConfigStatusAndvalueForPaymentFlow (getMerchantPIdFromMerchantAccount mAcc) (fromMaybe "" (merchantId mAcc)) (Skip mMCLookupConfig) + shouldShowSurchargePage = case surchargeConfigStatusAndValue of + (PT.PaymentFlowNotEligible, _) -> + -- (mAcc.shouldAddSurcharge ) && (mAcc.showSurchargeBreakupScreen) + -- mAcc.shouldAddSurcharge && mAcc.showSurchargeBreakupScreen + -- mAcc.shouldAddSurcharge && mAcc ^. PT.showSurchargeBreakupScreen + mAcc ^. PT.showSurchargeBreakupScreen && mAcc.shouldAddSurcharge + -- (PT.shouldAddSurcharge mAcc) && (PT.showSurchargeBreakupScreen mAcc) + (PT.Disabled, _) -> False + (PT.Enabled, surchargeConfigV) -> + (fromMaybe False $ (surchargeConfigV >>= (\sc -> sc.showSurchargeBreakupScreen)) <|> (Just $ mAcc ^. PT.showSurchargeBreakupScreen)) + -- (fromMaybe False $ (surchargeConfigV >>= (\sc -> PT1.showSurchargeBreakupScreen sc)) <|> (Just (PT.showSurchargeBreakupScreen mAcc))) + if shouldShowSurchargePage + then payStartPath + else defaultStartPayPath + + where + + getMerchantConfigStatus :: (PT.MerchantConfigStatus, Maybe PT1.SurchargeConfig) + getMerchantConfigStatus = + -- getMerchantConfigStatusAndvalueForPaymentFlow $ PT.showSurchargeBreakupScreen mAcc + -- getMerchantConfigStatusAndvalueForPaymentFlow (PT.showSurchargeBreakupScreen mAcc) + getMerchantConfigStatusAndvalueForPaymentFlow (mAcc ^. PT.showSurchargeBreakupScreen) + +getMerchantConfigStatusAndvalueForPaymentFlow ::Bool -> (PT.MerchantConfigStatus, Maybe PT1.SurchargeConfig) +getMerchantConfigStatusAndvalueForPaymentFlow _ = (PT.Enabled, Just $ PT1.SurchargeConfig {shouldAddSurchargeToRefund = False, showSurchargeBreakupScreen = Just True}) \ No newline at end of file diff --git a/paymentFlow/test/Types.hs b/paymentFlow/test/Types.hs new file mode 100644 index 0000000..c65baaf --- /dev/null +++ b/paymentFlow/test/Types.hs @@ -0,0 +1,24 @@ +{-# OPTIONS_GHC -fplugin=RecordDotPreprocessor #-} +{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, UndecidableInstances #-} +{-# LANGUAGE TemplateHaskell #-} + +module Types where + +import Data.Aeson +import Data.Text +import Control.Lens + +data TxnDetail = TxnDetail + +data MerchantAccount = MerchantAccount { + merchantId :: Maybe Text, + shouldAddSurcharge :: Bool, + -- showSurchargeBreakupScreen :: Bool + _showSurchargeBreakupScreen :: Bool +} + +data AK = Skip Bool | Force + +data MerchantConfigStatus = PaymentFlowNotEligible | Disabled | Enabled + +makeLenses ''MerchantAccount \ No newline at end of file diff --git a/paymentFlow/test/Types1.hs b/paymentFlow/test/Types1.hs new file mode 100644 index 0000000..e6576c1 --- /dev/null +++ b/paymentFlow/test/Types1.hs @@ -0,0 +1,16 @@ +-- {-# LANGUAGE FlexibleInstances #-} +-- {-# LANGUAGE TypeFamilies #-} +-- {-# LANGUAGE MultiParamTypeClasses #-} +-- {-# LANGUAGE DataKinds #-} +{-# OPTIONS_GHC -fplugin=RecordDotPreprocessor #-} +{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, UndecidableInstances #-} +module Types1 where + +import Data.Aeson +import Control.Lens + +data SurchargeConfig = SurchargeConfig + {shouldAddSurchargeToRefund :: Bool, showSurchargeBreakupScreen :: Maybe Bool} + deriving (Show, Eq) + + diff --git a/sheriff/.juspay/indexedKeys.yaml b/sheriff/.juspay/indexedKeys.yaml index 1f09e43..6a9b3b8 100644 --- a/sheriff/.juspay/indexedKeys.yaml +++ b/sheriff/.juspay/indexedKeys.yaml @@ -10,6 +10,9 @@ tables: - partitionKey - name: Token + ignoredModules: [] + ignoredFunctions: [] + checkModules: [] indexedKeys: - provisionedTokenId - tokenReferenceId diff --git a/sheriff/.juspay/sheriffExceptionRules.yaml b/sheriff/.juspay/sheriffExceptionRules.yaml index dde3ba6..aa36e6d 100644 --- a/sheriff/.juspay/sheriffExceptionRules.yaml +++ b/sheriff/.juspay/sheriffExceptionRules.yaml @@ -1,29 +1,12 @@ -# rules: -# - ruleName: "GeneralRuleTest" -# conditions: -# - fnName: "pack" -# isQualified: false -# argNo: 0 -# action: "Blocked" -# argTypes: [] -# argFns: [] -# suggestedFixes: ["Remove `pack` function call from the error location."] -# ruleInfo: -# fnName: "show" -# isQualified: false -# argNo: 1 -# action: "Blocked" -# argTypes: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] -# argFns: [] -# suggestedFixes: ["Remove `show` function call from the error location. If quotes are required, manually add them to the text.","You might want to use a convertor function like `Data.Text.pack`, `Data.Text.unpack`, `decodeUtf8`, `encodeUtf8`, etc."] - -# - fn_rule_name: "ShowRule" -# fn_name: "show" -# arg_no: 1 -# fns_blocked_in_arg: [] -# types_blocked_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] -# types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] -# fn_rule_fixes: ["Remove `show` function call from the error location. If quotes are required, manually add them to the text.","You might want to use a convertor function like `Data.Text.pack`, `Data.Text.unpack`, `decodeUtf8`, `encodeUtf8`, etc."] -# fn_rule_exceptions: [] - -rules: [] \ No newline at end of file +rules: + - fn_rule_name: "ShowRule" + fn_name: "show" + arg_no: 1 + fns_blocked_in_arg: [] + types_blocked_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] + types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] + fn_rule_fixes: ["Remove `show` function call from the error location. If quotes are required, manually add them to the text.","You might want to use a convertor function like `Data.Text.pack`, `Data.Text.unpack`, `decodeUtf8`, `encodeUtf8`, etc."] + fn_rule_exceptions: [] + fn_rule_ignore_modules: # Adding as ignore module, so that this exception is not applied on Below modules + - Test1 + - SubTests.ShowTest \ No newline at end of file diff --git a/sheriff/.juspay/sheriffRules.yaml b/sheriff/.juspay/sheriffRules.yaml index 4c26a4d..e1ebeb6 100644 --- a/sheriff/.juspay/sheriffRules.yaml +++ b/sheriff/.juspay/sheriffRules.yaml @@ -1,21 +1,31 @@ rules: - - ruleName: "GeneralRuleTest" + - ruleName: GeneralRuleTest conditions: - - fnName: "pack" + - fnName: pack isQualified: false argNo: 0 - action: "Blocked" + action: Blocked argTypes: [] argFns: [] - suggestedFixes: ["Remove `pack` function call from the error location."] + suggestedFixes: + - Remove `pack` function call from the error location. ruleInfo: - fnName: "show" + fnName: show isQualified: false argNo: 1 - action: "Blocked" - argTypes: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] + action: Blocked + argTypes: + - Text + - String + - Char + - "[Char]" + - Maybe + - "(,)" + - "[]" argFns: [] - suggestedFixes: ["Remove `show` function call from the error location. If quotes are required, manually add them to the text.","You might want to use a convertor function like `Data.Text.pack`, `Data.Text.unpack`, `decodeUtf8`, `encodeUtf8`, etc."] + suggestedFixes: + - Remove `show` function call from the error location. If quotes are required, manually add them to the text. + - You might want to use a convertor function like `Data.Text.pack`, `Data.Text.unpack`, `decodeUtf8`, `encodeUtf8`, etc. # - fn_rule_name: "ShowRule" # fn_name: "show" @@ -27,172 +37,152 @@ rules: # fn_rule_exceptions: [] # fn_rule_ignore_modules: [] - # - db_rule_name: "DBRuleTest" - # table_name: "TxnRiskCheck" - # indexed_cols_names: - # - partitionKey - # - and: - # - txnId - # - customerId - # db_rule_fixes: ["You might want to include an indexed column in the `where` clause of the query."] - # db_rule_exceptions: [] - - fn_rule_name: "LogRule" - fn_name: "logErrorT" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] + fn_name: + - "logErrorT" + - "logErrorV" + - "logError" + - "logDebugT" + - "logDebugV" + - "logDebug" + - "logInfoT" + - "logInfoV" + - "logInfo" + - "logErrorWithCategoryT" + - "logErrorWithCategoryV" + - "logErrorWithCategory" + - "logWarningT" + - "logWarningV" + - "logWarning" + - "forkErrorLog" + - "forkInfoLog" + - "debugLog" + - "warnLog" + - "logDecryptedResponse" + - "logDecryptedRequest" + arg_no: 2 + fns_blocked_in_arg: + - ["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]] + - ["encode", 1, []] + - ["encodeJSON", 1, []] + types_blocked_in_arg: [] + types_to_check_in_arg: + - "Text" + - "String" + - "Char" + - "[Char]" + - "Maybe" + - "(,)" + - "[]" + fn_rule_fixes: + - "Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module." + - "Make sure that there is `ToJSON` instance on the value we are logging." + - "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)" fn_rule_exceptions: [] fn_rule_ignore_modules: [] - - fn_rule_name: "LogRule" - fn_name: "logErrorV" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] + - fn_rule_name: "ART KVDB Rule" + fn_name: "runKVDB" + arg_no: 0 + fns_blocked_in_arg: [] + types_blocked_in_arg: [] + types_to_check_in_arg: [] + fn_rule_fixes: + - "You might want to use some other wrapper function from `EulerHS.Extra.Redis` module." + - "For e.g. - rExists, rDel, rGet, rExpire, etc." + fn_rule_exceptions: [] + fn_rule_ignore_modules: + - "EulerHS.Extra.Redis" + - "EulerHS.Framework.Interpreter" + - "EulerHS.Framework.Language" + - "EulerHS.KVDB.Interpreter" + - "KVDB.KVDBSpec" + + - fn_rule_name: Test Qualified Function Name Rule + fn_name: TestUtils.throwException + arg_no: 0 + fns_blocked_in_arg: [] + types_blocked_in_arg: [] + types_to_check_in_arg: [] + fn_rule_fixes: + - You are not allowed to use helper function `throwException` from `TestUtils` module. + - Use `throwExceptionV2` or `throwExceptionV4` function from `TestUtils` module. + fn_rule_exceptions: [] + fn_rule_ignore_modules: + - TestUtils - - fn_rule_name: "LogRule" - fn_name: "logError" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logInfoT" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logInfoV" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] + - fn_rule_name: Test Multiple Function Names in Single Rule + fn_name: + - TestUtils.throwExceptionV2 + - TestUtils.throwExceptionV4 + arg_no: 0 + fns_blocked_in_arg: [] + types_blocked_in_arg: [] + types_to_check_in_arg: [] + fn_rule_fixes: + - You are not allowed to use helper function `throwException` from `TestUtils` module. + - Use `throwExceptionV2` or `throwExceptionV4` function from `TestUtils` module. + fn_rule_exceptions: [] + fn_rule_ignore_modules: + - TestUtils + - Exceptions - - fn_rule_name: "LogRule" - fn_name: "logInfo" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] + # We should get the errors for below functions two times only in Test1 + - fn_rule_name: Test Allowed Modules + fn_name: + - throwExceptionV2 + - TestUtils.throwExceptionV4 + arg_no: 0 + fns_blocked_in_arg: [] + types_blocked_in_arg: [] + types_to_check_in_arg: [] + fn_rule_fixes: + - Check Allowed Modules + - You are not allowed to use helper function `throwException` from `TestUtils` module. + - Use `throwExceptionV2` or `throwExceptionV4` function from `TestUtils` module. fn_rule_exceptions: [] fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logDebugT" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] + fn_rule_check_modules: + - SubTests.FunctionUseTest + + # We should get error for usage of any function with Number type + - fn_rule_name: Test Functions usage blocked having any function name with given signature + fn_name: + - "*" + arg_no: 0 + fn_sigs_blocked: + - Number -> TestUtils1.Number -> * + - Number -> Number -> Number + - TestUtils.Number -> TestUtils.Number -> TestUtils.Number + - Maybe (Either (Maybe (Int)) (Maybe (Int))) -> Number -> Number -> Number + fns_blocked_in_arg: [] + types_blocked_in_arg: [] + types_to_check_in_arg: [] + fn_rule_fixes: + - Contact senior dev for the solution. fn_rule_exceptions: [] fn_rule_ignore_modules: [] + fn_rule_check_modules: + - Test1 + - SubTests.FunctionUseTest - - fn_rule_name: "LogRule" - fn_name: "logDebugV" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] + # We should get error for any infinite recursion cases + - infinite_recursion_rule_name: Infinite Recursions + infinite_recursion_rule_fixes: + - "Remove the infinite recursion." + - "Add a base case check." + - "Pass the modified value to function arguments." + infinite_recursion_rule_ignore_functions: + - Exceptions.pattern6 - - fn_rule_name: "LogRule" - fn_name: "logDebug" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logErrorWithCategoryT" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logErrorWithCategoryV" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "logErrorWithCategory" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "forkErrorLog" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "forkInfoLog" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "debugLog" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] - - - fn_rule_name: "LogRule" - fn_name: "warnLog" - arg_no: 2 - fns_blocked_in_arg: [["show", 1, ["EnumTypes", "Integer", "Double", "Float", "Int64", "Int", "Bool", "Number", "(,)", "[]", "Maybe"]], ["encode", 1, []], ["encodeJSON", 1, []]] - types_blocked_in_arg: [] - types_to_check_in_arg: ["Text", "String", "Char", "[Char]", "Maybe", "(,)", "[]"] - fn_rule_fixes: ["Remove `show` function call from the error location and use `L.logErrorV @Text` or `L.logDebugV @Text` or `L.logInfoV @Text` function(s) imported from `EulerHS.Language` module.", "Make sure that there is `ToJSON` instance on the value we are logging.", "You may use tuples for combining string and objects. For e.g., (\"Failed to fetch object: \" :: Text, obj)"] - fn_rule_exceptions: [] - fn_rule_ignore_modules: [] \ No newline at end of file + - db_rule_name: "DBRuleTest" + table_name: "Table1" + indexed_cols_names: + - partitionKey + - id + - and: + - customerPhone + - customerEmail + db_rule_fixes: + - "You might want to include an indexed column in the `where` clause of the query." + db_rule_exceptions: [] \ No newline at end of file diff --git a/sheriff/CHANGELOG.md b/sheriff/CHANGELOG.md index 43bd582..4906f7c 100644 --- a/sheriff/CHANGELOG.md +++ b/sheriff/CHANGELOG.md @@ -1,5 +1,77 @@ -# Revision history for code-checker +# Revision history for sheriff -## 0.1.0.0 -- 2024-03-20 +## 0.2.1.8 +* Add module level & function level exceptions to DB rules + +## 0.2.1.7 +* Fix DB rules not working in GHC 9 + +## 0.2.1.6 +* Fix DB rules not getting applied +* Add information about exceptions to be added in infinite recursion rule + +## 0.2.1.5 +* Resolve names to top most name to enable unique matching for name-shadowing cases of function name +* Add test cases for name shadowing of function name + +## 0.2.1.4 +* Fix function level exception not working for Function Rule +* Fix infinite recursion being detected for top level function being called in `where` clause +* Fix infinite recursion being detected for partial function being called inside `fmap` or any other higher order function +* Handle infinite recursion of `LambdaCase` & `Lambda` functions for partial functions +* Add/modify test cases to check above cases +* Refactor partial function infinite recursion detection code for better reusability + +## 0.2.1.3 +* Fix type signature check case by avoiding constraint matching (edge case for functions with constraint cases, only in GHC 9.2.8) + +## 0.2.1.2 +* Add type signature check for self recursive function calls (edge cases for instance calls) + +## 0.2.1.1 +* Add module name check for self recursive function calls +* Add top level module name resolving for instance functions +* Add test cases for calling function with same name but from different module +* Add State based flow implementation for infinite recursion detection + +## 0.2.1.0 +* Add infinite recursion rule +* Add test suite for infinite recursion rule detection +* Add ignore functions in function rules +* Refactor to use implicit params for passing plugin opts +* Refactor CommonTypes, Types, Utils and TypeUtils +* Update documentation for infinite recursion rule and supported patterns +## 0.2.0.6 +* Refactor test cases in cabal file to use smaller test suites +* Revert functionality to provide sheriff plugin opt in single module +* Refactor plugin to split types modules + +## 0.2.0.5 +* Refactor test cases to different sub test cases instead of single test file +* Add functionality to provide sheriff plugin opt in single module +* Refactor & Re-enable log rules check as a subtest +* Add -fkeep-going to keep testing individual sub-test despite compilation error + +## 0.2.0.4 +* Fix modules names not getting matched due to wildcard support logical bug + +## 0.2.0.3 +* Fix non-exhaustive pattern +* Fix wildcard character matching for asterisk operator +* Customizable asterisk matching based on use case + +## 0.2.0.2 +* Remove support for wilcard character "*" in function rules (to be fixed) + +## 0.2.0.1 +* Add derived signature from the arg types for Signature Check rules + +## 0.2.0.0 +* Add allowed modules list for function rule +* Add support for wilcard character "*" in function rules +* Add support for blocking functions with a particular type signature (Signature only) +* Add test rules in sheriff rules + +## 0.1.0.0 -- 2024-03-20 * First version. Basic rules based compilation error. diff --git a/sheriff/DOC.md b/sheriff/DOC.md index cf65691..8877eec 100644 --- a/sheriff/DOC.md +++ b/sheriff/DOC.md @@ -1,11 +1,95 @@ -### Sheriff plugin -#### What it does? -`sheriff` is a compiler plugin to add business logic compilation checks. It checks for the following logics : +# Sheriff plugin +## What it does? +`sheriff` is a compiler plugin to add general code related compilation checks. It can check for the following logics : 1. ***Stringification of text/string*** - It checks if we have applied show function to any string literal or variable. We want this check to make sure that we are not having any unnecessary quotes in the code. 2. ***Logging stringified haskell objects*** - It checks if we have applied show function to any variable inside any logging function. We want this check to make sure that our application level log filtering works fine and is not impacted by stringified values. 3. ***NonIndexed DB queries check*** - It checks if we have any DB query which won't use any index. Indexed columns of a particular table are provided in *.juspay/indexedKeys.yaml* file. It works on all operations - select, update, delete. We want this check to make sure that our application doesn't make any slow sql queries. +4. ***Detecting use of non-allowed functions*** - It checks if we have used some function which should not be used anywhere in the code. The function can be specified by name or signature. +5. ***Detecting infinite recursions in code*** - It checks whether there is any recursive call for which we can say that it is infinite recursion. For more details, check [Infinite Recursion Detected Patterns](./InfiniteRecursionPatterns.md) -#### How to resolve compilation errors? +## How to write rules? +Any additional rules for a package can be provided as `yaml` file. The path to this rules file can be given as plugin option as follows:
+```cabal +-fplugin-opt=Sheriff.Plugin:{"rulesConfigPath":".juspay/sheriffRules.yaml","exceptionsConfigPath":".juspay/sheriffExceptionRules.yaml"} +``` + +- Both rules and exception rules follows same structure & format. +- Rules provided in exceptions file are global exceptions for all rules. It means if the code satisfies both a rule and any exceptionRule, then error won't be thrown. + +> Structure of Function Rules: +>```yaml +> - fn_rule_name: "" +> fn_name: "" +> arg_no: 1 +> fns_blocked_in_arg: +> - [ModuleA.dummyFn, 0, []] +> - [dummyFn2, 0, []] +> types_blocked_in_arg: +> - String +> - Text +> - EnumTypes +> - Int64 +> - Person +> types_to_check_in_arg: +> - String +> - Text +> - EnumTypes +> - Int64 +> - Person +> fn_rule_fixes: +> - Sample Suggested Fix 1" +> - Sample Suggested Fix 2" +> fn_rule_exceptions: +> - fn_rule_name: "" +> fn_name: dummy +> arg_no: 0 +> fns_blocked_in_arg: [] +> types_blocked_in_arg: [] +> types_to_check_in_arg: [] +> fn_rule_fixes: [] +> fn_rule_exceptions: [] +> fn_rule_ignore_modules: +> - ModuleK +> fn_rule_ignore_modules: +> - ModuleT +> - ModuleP +> fn_rule_check_modules: +> - ModuleA +> - ModuleB +>``` + +> Structure of DB Rules: +>```yaml +> - db_rule_name: "DBRuleTest" +> table_name: "TxnRiskCheck" +> indexed_cols_names: +> - partitionKey +> - and: +> - txnId +> - customerId +> db_rule_fixes: +> - You might want to include an indexed column in the `where` clause of the query. +> db_rule_exceptions: [] +>``` + +> Structure of Infinite Recursion Rule: +>```yaml +> - infinite_recursion_rule_name: "Infinite Recursion Rule" +> infinite_recursion_rule_fixes: +> - Fix1 +> - Fix 2 +> infinite_recursion_rule_exceptions: [] +> infinite_recursion_rule_ignore_modules: +> - ModuleA +> infinite_recursion_rule_check_modules: +> - "*" +> infinite_recursion_rule_ignore_functions: +> - ModuleB.fn2 +>``` + +Refer Sample Rules and exception rules in [.juspay](.juspay/sheriffRules.yaml) directory. + +## How to resolve compilation errors? > "show" on Text is not allowed - Remove `show` function call from the error location. If quotes are required, manually add them to the text. - If conversions are required, we can use `encodeUtf8`, `pack`, `unpack`, etc. @@ -44,3 +128,6 @@ > Querying on non-indexed column 'ColumnName' of table 'TableName' is not allowed - Make sure that the `where` clause in the DB query use at least one index, be it composite index or a key + +> Infinite Recursion Detected + - Remove infinite recursion by adding base case or changing the function logic. If it is genuine case that must stay in code like server loop, interval loop, etc., then add the function as ignore function. \ No newline at end of file diff --git a/sheriff/InfiniteRecursionPatterns.md b/sheriff/InfiniteRecursionPatterns.md new file mode 100644 index 0000000..721cc03 --- /dev/null +++ b/sheriff/InfiniteRecursionPatterns.md @@ -0,0 +1,115 @@ +# Detectable Infinite Recursion Patterns + +> For exact examples for what is covered and what is not, refer [InfiniteRecursionTest.hs](./test/SubTests/InfiniteRecursionTest.hs) + +## Pattern 1 : Self call to variable / function without argument +For e.g. - +```haskell +let x = x <> "Dummy" in x +``` + +## Pattern 2 : Call to complete function (non partial function) with same arguments +For e.g. - +```haskell +fn :: Int -> Int +fn val = fn val +``` +```haskell +fn :: String -> String +fn a = + let z = fn a + in z +``` +However, this won't throw error +```haskell +fn :: String -> String +fn a = + let a = "Changed value" + in fn a +``` + +## Pattern 3 : Self recursive call in instance function +for e.g. - +```haskell +toJSON a = A.toJSON a +``` + +## Pattern 4 : Infinite recursion call in where clause +for e.g. - +```haskell +fn :: String -> String -> String +fn x y = fn2 x y + where + fn2 a b = fn3 $ fn2 a b +``` + +## Pattern 5 : Infinite recursion call with specific pattern match +for e.g. - +```haskell +fn :: Int -> Int +fn 10 = fn (10 :: Int) +fn _ = -1 +``` + +## Pattern 6 : Infinite recursion call on partial function +for e.g. - +```haskell +fn :: Int -> Int +fn = fn +``` + +## Pattern 7 : Infinite recursion call on partial function called inside function composition +for e.g. - +```haskell +fn :: Int -> Int +fn = fn2 . fn +``` + +## Pattern 8 : Infinite recursion call on partial function in the last stmt +for e.g. - +```haskell +fn :: Int -> Int +fn = let z = "Dummy" in fn +``` + +## Pattern 9 : Infinite recursion in instance method based on instance type being used +for e.g. - +```haskell +class TypeChanger a b where + changeType :: a -> b + +data SumType = TypeA Int | TypeB | RecType SumType + +instance TypeChanger String SumType where + changeType x = RecType $ changeType x -- Infinite recursion since types are same (TypeChanger String SumType) + +instance TypeChanger Integer SumType where + changeType = TypeA . changeType -- NOT infinite recursion since type is changed (TypeChanger Integer Int) +``` + +## Pattern 10 : Infinite recursion call on partial function but using lambda case (in normal function and instance methods, as first statement or let-in statement or in function composition) +for e.g. - +```haskell +fn :: String -> String +fn = \case + "Pattern" -> fn "Pattern" -- Infinite recursion due to same pattern + b -> fn b -- Infinite recursion due to same variable +``` + +## Pattern 11 : Infinite recursion call on partial function but using lambda functions (in normal function and instance methods, as first statement or let-in statement or in function composition) +for e.g. - +```haskell +fn :: String -> String +fn = \x -> fn x -- Infinite recursion due to same variable +``` + +Following cases are detected as infinite recursion, but these might not be infinite and it can depend on how these are written: +1. Functions modifying state and preventing infinite recursion on the basis of that +2. Functions using mutable variables like IORef, MVar, etc. +3. Functions intended to be used as infinite recursion, such as loopers with threadDelay, server listeners, etc. +4. Functions generating infinite list but used lazily in context of some limiting function like zip, take, etc. +5. Some functions which are infinite only hypothetically, such as function having control flow based on some random number generated + +For such cases, it is advisable to add them to `infinite_recursion_rule_ignore_functions` cases and maintain a list of such functions for visibility. +It is always advised to add ignore functions as fully qualified names. +In case of any other false positive/anamoly, let us know by raising an issue. \ No newline at end of file diff --git a/sheriff/README.MD b/sheriff/README.MD index 21fa4b5..e4e2597 100644 --- a/sheriff/README.MD +++ b/sheriff/README.MD @@ -7,8 +7,12 @@ This Haskell plugin automatically checks function calls for given rule violation It supports the following rules as of now 1. Blocking certain type of argument to a function call :- Give a particular argument number `arg_no` and list of types to be blocked in that argument `types_blocked_in_arg` 2. Blocking use of certain function in an argument to a function call :- Give a particular argument number `arg_no` and list of functions to be blocked in argument.
_Note: The function presence will be checked and blocked only if the type of the argument is in list of `types_to_check_in_arg`_ -3. Blocking use of a particular function in the code (specify argument number `arg_no` as `0` in the rule) -4. Blocking querying in DB on non-indexed columns (indexed columns are provided in a yaml file) +3. Blocking use of a particular function in the code (specify argument number `arg_no` as `0` in the rule). +4. Blocking all or particular functions with given type signature +5. Applying rules to all or particular set of modules +6. Blocking querying in DB on non-indexed columns (indexed columns are provided in a yaml file) +7. Detecting deterministic infinite recursion in the code +8. Blocking use of functions with given signature This tool is useful for developers to enforce better coding practices and prevent use of some specific unsafe function in the code. @@ -21,6 +25,7 @@ Add this to your ghc-options in cabal and mention `sheriff` in build-depends ``` Also, we can provide flags to the plugin in as follows: ``` --fplugin-opt=Sheriff.Plugin:{"throwCompilationError":true,"saveToFile":true,"savePath":".juspay/tmp/sheriff/","indexedKeysPath":".juspay/indexedKeys.yaml","failOnFileNotFound":true,"matchAllInsideAnd":false} +-fplugin-opt=Sheriff.Plugin:{"throwCompilationError":true,"saveToFile":true,"savePath":".juspay/tmp/sheriff/","indexedKeysPath":".juspay/indexedKeys.yaml", +"rulesConfigPath":".juspay/sheriffRules.yaml","exceptionsConfigPath":".juspay/sheriffExceptionRules.yaml","failOnFileNotFound":true,"matchAllInsideAnd":false} ``` -By default, it throwsCompilationErrors and doesn't log to file. Also, it fails, if indexedKeys file is not found or is invalid \ No newline at end of file +By default, it throwsCompilationErrors and doesn't log to file. Also, it fails, if indexedKeys file is not found or is invalid. \ No newline at end of file diff --git a/sheriff/sheriff.cabal b/sheriff/sheriff.cabal index fbedc17..455fb05 100644 --- a/sheriff/sheriff.cabal +++ b/sheriff/sheriff.cabal @@ -1,7 +1,7 @@ cabal-version: 3.0 name: sheriff -version: 0.1.0.0 -synopsis: A plugin to throw compilation errors based on given rules +version: 0.2.1.8 +synopsis: A checker plugin to throw compilation errors based on given rules; basically what a `Sheriff` does license: MIT license-file: LICENSE author: piyushgarg-juspay @@ -10,21 +10,21 @@ category: Development build-type: Simple extra-doc-files: CHANGELOG.md -Flag Dev +Flag SheriffDev Description: Use ghc options to dump ASTs in dev mode Default: False Manual: True common common-options - build-depends: base ^>=4.14.3.0 - ghc-options: -Wall + build-depends: base + ghc-options: -Werror -Wincomplete-uni-patterns -Wincomplete-record-updates + -Wincomplete-patterns -Wcompat -Widentities -Wredundant-constraints -fhide-source-paths - default-language: Haskell2010 default-extensions: DeriveGeneric GeneralizedNewtypeDeriving @@ -35,60 +35,119 @@ common common-options ScopedTypeVariables StandaloneDeriving TypeApplications - + CPP library import: common-options exposed-modules: Sheriff.Plugin other-modules: - Sheriff.Types + Sheriff.CommonTypes + Sheriff.Patterns Sheriff.Rules + Sheriff.Types + Sheriff.TypesUtils + Sheriff.Utils build-depends: - bytestring - , containers - , filepath - , ghc ^>= 8.10.7 - , ghc-exactprint - , unordered-containers - , uniplate >= 1.6 && < 1.7 - , references - , classyplate - , aeson - , directory - , extra - , yaml - , text - , aeson-pretty + bytestring + , containers + , filepath + , ghc + , ghc-exactprint + , hashable + , unordered-containers + , uniplate + , references + , classyplate + , aeson + , directory + , mtl + , extra + , yaml + , text + , aeson-pretty hs-source-dirs: src default-language: Haskell2010 + if flag(SheriffDev) + cpp-options: -DSheriffDev -test-suite sheriff-test - import: common-options - +common test-common-options + import: common-options default-language: Haskell2010 - type: exitcode-stdio-1.0 - hs-source-dirs: test - - main-is: Main.hs other-modules: - Test1 - + TestUtils + Exceptions build-depends: - , sheriff + sheriff , aeson , text , containers , bytestring , aeson-pretty , extra - - if flag(Dev) + if flag(SheriffDev) ghc-options: - -- Plugin options order: {"throwCompilationError":true,"saveToFile":true,"savePath":".juspay/tmp/sheriff/","indexedKeysPath":".juspay/tmp"} + -- Plugin Options available (default values) + -- {"saveToFile":false,"throwCompilationError":true,"failOnFileNotFound":true,"matchAllInsideAnd":false",savePath":".juspay/tmp/sheriff/","indexedKeysPath":".juspay/indexedKeys.yaml","rulesConfigPath":".juspay/sheriffRules.yaml","exceptionsConfigPath":".juspay/sheriffExceptionRules.yaml","logDebugInfo":false,"logWarnInfo":true,"logTypeDebugging":false,"useIOForSourceCode":false} + -fkeep-going -fplugin=Sheriff.Plugin -fplugin-opt=Sheriff.Plugin:{"throwCompilationError":true,"saveToFile":true,"savePath":".juspay/tmp/sheriff/","indexedKeysPath":".juspay/indexedKeys.yaml","failOnFileNotFound":true,"matchAllInsideAnd":true,"logDebugInfo":false,"logTypeDebugging":false,"useIOForSourceCode":true} -dumpdir=.juspay/tmp/sheriff/ -ddump-to-file -ddump-parsed-ast -ddump-tc-ast else ghc-options: + -fkeep-going -fplugin=Sheriff.Plugin + +test-suite sheriff-test + import: test-common-options + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + other-modules: + SubTests.FunctionUseTest + SubTests.InfiniteRecursionTest + SubTests.LogTest + SubTests.ShowTest + Test1 + +test-suite sheriff-show-test + import: test-common-options + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + other-modules: SubTests.ShowTest + +test-suite sheriff-log-test + import: test-common-options + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + other-modules: SubTests.LogTest + +-- TODO: Add Sequelize and Beam and DB Rule Test Suite +-- test-suite sheriff-db-rule-test +-- import: test-common-options +-- default-language: Haskell2010 +-- type: exitcode-stdio-1.0 +-- hs-source-dirs: test +-- main-is: Main.hs +-- other-modules: SubTests.DBRuleTest + +test-suite sheriff-function-use-test + import: test-common-options + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + other-modules: SubTests.FunctionUseTest + +test-suite sheriff-infinite-recursion-test + import: test-common-options + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + other-modules: SubTests.InfiniteRecursionTest \ No newline at end of file diff --git a/sheriff/src/Sheriff/CommonTypes.hs b/sheriff/src/Sheriff/CommonTypes.hs new file mode 100644 index 0000000..edd2ca2 --- /dev/null +++ b/sheriff/src/Sheriff/CommonTypes.hs @@ -0,0 +1,137 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeSynonymInstances #-} + +module Sheriff.CommonTypes where + +import Data.Hashable +import qualified Data.HashMap.Strict as HM +import GHC hiding (exprType) + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Plugins hiding ((<>), getHscEnv) +import GHC.Tc.Types +#else +import GhcPlugins hiding ((<>), getHscEnv) +import TcRnMonad +#endif + +data NameModuleValue = + NMV_Name Name + | NMV_ClassModule Name String -- This should be terminating for class name and module name + deriving (Eq) + +instance Show NameModuleValue where + show (NMV_Name name) = "NMV_Name " <> getOccString name <> "_" <> show (nameUnique name) + show (NMV_ClassModule name modName) = "NMV_ClassModule " <> getOccString name <> "_" <> show (nameUnique name) <> modName + +data PluginCommonOpts a = PluginCommonOpts { + currentModule :: String, + nameModuleMap :: HM.HashMap NameModuleValue NameModuleValue, + pluginOpts :: a + } + deriving (Show, Eq) + +type (HasPluginOpts a) = ?pluginOpts :: (PluginCommonOpts a) + +-- Recursive data type for simpler type representation +data TypeData = TextTy String | NestedTy [TypeData] + deriving (Show, Eq) + +data SimpleTcExpr = + SimpleVar Var + | SimpleFnNameVar Var Type -- Just for checking function Variable + | SimpleList [SimpleTcExpr] + | SimpleAliasPat SimpleTcExpr SimpleTcExpr + | SimpleTuple [SimpleTcExpr] + | SimpleDataCon (Maybe Var) [SimpleTcExpr] + | SimpleLit (HsLit GhcTc) + | SimpleOverloadedLit OverLitVal + | SimpleUnhandledTcExpr + +instance Outputable SimpleTcExpr where + ppr simpleTcExpr = case simpleTcExpr of + SimpleVar v -> "SimpleVar " $$ ppr v + SimpleFnNameVar v ty -> "SimpleFnNameVar " $$ ppr v $$ ppr ty + SimpleList ls -> "SimpleList " $$ ppr ls + SimpleAliasPat p1 p2 -> "SimpleAliasPat " + SimpleTuple ls -> "SimpleTuple " $$ ppr ls + SimpleDataCon mbCon ls -> "SimpleDataCon " $$ ppr mbCon $$ ppr ls + SimpleLit lit -> "SimpleLit " $$ ppr lit + SimpleOverloadedLit overloadedLit -> "SimpleOverloadedLit " $$ ppr overloadedLit + SimpleUnhandledTcExpr -> "SimpleUnhandledTcExpr" + +instance Eq SimpleTcExpr where + (==) (SimpleAliasPat pat11 pat12) (SimpleAliasPat pat21 pat22) = pat11 == pat12 || pat12 == pat22 || pat11 == pat22 || pat12 == pat21 + (==) (SimpleAliasPat pat1 pat2) pat = pat1 == pat || pat2 == pat + (==) pat (SimpleAliasPat pat1 pat2) = pat1 == pat || pat2 == pat + (==) (SimpleVar var1) (SimpleVar var2) = var1 == var2 + (==) (SimpleFnNameVar var1 ty1) (SimpleFnNameVar var2 ty2) = nameOccName (varName var1) == nameOccName (varName var2) + (==) (SimpleList pat1) (SimpleList pat2) = pat1 == pat2 + (==) (SimpleTuple pat1) (SimpleTuple pat2) = pat1 == pat2 + (==) (SimpleDataCon mbVar1 pat1) (SimpleDataCon mbVar2 pat2) = mbVar1 == mbVar2 && pat1 == pat2 + (==) (SimpleLit lit1) (SimpleLit lit2) = lit1 == lit2 + (==) (SimpleOverloadedLit lit1) (SimpleOverloadedLit lit2) = lit1 == lit2 + (==) (SimpleUnhandledTcExpr) (SimpleUnhandledTcExpr) = False + (==) _ _ = False + +-- Data type to represent asterisk matching +data AsteriskMatching = AsteriskInFirst | AsteriskInSecond | AsteriskInBoth | NoAsteriskMatching + deriving (Show, Eq) + +-- Type family and GADT for generic phase related stuff +#if __GLASGOW_HASKELL__ >= 900 +type family PassMonad (p :: Pass) a +type instance PassMonad 'Parsed a = Hsc a +type instance PassMonad 'Renamed a = TcRn a +type instance PassMonad 'Typechecked a = TcM a +#else +data MyGhcPass (p :: Pass) where + GhcPs :: MyGhcPass 'Parsed + GhcRn :: MyGhcPass 'Renamed + GhcTc :: MyGhcPass 'Typechecked + +class IsPass (p :: Pass) where + ghcPass :: MyGhcPass p + +instance IsPass 'Parsed where + ghcPass = GhcPs + +instance IsPass 'Renamed where + ghcPass = GhcRn + +instance IsPass 'Typechecked where + ghcPass = GhcTc + +type family PassMonad (p :: Pass) a +type instance PassMonad 'Parsed a = Hsc a +type instance PassMonad 'Renamed a = TcRn a +type instance PassMonad 'Typechecked a = TcM a +#endif + +instance Hashable (Located Var) where + hashWithSalt salt (L srcSpan var) = hashWithSalt salt $ show srcSpan <> "::" <> (nameStableString . getName $ var) + +instance Hashable NameModuleValue where + hashWithSalt salt (NMV_Name name) = hashWithSalt salt (nameStableString name) + hashWithSalt salt (NMV_ClassModule name _) = hashWithSalt salt (nameStableString name) + +class StrictEq a where + (===) :: (HasPluginOpts u) => a -> a -> Bool + +instance (StrictEq a) => StrictEq (Maybe a) where + (===) (Just x) (Just y) = x === y + (===) Nothing Nothing = True + (===) _ _ = False + +instance (StrictEq a) => StrictEq [a] where + (===) [] [] = True + (===) (x:xs) (y:ys) = (x === y && xs === ys) + (===) _ _ = False \ No newline at end of file diff --git a/sheriff/src/Sheriff/Patterns.hs b/sheriff/src/Sheriff/Patterns.hs new file mode 100644 index 0000000..64387dd --- /dev/null +++ b/sheriff/src/Sheriff/Patterns.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE PatternSynonyms #-} + +module Sheriff.Patterns where + +import GHC hiding (exprType) + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.TyCo.Rep +import GHC.Tc.Types.Evidence +import Language.Haskell.Syntax.Expr +#else +import GHC.Hs.Expr +import TcEvidence +import TyCoRep +#endif + +#if __GLASGOW_HASKELL__ >= 900 +pattern PatFunTy :: AnonArgFlag -> Type -> Type -> Type +pattern PatFunTy anonArgFlag ty1 ty2 <- (FunTy anonArgFlag _ ty1 ty2) + +pattern PatHsIf :: LHsExpr (GhcPass p) -> LHsExpr (GhcPass p) -> LHsExpr (GhcPass p) -> HsExpr (GhcPass p) +pattern PatHsIf pred thenCl elseCl <- (HsIf _ pred thenCl elseCl) + +pattern PatHsWrap :: HsWrapper -> HsExpr GhcTc -> HsExpr GhcTc +pattern PatHsWrap wrapper expr = (XExpr (WrapExpr (HsWrap wrapper expr))) + +pattern PatHsExpansion :: HsExpr GhcRn -> HsExpr GhcTc -> HsExpr GhcTc +pattern PatHsExpansion orig expanded <- (XExpr (ExpansionExpr (HsExpanded orig expanded))) + +pattern PatExplicitList :: (XExplicitList (GhcPass p)) -> [LHsExpr (GhcPass p)] -> HsExpr (GhcPass p) +pattern PatExplicitList typ arg = (ExplicitList typ arg) + +#else +pattern PatFunTy :: AnonArgFlag -> Type -> Type -> Type +pattern PatFunTy anonArgFlag ty1 ty2 <- (FunTy anonArgFlag ty1 ty2) + +pattern PatHsIf :: LHsExpr (GhcPass p) -> LHsExpr (GhcPass p) -> LHsExpr (GhcPass p) -> HsExpr (GhcPass p) +pattern PatHsIf pred thenCl elseCl <- (HsIf _ _ pred thenCl elseCl) + +pattern PatHsWrap :: HsWrapper -> HsExpr (GhcPass p) -> HsExpr (GhcPass p) +pattern PatHsWrap wrapper expr <- (HsWrap _ wrapper expr) where + PatHsWrap wrapper expr = (HsWrap NoExtField wrapper expr) + +pattern PatExplicitList :: XExplicitList (GhcPass p) -> [LHsExpr (GhcPass p)] -> HsExpr (GhcPass p) +pattern PatExplicitList typ arg <- (ExplicitList typ _ arg) where + PatExplicitList typ arg = (ExplicitList typ Nothing arg) + +#endif \ No newline at end of file diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 4c0ac6a..733cc3d 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -1,107 +1,73 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} -{-# OPTIONS_GHC -Werror=incomplete-patterns #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeSynonymInstances #-} module Sheriff.Plugin (plugin) where -import Bag (bagToList,listToBag, emptyBag) +-- Sheriff imports +import Sheriff.CommonTypes +import Sheriff.Patterns +import Sheriff.Rules +import Sheriff.Types +import Sheriff.TypesUtils +import Sheriff.Utils + +-- GHC imports +import Control.Applicative ((<|>)) +import Control.Monad (foldM, when) import Control.Monad.IO.Class (MonadIO (..)) -import Data.Data -import Control.Reference (biplateRef, (^?), Simple, Traversal) -import Data.Generics.Uniplate.Data (universeBi, childrenBi, contextsBi, holesBi, children) -import Data.List (nub) -import Debug.Trace (traceShowId, trace) -import Data.Yaml -import GHC - ( GRHS (..), - GRHSs (..), - GenLocated (L), - HsValBinds (..), - GhcTc, - HsBindLR (..), - HsConDetails (..), - HsConPatDetails, - HsExpr (..), - HsRecField' (..), - HsRecFields (..), - LGRHS, - MatchGroupTc(..), - HsType(..), - LHsType, - NoGhcTc(..), - HsTyLit(..), - HsWildCardBndrs(..), - LHsExpr, - LHsRecField, - LMatch, - LPat, - Match (m_grhss, m_pats), - MatchGroup (..), - Name, - Pat (..), - PatSynBind (..), - noLoc, noExtField, Module (moduleName), moduleNameString,Id(..),getName,nameSrcSpan,IdP(..),GhcPass, getModSummary - ) -import GHC.Hs.Binds -import GhcPlugins (idName,Var (varName), getOccString, unLoc, Plugin (pluginRecompile), PluginRecompile (..),showSDocUnsafe,ppr,elemNameSet,pprPrefixName,idType,tidyOpenType, isEnumerationTyCon, WarnReason(..), msHsFilePath, getModule) -import HscTypes (ModSummary (..)) -import GhcPlugins (moduleEnvToList, ModIface, lookupModuleEnv) -import HscTypes (hscEPS, eps_PTE, hsc_targets, mgModSummaries, hsc_mod_graph, hsc_HPT, eps_PIT, pprHPT, mi_decls, mi_exports, mi_complete_sigs, mi_insts, eps_complete_matches, eps_stats, eps_rule_base, mi_usages) -import GHC.IORef (readIORef) -import LoadIface (pprModIface, pprModIfaceSimple) -import Name (nameStableString) -import Plugins (CommandLineOption, Plugin (typeCheckResultAction), defaultPlugin) -import TcRnTypes (TcGblEnv (..), TcM) -import Prelude hiding (id,writeFile, appendFile) +import Control.Monad.State import Data.Aeson as A -import Data.ByteString.Lazy (writeFile, appendFile) -import System.Directory (createDirectoryIfMissing,getHomeDirectory) -import Data.Maybe (fromMaybe) -import Control.Exception (try,SomeException, catch) -import SrcLoc -import Annotations -import Outputable (showSDocUnsafe, ppr, Outputable(..)) -import GhcPlugins () -import DynFlags () -import Control.Monad (foldM,when) -import Data.List -import Data.List.Extra (replace, splitOn, sortOn) -import Data.Maybe (fromJust,isJust,mapMaybe) -import Sheriff.Types -import Sheriff.Rules import Data.Aeson.Encode.Pretty (encodePretty) -import Control.Concurrent -import System.Directory -import PatSyn -import Avail -import TcEnv -import GHC.Hs.Utils as GHCHs -import TyCoPpr ( pprUserForAll, pprTypeApp, pprSigmaType ) import Data.Bool (bool) -import qualified Data.Map as Map -import qualified Outputable as OP -import FastString -import Data.Maybe (catMaybes) -import DsMonad (initDsTc) -import DsExpr (dsLExpr) -import TcRnMonad (failWith, addErr, addWarn, addErrAt, addErrs, getEnv, env_top, env_gbl) -import Name (isSystemName) -import GHC (OverLitTc(..), HsOverLit(..)) -import CoreUtils (exprType) -import Control.Applicative ((<|>)) -import Type (isFunTy, funResultTy, splitAppTys, dropForAlls) -import TyCoRep (Type(..), TyLit (..)) -import Data.ByteString.Lazy as BSL () -import Data.String (fromString) +import Data.ByteString.Lazy (writeFile, appendFile) import qualified Data.ByteString.Lazy.Char8 as Char8 -import TcType +import Data.Data +import Data.Function (on) +import qualified Data.HashMap.Strict as HM +import Data.List (nub, sortBy, groupBy, find, isInfixOf, isSuffixOf, isPrefixOf) +import Data.List.Extra (splitOn) +import Data.Maybe (catMaybes, fromMaybe) +import Data.Yaml +import Debug.Trace (traceShowId, trace) +import GHC hiding (exprType) +import Prelude hiding (id, writeFile, appendFile) +import System.Directory (createDirectoryIfMissing, getHomeDirectory) + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.Class +import GHC.Core.ConLike +import GHC.Core.InstEnv +import GHC.Core.TyCo.Rep +import GHC.Data.Bag +import GHC.HsToCore.Monad +import GHC.HsToCore.Expr +import GHC.Plugins hiding ((<>), getHscEnv, purePlugin) +import GHC.Tc.Types +import GHC.Tc.Types.Evidence +import GHC.Tc.Utils.Monad +import GHC.Tc.Utils.TcType +import GHC.Types.Annotations +import qualified GHC.Utils.Outputable as OP +#else +import Bag +import Class import ConLike -import TysWiredIn -import GHC.Hs.Lit (HsLit(..)) +import DsExpr +import DsMonad +import GhcPlugins hiding ((<>), getHscEnv, purePlugin) +import InstEnv +import qualified Outputable as OP import TcEvidence -import qualified Data.HashMap.Strict as HM -import qualified Data.Text as T +import TcRnMonad +import TcRnTypes +import TcType +import TyCoRep +#endif plugin :: Plugin plugin = defaultPlugin { @@ -112,241 +78,580 @@ plugin = defaultPlugin { purePlugin :: [CommandLineOption] -> IO PluginRecompile purePlugin _ = return NoForceRecompile --- Parse the YAML file -parseYAMLFile :: (FromJSON a) => FilePath -> IO (Either ParseException a) -parseYAMLFile file = decodeFileEither file - --- Function to extract the code segment based on SrcSpan -extractSrcSpanSegment :: SrcSpan -> FilePath -> String -> IO String -extractSrcSpanSegment srcSpan' filePath oldCode = case srcSpan' of - RealSrcSpan srcSpan -> do - content' <- try (readFile filePath) :: IO (Either SomeException String) - case content' of - Left _ -> pure oldCode - Right content -> do - let fileLines = T.lines (T.pack content) - startLine = srcSpanStartLine srcSpan - endLine = srcSpanEndLine srcSpan - startCol = srcSpanStartCol srcSpan - endCol = srcSpanEndCol srcSpan - - -- Extract relevant lines - relevantLines = take (endLine - startLine + 1) $ drop (startLine - 1) fileLines - -- Handle single-line and multi-line spans - result = case relevantLines of - [] -> "" - [singleLine] -> T.take (endCol - startCol) $ T.drop (startCol - 1) singleLine - _ -> T.unlines $ [T.drop (startCol - 1) (head relevantLines)] ++ - (init (tail relevantLines)) ++ - [T.take endCol (last relevantLines)] - pure $ T.unpack result - _ -> pure oldCode +--------------------------- Core Logic --------------------------- + +{- + +Stage - 1 SETUP + 1. Parse the following - + 1.1 plugin options + 1.2 Rules yaml file + 1.3 Exceptions yaml file + 1.4 DB indexed keys file + 2. Filter out rules based on module level exceptions + 3. Separate out individual rule level exception rules + +Stage - 2 EXECUTION + 1. Repeat steps 2 to 4 for all function binds + 2. Extract all `LHsExpr` type i.e. all expressions + 3. Perform some simplifications + 4. For each rule, check if that rule is applicable or not. If applicable call, corresponding validation function. + 5. Validation function will return violation found along with other info required + 6. Detect infinite recursion errors + +Stage - 3 ERRORS AND IO + 1. Convert raw error information to high level error + 2. Sort & group errors on basis of src_span + 3. Filter out rules for rule level exceptions -- If current rule in the error group has any exception rule coinciding with any other rule in the error group, then eliminate current rule + 4. Filter out rules for global level exceptions -- if any rule in the error group is part of globalExceptions, then eliminate the group + 6. Throw errors, if configured + 7. Write errors to file, if configured + +-} + +{- + TODO: + 1. Generalize the custom state monad and run things with context available to all functions + 2. Reuse the same type for implicit param and state (for interusability) + 3. Add helper functions to set in implicit params from state + 4. Change module name matching to direct variable matching by means of transforming to top most level +-} + +type SheriffTcM = StateT (HM.HashMap NameModuleValue NameModuleValue) TcM sheriff :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv sheriff opts modSummary tcEnv = do - let pluginOpts = case opts of - [] -> defaultPluginOpts - (x : _) -> fromMaybe defaultPluginOpts $ A.decode (Char8.pack x) - - throwCompilationErrorV = throwCompilationError pluginOpts - saveToFileV = saveToFile pluginOpts - savePathV = savePath pluginOpts - indexedKeysPathV = indexedKeysPath pluginOpts - sheriffRulesPath = rulesConfigPath pluginOpts - sheriffExceptionsPath = exceptionsConfigPath pluginOpts - failOnFileNotFoundV = failOnFileNotFound pluginOpts - moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + -- STAGE-1 + let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + pluginOpts@PluginOpts{..} = decodeAndUpdateOpts opts defaultPluginOpts + + let ?pluginOpts = PluginCommonOpts moduleName' HM.empty pluginOpts -- parse the yaml file from the path given - parsedYaml <- liftIO $ parseYAMLFile indexedKeysPathV + parsedYaml <- liftIO $ parseYAMLFile indexedKeysPath -- parse the yaml file from the path given for sheriff general rules - parsedRulesYaml <- liftIO $ parseYAMLFile sheriffRulesPath + parsedRulesYaml <- liftIO $ parseYAMLFile rulesConfigPath -- parse the yaml file from the path given for sheriff general exception rules - parsedExceptionsYaml <- liftIO $ parseYAMLFile sheriffExceptionsPath - - -- Check the parsed yaml file for indexedDbKeys and throw compilation error if configured - rulesListWithDbRules <- case parsedYaml of - Left err -> do - when failOnFileNotFoundV $ addErr (mkInvalidYamlFileErr (show err)) - pure badPracticeRules - Right (YamlTables tables) -> pure $ badPracticeRules <> (map yamlToDbRule tables) + parsedExceptionsYaml <- liftIO $ parseYAMLFile exceptionsConfigPath + + -- Check the parsed yaml file for indexedDbKeys and generate DB rules. If failed, throw file error if configured. + dbRules <- case parsedYaml of + Left err -> do + when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) + pure [] + Right (YamlTables tables) -> pure $ (map yamlToDbRule tables) - rulesList' <- case parsedRulesYaml of + -- Check the parsed rules yaml file. If failed, throw file error if configured. + configuredRules <- case parsedRulesYaml of Left err -> do - when failOnFileNotFoundV $ addErr (mkInvalidYamlFileErr (show err)) - pure rulesListWithDbRules - Right (SheriffRules rules) -> pure $ rulesListWithDbRules <> rules - - let rulesList = filter (isAllowedOnCurrentModule moduleName') rulesList' - - exceptionList' <- case parsedExceptionsYaml of - Left err -> do - when failOnFileNotFoundV $ addErr (mkInvalidYamlFileErr (show err)) - pure exceptionRules - Right (SheriffRules rules) -> pure $ exceptionRules <> rules + when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) + pure [] + Right (SheriffRules rules) -> pure rules + + -- Check the parsed exception rules yaml file. If failed, throw file error if configured. + configuredExceptionRules <- case parsedExceptionsYaml of + Left err -> do + when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) + pure [] + Right (SheriffRules exceptionRules) -> pure exceptionRules - let exceptionList = filter (isAllowedOnCurrentModule moduleName') exceptionList' - - let rulesExceptionList = concat $ fmap getRuleExceptions rulesList - - when (logDebugInfo pluginOpts) $ liftIO $ print rulesList - when (logDebugInfo pluginOpts) $ liftIO $ print exceptionList - - let finalRules = rulesList <> exceptionList <> rulesExceptionList - - rawErrors <- concat <$> (mapM (loopOverModBinds finalRules pluginOpts) $ bagToList $ tcg_binds tcEnv) - errors <- mapM (mkCompileError moduleName') rawErrors + let rawGlobalRules = defaultSheriffRules <> dbRules <> configuredRules + globalRules = filter (isAllowedOnCurrentModule moduleName') rawGlobalRules + rawExceptionRules = defaultSheriffExceptionsRules <> configuredExceptionRules + globalExceptionRules = filter (isAllowedOnCurrentModule moduleName') rawExceptionRules + ruleLevelExceptionRules = concat $ fmap getRuleExceptions globalRules + finalSheriffRules = nub $ globalRules <> globalExceptionRules <> ruleLevelExceptionRules + isInfiniteRecursionRule r = case r of + (InfiniteRecursionRuleT rule) -> True + _ -> False + infRule = case find isInfiniteRecursionRule globalRules of + Just (InfiniteRecursionRuleT r) -> r + _ -> defaultInfiniteRecursionRuleT + + when logDebugInfo $ liftIO $ print globalRules + when logDebugInfo $ liftIO $ print globalExceptionRules + + -- STAGE-2 + -- Get Instance declarations and add class name to module name binding in initial state + insts <- tcg_insts . env_gbl <$> getEnv + let namesModTuple = concatMap (\inst -> let clsName = className (is_cls inst) in (is_dfun_name inst, clsName) : fmap (\clsMethod -> (varName clsMethod, clsName)) (classMethods $ is_cls inst)) insts + nameModMap = foldr (\(name, clsName) r -> HM.insert (NMV_Name name) (NMV_ClassModule clsName (getModuleName clsName)) r) HM.empty namesModTuple + + rawErrors <- concat <$> (mapM (loopOverModBinds finalSheriffRules) $ bagToList $ tcg_binds tcEnv) + (rawInfiniteRecursionErrors, _) <- flip runStateT nameModMap $ concat <$> (mapM (checkInfiniteRecursion True infRule) $ bagToList $ tcg_binds tcEnv) + + -- STAGE-3 + errors <- mapM (mkCompileError moduleName') (rawErrors <> rawInfiniteRecursionErrors) - let sortedErrors = sortOn src_span errors + let sortedErrors = sortBy (leftmost_smallest `on` src_span) errors groupedErrors = groupBy (\a b -> src_span a == src_span b) sortedErrors - filteredErrorsForRuleLevelExceptions = fmap (\x -> filter (\err -> not $ (getRuleExceptionsFromCompileError err) `hasAny` (fmap getRuleFromCompileError x)) x) groupedErrors - filteredErrorsForGlobalExceptions = concat $ filter (\x -> not $ (\err -> (getRuleFromCompileError err) `elem` exceptionList) `any` x) filteredErrorsForRuleLevelExceptions - filteredErrors = nub $ filter (\x -> getRuleFromCompileError x `elem` rulesList) filteredErrorsForGlobalExceptions -- Filter errors to take only rules since we might have some individual rule level errors in this list + filteredErrorsForRuleLevelExceptions = fmap (\x -> let errorRulesInCurrentGroup = fmap getRuleFromCompileError x in filter (\err -> not $ (getRuleExceptionsFromCompileError err) `hasAny` errorRulesInCurrentGroup) x) groupedErrors + filteredErrorsForGlobalExceptions = concat $ filter (\x -> not $ (\err -> (getRuleFromCompileError err) `elem` globalExceptionRules) `any` x) filteredErrorsForRuleLevelExceptions + filteredErrors = nub $ filter (\x -> getRuleFromCompileError x `elem` (InfiniteRecursionRuleT infRule : globalRules)) filteredErrorsForGlobalExceptions -- Filter errors to take only rules since we might have some individual rule level errors in this list - if throwCompilationErrorV + if throwCompilationError then addErrs $ map mkGhcCompileError filteredErrors else pure () - if saveToFileV - then addErrToFile modSummary savePathV filteredErrors + if saveToFile + then addErrToFile modSummary savePath filteredErrors else pure () return tcEnv ---------------------------- Core Logic --------------------------- +--------------------------- Infinite Recursion Detection Logic --------------------------- +{- + + 1. Check if bind is AbsBind, add a mapping from mono to poly Var and recurse for binds + 2. Check if bind is VarBind, add mappings for child HsVar to VarId and update state + 3. Check if bind is FunBind, then get the function `var` + 4. Get all the match groups from the match (One match group is single definition for a function, a function may have multiple match groups) + 5. For each match group, perform below steps: + 5.1 Get the Pattern matches + 5.2 Transform pattern matches into common type `SimpleTcExpr` + 5.3 Append function name var to the beginning to complete the transformation + 5.4 Fetch all the HsExpr from the match group's guarded rhs (includes where clause) + 5.5 Filter out FunApp from all the HsExpr + 5.6 Simplify for ($) operator and transform to `SimpleTcExpr` + 5.7 Check if any of the `SimpleTcExpr` representation of HsExpr is same as `SimpleTcExpr` representation of pattern matches + 5.8 Fetch all the FunBinds from the guarded rhs + 5.9 Recur for each fun bind and repeat from step 1 + +TODO: (Optimizations) + 1. Traverse a functions's body HsExpr once only and traverse the list/tree manually filtering based on location for local function binds + 2. Traverse AST manually, passing down all the required info as required + 3. Avoid duplicate recursion + +TODO: (Extending Patterns) + 1. Match on VarBind. PatBind and other binds + 2. Support all the Pat and HsExpr types + 3. Check if infinite recursion possible in lambda function?? + 4. Validate and handle SectionL and SectionR and partial functions and tuple sections + 5. Handle infinite lists traversals like map, fmap, etc. (represented by `ArithSeq`) + 6. Handle function renaming before enabling partial functions +-} + + +-- Function to check if the AST has deterministic infinite recursion +checkInfiniteRecursion :: (HasPluginOpts PluginOpts) => Bool -> InfiniteRecursionRule -> LHsBindLR GhcTc GhcTc -> SheriffTcM [(LHsExpr GhcTc, Violation)] +checkInfiniteRecursion recurseForBinds rule (L _ ap@(FunBind{fun_id = funVar, fun_matches = matches})) = do + currNameModMap <- get + let ?pluginOpts = ?pluginOpts {nameModuleMap = currNameModMap} + errs <- mapM (checkAndVerifyAlt recurseForBinds rule funVar) (fmap unLoc . unLoc $ mg_alts matches) + pure $ concat errs +checkInfiniteRecursion recurseForBinds rule (L _ ap@(AbsBinds{abs_binds = binds, abs_exports = bindVars})) = do + let mbVar = case bindVars of + x : _ -> Just $ (varName $ abe_poly x, varName $ abe_mono x) + _ -> Nothing + currNameModMap <- get + let updatedNameModMap = maybe currNameModMap (\(poly, mono) -> HM.insert (NMV_Name mono) (NMV_Name poly) currNameModMap) mbVar + put updatedNameModMap + list <- mapM (\x -> checkInfiniteRecursion recurseForBinds rule x) $ bagToList binds + pure $ concat list +checkInfiniteRecursion recurseForBinds rule (L loc ap@(VarBind{var_id = varId, var_rhs = rhs})) = do + let currVarName = varName varId + childHsVar = fmap varName (traverseAst rhs) + currNameModMap <- get + let updatedNameModMap = foldr (\childVar r -> HM.insert (NMV_Name childVar) (NMV_Name currVarName) r) currNameModMap childHsVar + put updatedNameModMap + pure [] +checkInfiniteRecursion _ _ _ = pure [] + +-- Helper function to verify if any of the body HsExpr results in infinite recursion +checkAndVerifyAlt :: (HasPluginOpts PluginOpts) => Bool -> InfiniteRecursionRule -> LIdP GhcTc -> Match GhcTc (LHsExpr GhcTc) -> SheriffTcM [(LHsExpr GhcTc, Violation)] +checkAndVerifyAlt recurseForBinds rule ap@(L loc fnVar) match = do + let currentFnNameWithModule = getVarNameWithModuleName fnVar + ignoredFunctions = infinite_recursion_rule_ignore_functions rule + skipCurrentFn = any (\ignoredFnName -> matchNamesWithModuleName currentFnNameWithModule ignoredFnName AsteriskInSecond) ignoredFunctions + currentFnErrors <- case skipCurrentFn of + True -> pure [] + False -> do + let argsInFnDefn = m_pats match + trfArgsInFnDefn = fmap (trfPatToSimpleTcExpr . unLoc) argsInFnDefn + fnVarTyp <- lift $ getHsExprTypeWithResolver False (mkLHsVar $ getLocated ap loc) + let finalTrfArgsInFnDefn = (SimpleFnNameVar fnVar fnVarTyp) : trfArgsInFnDefn + argLenByTy = length (getHsExprTypeAsTypeDataList fnVarTyp) - 1 + argLenByFnDefn = length argsInFnDefn + grhssList = grhssGRHSs (m_grhss match) + when (logDebugInfo . pluginOpts $ ?pluginOpts) $ + liftIO $ do + putStrLn (showS loc <> " :: " <> showS fnVar) >> putStrLn "***" + print (getHsExprTypeAsTypeDataList fnVarTyp) >> putStrLn "***" + putStrLn "******" + concatMapM (checkGrhSS finalTrfArgsInFnDefn argLenByTy argLenByFnDefn) grhssList + + -- Process sub binds if further recursion allowed + subBindsErrors <- + if recurseForBinds + then do + let (subBinds :: [LHsBindLR GhcTc GhcTc]) = traverseAst (m_grhss match) + concat <$> mapM (checkInfiniteRecursion False rule) subBinds + else pure [] + + pure (currentFnErrors <> subBindsErrors) + where + checkGrhSS :: HasPluginOpts PluginOpts => [SimpleTcExpr] -> Int -> Int -> LGRHS GhcTc (LHsExpr GhcTc) -> SheriffTcM [(LHsExpr GhcTc, Violation)] + checkGrhSS finalTrfArgsInFnDefn argLenByTy argLenByFnDefn grhss = do + let lenDiff = argLenByTy - argLenByFnDefn + isPartialFn = lenDiff > 0 + if isPartialFn then case getMaybeLambdaCaseOrLambdaMG grhss of + Just mg -> do + let matches = map unLoc . unLoc $ mg_alts mg + flip concatMapM matches $ \match -> do + let argsInFnDefn = m_pats match + trfArgsInFnDefn = fmap (trfPatToSimpleTcExpr . unLoc) argsInFnDefn + updatedFinalTrfArgsInFnDefn = finalTrfArgsInFnDefn <> trfArgsInFnDefn + updatedArgLenByFnDefn = argLenByFnDefn + length argsInFnDefn + grhssList = grhssGRHSs (m_grhss match) + concatMapM (checkGrhSS updatedFinalTrfArgsInFnDefn argLenByTy updatedArgLenByFnDefn) grhssList + Nothing -> processLHsExprInGrhs (getLastStmt grhss) finalTrfArgsInFnDefn + else processLHsExprInGrhs (traverseAst grhss) finalTrfArgsInFnDefn + + processLHsExprInGrhs :: HasPluginOpts PluginOpts => [LHsExpr GhcTc] -> [SimpleTcExpr] -> SheriffTcM [(LHsExpr GhcTc, Violation)] + processLHsExprInGrhs hsExprs finalTrfArgsInFnDefn = do + let (funApps :: [LHsExpr GhcTc]) = filter (isFunApp True) hsExprs + (simplifiedFnApps :: [(Located Var, (LHsExpr GhcTc, LHsExpr GhcTc, [LHsExpr GhcTc]))]) = HM.toList $ foldr (\x r -> maybe r (\(lVar, typ, args) -> HM.insertWith (\(x1, e1, newArgs) (x2, e2, oldArgs) -> if length newArgs >= length oldArgs then (x1, e1, newArgs) else (x2, e2, oldArgs)) lVar (x, typ, args) r) $ getFnNameAndTypeableExprWithAllArgs x) HM.empty funApps + (trfSimplifiedFunApps :: [(LHsExpr GhcTc, [SimpleTcExpr])]) <- mapM trfSimplifiedFunApp simplifiedFnApps + let currentGrhssErrors = map (\(lhsExpr, _) -> (lhsExpr, InfiniteRecursionDetected rule)) $ filter (\x -> (snd x === finalTrfArgsInFnDefn)) trfSimplifiedFunApps + when (logDebugInfo . pluginOpts $ ?pluginOpts) $ + liftIO $ do + let tyL = concat $ fmap (foldr (\x r -> case x of; SimpleFnNameVar v ty -> (v, ty) : r; _ -> r;) [] . snd) trfSimplifiedFunApps + mapM (\(v, t) -> putStrLn $ showS v <> " ::: " <> show (getHsExprTypeAsTypeDataList t)) tyL + putStrLn "******" + pure currentGrhssErrors + + {- + Assumption for infinite recursion in partial function case: + 1. Function composition , e.g. - fn a = fn1 x . fn2 y . fn a + 2. Let-in statement, and infinite recursion is inside `in` statement + 3. Do statement, infinite recursion can be in last statement + 4. Straight HsApp case or HsVar case + + TODO: + 1. Correct the assumption and handle more possible cases + -} + getLastStmt :: LGRHS GhcTc (LHsExpr GhcTc) -> [LHsExpr GhcTc] + getLastStmt (L _ (GRHS _ _ lExpr)) = checkAndGetExpr lExpr +#if __GLASGOW_HASKELL__ < 900 + getLastStmt (L _ (XGRHS _)) = [] +#endif + + checkAndGetExpr :: LHsExpr GhcTc -> [LHsExpr GhcTc] + checkAndGetExpr (L loc expr) = case expr of + HsLet _ _ inStmt -> checkAndGetExpr inStmt + HsApp _ _ _ -> [L loc expr] + HsVar _ _ -> [L loc expr] + HsDo _ _ doStmts -> concatMap checkAndGetExpr $ foldr isLastStmt [] (traverseAst doStmts) + PatHsWrap wrapper wrapExpr -> fmap (\(L _ trfExpr) -> L loc (PatHsWrap wrapper trfExpr)) $ checkAndGetExpr (L loc wrapExpr) + OpApp _ lApp op rApp -> case showS op of + "(.)" -> checkAndGetExpr rApp + _ -> [] +#if __GLASGOW_HASKELL__ >= 900 + PatHsExpansion (OpApp _ _ op _) (HsApp _ lApp rApp) -> case showS op of + "(.)" -> checkAndGetExpr rApp + _ -> [] +#endif + _ -> [] + + isLastStmt :: ExprLStmt GhcTc -> [LHsExpr GhcTc] -> [LHsExpr GhcTc] + isLastStmt (L _ (LastStmt _ lexpr _ _)) rem = lexpr : rem + isLastStmt _ rem = rem + + getMaybeLambdaCaseOrLambdaMG :: LGRHS GhcTc (LHsExpr GhcTc) -> Maybe (MatchGroup GhcTc (LHsExpr GhcTc)) + getMaybeLambdaCaseOrLambdaMG grhs = case grhs of + (L _ (GRHS _ _ lExpr)) -> checkAndGetMaybeLambdaCaseOrLambdaMG lExpr +#if __GLASGOW_HASKELL__ < 900 + (L _ (XGRHS _)) -> Nothing +#endif + + checkAndGetMaybeLambdaCaseOrLambdaMG :: LHsExpr GhcTc -> Maybe (MatchGroup GhcTc (LHsExpr GhcTc)) + checkAndGetMaybeLambdaCaseOrLambdaMG (L loc expr) = case expr of + (HsLamCase _ mg) -> Just mg -- LambdaCase + (HsLam _ mg) -> Just mg -- Lambda Function + (HsLet _ _ inStmt) -> checkAndGetMaybeLambdaCaseOrLambdaMG inStmt + PatHsWrap _ wrapExpr -> checkAndGetMaybeLambdaCaseOrLambdaMG (L loc wrapExpr) + OpApp _ _ op rApp -> case showS op of + "(.)" -> checkAndGetMaybeLambdaCaseOrLambdaMG rApp + _ -> Nothing +#if __GLASGOW_HASKELL__ >= 900 + PatHsExpansion (OpApp _ _ op _) (HsApp _ _ rApp) -> case showS op of + "(.)" -> checkAndGetMaybeLambdaCaseOrLambdaMG rApp + _ -> Nothing +#endif + _ -> Nothing --- FLOW : --- Perform steps for each top level function binding in a module --- 1. Extract the value bindings & fun bindings inside the definition --- 2. Extract the function arguments --- 3. Get all the FunApp in the definition --- 4. Check and return if the FunApp is Logging Function and corresponding value/description argument is Text --- 5. Now, we have function application Var and corresponding arg to be checked --- 6. Check if the arg has any stringification function, If yes, it is `ErrorCase` --- 7. Check if the arg uses any local binding from WhereBinds or Normal Bind, If yes, then check if that binding has any stringification output --- 8. Check if arg uses top level binding from any module. If yes, then check if that binding has any stringification output + trfSimplifiedFunApp :: (Located Var, (LHsExpr GhcTc, LHsExpr GhcTc, [LHsExpr GhcTc])) -> SheriffTcM (LHsExpr GhcTc, [SimpleTcExpr]) + trfSimplifiedFunApp (lVar, (lHsExpr, typLHsExpr, lHsExprArgsLs)) = do + typ <- lift $ dropForAlls <$> (getHsExprTypeWithResolver False) typLHsExpr + pure (lHsExpr, SimpleFnNameVar (unLoc lVar) typ : (fmap trfLHsExprToSimpleTcExpr lHsExprArgsLs)) -- Loop over top level function binds -loopOverModBinds :: Rules -> PluginOpts -> LHsBindLR GhcTc GhcTc -> TcM [(LHsExpr GhcTc, Violation)] -loopOverModBinds rules opts (L _ ap@(FunBind _ id matches _ _)) = do +loopOverModBinds :: (HasPluginOpts PluginOpts) => Rules -> LHsBindLR GhcTc GhcTc -> TcM [(LHsExpr GhcTc, Violation)] +loopOverModBinds rules (L _ ap@(FunBind{fun_id = (L _ funVar)})) = do -- liftIO $ print "FunBinds" >> showOutputable ap - badCalls <- getBadFnCalls rules opts ap + let currentFnNameWithModule = getVarNameWithModuleName funVar + filteredRulesForFunction = filter (isAllowedOnCurrentFunction currentFnNameWithModule) rules + badCalls <- getBadFnCalls filteredRulesForFunction ap pure badCalls -loopOverModBinds _ _ (L _ ap@(PatBind _ _ pat_rhs _)) = do +loopOverModBinds _ (L _ ap@(PatBind{})) = do -- liftIO $ print "PatBinds" >> showOutputable ap pure [] -loopOverModBinds _ _ (L _ ap@(VarBind {var_rhs = rhs})) = do +loopOverModBinds _ (L _ ap@(VarBind{})) = do -- liftIO $ print "VarBinds" >> showOutputable ap pure [] -loopOverModBinds rules opts (L _ ap@(AbsBinds {abs_binds = binds})) = do +loopOverModBinds rules (L _ ap@(AbsBinds {abs_binds = binds})) = do -- liftIO $ print "AbsBinds" >> showOutputable ap - list <- mapM (loopOverModBinds rules opts) $ bagToList binds + list <- mapM (loopOverModBinds rules) $ bagToList binds pure (concat list) -loopOverModBinds _ _ _ = pure [] +loopOverModBinds _ _ = pure [] -- Get all the FunApps inside the top level function bind -- This call can be anywhere in `where` clause or `regular` RHS -getBadFnCalls :: Rules -> PluginOpts -> HsBindLR GhcTc GhcTc -> TcM [(LHsExpr GhcTc, Violation)] -getBadFnCalls rules opts (FunBind _ id matches _ _) = do +getBadFnCalls :: (HasPluginOpts PluginOpts) => Rules -> HsBindLR GhcTc GhcTc -> TcM [(LHsExpr GhcTc, Violation)] +getBadFnCalls rules (FunBind{fun_matches = matches}) = do let funMatches = map unLoc $ unLoc $ mg_alts matches concat <$> mapM getBadFnCallsHelper funMatches where getBadFnCallsHelper :: Match GhcTc (LHsExpr GhcTc) -> TcM [(LHsExpr GhcTc, Violation)] getBadFnCallsHelper match = do - let whereBinds = (grhssLocalBinds $ m_grhss match) ^? biplateRef :: [LHsBinds GhcTc] - normalBinds = (grhssGRHSs $ m_grhss match) ^? biplateRef :: [LHsBinds GhcTc] + let whereBinds = traverseAst (grhssLocalBinds $ m_grhss match) :: [LHsBinds GhcTc] + normalBinds = traverseAst (grhssGRHSs $ m_grhss match) :: [LHsBinds GhcTc] argBinds = m_pats match -- exprs = match ^? biplateRef :: [LHsExpr GhcTc] -- use childrenBi and then repeated children usage as per use case - exprs = traverseConditionalUni (noWhereClauseExpansion) (childrenBi match :: [LHsExpr GhcTc]) - concat <$> mapM (isBadFunApp rules opts) exprs -getBadFnCalls _ _ _ = pure [] - --- Takes a predicate which return true if further expansion is not required, false otherwise -traverseConditionalUni :: (Data a) => (a -> Bool) -> [a] -> [a] -traverseConditionalUni _ [] = [] -traverseConditionalUni p (x : xs) = - if p x - then x : traverseConditionalUni p xs - else (x : traverseConditionalUni p (children x)) <> traverseConditionalUni p xs - -noGivenFunctionCallExpansion :: String -> LHsExpr GhcTc -> Bool -noGivenFunctionCallExpansion fnName expr = case expr of - (L loc (HsWrap _ _ expr)) -> noWhereClauseExpansion (L loc expr) - _ -> case getFnNameWithAllArgs expr of - Just (lVar, _) -> (getOccString . varName . unLoc $ lVar) == fnName - Nothing -> False + -- (exprs :: [LHsExpr GhcTc]) = traverseConditionalUni (noWhereClauseExpansion) (childrenBi match :: [LHsExpr GhcTc]) + (exprs :: [LHsExpr GhcTc]) = traverseAstConditionally match noWhereClauseExpansion + concat <$> mapM (isBadExpr rules) exprs +getBadFnCalls _ _ = pure [] +-- Do not expand sequelize `where` clause further noWhereClauseExpansion :: LHsExpr GhcTc -> Bool noWhereClauseExpansion expr = case expr of - (L loc (HsWrap _ _ expr)) -> noWhereClauseExpansion (L loc expr) - (L _ (ExplicitList (TyConApp ty _) _ _)) -> showS ty == "Clause" + (L loc (PatHsWrap _ expr)) -> noWhereClauseExpansion (L loc expr) + (L _ (PatExplicitList (TyConApp ty _) _)) -> showS ty == "Clause" _ -> False -isBadFunApp :: Rules -> PluginOpts -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] -isBadFunApp rules opts ap@(L _ (HsVar _ v)) = isBadFunAppHelper rules opts ap -isBadFunApp rules opts ap@(L _ (HsApp _ funl funr)) = isBadFunAppHelper rules opts ap -isBadFunApp rules opts ap@(L loc (HsWrap _ _ expr)) = isBadFunApp rules opts (L loc expr) >>= mapM (\(x, y) -> trfViolationErrorInfo opts y ap x >>= \z -> pure (x, z)) -isBadFunApp rules opts ap@(L _ (ExplicitList _ _ _)) = isBadFunAppHelper rules opts ap -isBadFunApp rules opts ap@(L loc (OpApp _ lfun op rfun)) = do - case showS op of - "($)" -> isBadFunAppHelper rules opts (L loc (HsApp noExtField lfun rfun)) >>= mapM (\(x, y) -> trfViolationErrorInfo opts y ap x >>= \z -> pure (x, z)) - _ -> pure [] -isBadFunApp _ _ _ = pure [] +-- Takes a function name which should not be expanded further while traversing AST +noGivenFunctionCallExpansion :: (HasPluginOpts a) => String -> LHsExpr GhcTc -> Bool +noGivenFunctionCallExpansion fnName expr = case expr of + (L loc (PatHsWrap _ expr)) -> noGivenFunctionCallExpansion fnName (L loc expr) + _ -> case getFnNameWithAllArgs expr of + Just (lVar, _) -> matchNamesWithModuleName (getLocatedVarNameWithModuleName lVar) fnName AsteriskInSecond -- (getOccString . varName . unLoc $ lVar) == fnName + Nothing -> False -isBadFunAppHelper :: Rules -> PluginOpts -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] -isBadFunAppHelper rules opts ap = concat <$> mapM (\rule -> checkAndApplyRule rule opts ap) rules +-- Simplifies few things and handles some final transformations +isBadExpr :: (HasPluginOpts PluginOpts) => Rules -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] +isBadExpr rules ap@(L _ (HsVar _ v)) = isBadExprHelper rules ap +isBadExpr rules ap@(L _ (HsApp _ funl funr)) = isBadExprHelper rules ap +isBadExpr rules ap@(L _ (PatExplicitList _ _)) = isBadExprHelper rules ap +isBadExpr rules ap@(L loc (PatHsWrap _ expr)) = isBadExpr rules (L loc expr) >>= mapM (\(x, y) -> trfViolationErrorInfo y ap x >>= \z -> pure (x, z)) +isBadExpr rules ap@(L loc (OpApp _ lfun op rfun)) = do + case showS op of + "($)" -> isBadExpr rules (L loc (HsApp noExtFieldOrAnn lfun rfun)) >>= mapM (\(x, y) -> trfViolationErrorInfo y ap x >>= \z -> pure (x, z)) + _ -> isBadExprHelper rules ap +#if __GLASGOW_HASKELL__ >= 900 +isBadExpr rules ap@(L loc (PatHsExpansion orig expanded)) = do + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of + "($)" -> isBadExpr rules (L loc (HsApp noExtFieldOrAnn funl funr)) >>= mapM (\(x, y) -> trfViolationErrorInfo y ap x >>= \z -> pure (x, z)) + _ -> isBadExpr rules (L loc expanded) + _ -> isBadExpr rules (L loc expanded) +#endif +isBadExpr rules ap = pure [] + +-- Calls checkAndApplyRule, can be used to directly call without simplifier if needed +isBadExprHelper :: (HasPluginOpts PluginOpts) => Rules -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] +isBadExprHelper rules ap = concat <$> mapM (\rule -> checkAndApplyRule rule ap) rules + +-- Check if a particular rule applies to given expr +checkAndApplyRule :: (HasPluginOpts PluginOpts) => Rule -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) +checkAndApplyRule ruleT ap = case ruleT of + DBRuleT rule@(DBRule {table_name = ruleTableName}) -> + case ap of + (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do + case (showS ty == "Clause" && showS tblName == (ruleTableName <> "T")) of + True -> validateDBRule rule (showS tblName) exprs ap + False -> pure [] + _ -> pure [] + FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do + let res = getFnNameWithAllArgs ap + case res of + Nothing -> pure [] + Just (fnLocatedVar, args) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + fnLHsExpr = mkLHsVar fnLocatedVar + case (find (\ruleFnName -> matchNamesWithModuleName fnName ruleFnName AsteriskInSecond && length args >= arg_no) ruleFnNames) of + Just ruleFnName -> validateFunctionRule rule ruleFnName fnName fnLHsExpr args ap + Nothing -> pure [] + InfiniteRecursionRuleT rule -> pure [] --TODO: Add handling of infinite recursion rule + GeneralRuleT rule -> pure [] --TODO: Add handling of general rule -validateFunctionRule :: FunctionRule -> PluginOpts -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -validateFunctionRule rule opts fnName args expr = do - if (arg_no rule) == 0 -- considering arg 0 as the case for blocking the whole function occurence - then pure [(expr, FnUseBlocked rule)] +--------------------------- Function Rule Validation Logic --------------------------- +{- + +Part-1 Checking Applicability + 1. Get function name and arguments list + 2. Check if function name matches with required name + 3. Check if function has more than or equal number of arguments than required as per rule + +Part-2 Validation + 1. Check if argument number in rule is 0, then the use of function is not allowed in code. + 2. Extract the required argument as per rule from the argument list. + 3. Get the type of required argument + 4. Check if argument type is in the blocked types list as per rule, then the use of this argument type is not allowed + 5. Check if argument type is in the to_be_checked types list as per rule, then check for function type blocked + 5.1 Extract the list of all function application inside the argument + 5.2 For each function application, get the function name and list of arguments + 5.3 For each function name match, if the type of required argument is not in exception list, then it is a Function Blocked in argument violation. + +-} + +-- Function to check if given function rule is violated or not +validateFunctionRule :: (HasPluginOpts PluginOpts) => FunctionRule -> String -> String -> LHsExpr GhcTc -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) +validateFunctionRule rule ruleFnName fnName fnNameExpr args expr = do + if arg_no rule == 0 && fn_sigs_blocked rule == [] -- considering arg 0 as the case for blocking the whole function occurence + then pure [(fnNameExpr, FnUseBlocked ruleFnName rule)] + else if arg_no rule == 0 + then do + -- Check argument types for functions with polymorphic signature + argTyps <- concat <$> mapM (\arg -> getHsExprTypeAsTypeDataList <$> getHsExprTypeWithResolver (logTypeDebugging . pluginOpts $ ?pluginOpts) arg) args + fnReturnType <- getHsExprTypeAsTypeDataList <$> getHsExprTypeWithResolver (logTypeDebugging . pluginOpts $ ?pluginOpts) expr + let fnSigFromArg = argTyps <> fnReturnType + + -- Given function signature + fnExprTyp <- getHsExprTypeWithResolver (logTypeDebugging . pluginOpts $ ?pluginOpts) fnNameExpr + let fnSigTypList = getHsExprTypeAsTypeDataList fnExprTyp + + pure . concat $ fmap (\ruleFnSig -> if matchFnSignatures fnSigTypList ruleFnSig || matchFnSignatures fnSigFromArg ruleFnSig then [(fnNameExpr, FnSigBlocked fnName ruleFnSig rule)] else []) (fn_sigs_blocked rule) else do let matches = drop ((arg_no rule) - 1) args if length matches == 0 then pure [] else do let arg = head matches - argTypeGhc <- getHsExprTypeWithResolver opts arg + argTypeGhc <- getHsExprTypeWithResolver (logTypeDebugging . pluginOpts $ ?pluginOpts) arg let argType = showS argTypeGhc argTypeBlocked = validateType argTypeGhc $ types_blocked_in_arg rule isArgTypeToCheck = validateType argTypeGhc $ types_to_check_in_arg rule - when (logDebugInfo opts && fnName /= "NA") $ + when ((logDebugInfo . pluginOpts $ ?pluginOpts) && fnName /= "NA") $ liftIO $ do print $ (fnName, map showS args) print $ (fnName, showS arg) - print $ rule + print $ fn_rule_name rule print $ "Arg Type = " <> argType if argTypeBlocked then do - exprType <- getHsExprTypeWithResolver opts expr - pure [(expr, ArgTypeBlocked argType (showS exprType) rule)] - else if not isArgTypeToCheck - then pure [] - else do - -- It's a rule function with to_be_checked type argument - blockedFnsList <- getBlockedFnsList opts arg rule -- check if the expression has any stringification function - mapM (\(lExpr, blockedFnName, blockedFnArgTyp) -> mkFnBlockedInArgErrorInfo opts expr lExpr >>= \errorInfo -> pure (lExpr, FnBlockedInArg (blockedFnName, blockedFnArgTyp) errorInfo rule)) blockedFnsList + exprType <- getHsExprTypeWithResolver (logTypeDebugging . pluginOpts $ ?pluginOpts) expr + pure [(expr, ArgTypeBlocked argType (showS exprType) ruleFnName rule)] + else if isArgTypeToCheck + then do + blockedFnsList <- getBlockedFnsList arg rule -- check if the expression has any stringification function + mapM (\(lExpr, blockedFnName, blockedFnArgTyp) -> mkFnBlockedInArgErrorInfo expr lExpr >>= \errorInfo -> pure (lExpr, FnBlockedInArg (blockedFnName, blockedFnArgTyp) ruleFnName errorInfo rule)) blockedFnsList + else pure [] +-- Helper to validate types based on custom types present in the rules -- tuples, list, maybe validateType :: Type -> TypesToCheckInArg -> Bool validateType argTyp@(TyConApp tyCon ls) typs = - if showS tyCon == "(,)" && "(,)" `elem` typs - then (\t -> validateType t typs) `any` ls - else if showS tyCon == "[]" && "[]" `elem` typs - then (\t -> validateType t typs) `any` ls - else if showS tyCon == "Maybe" && "Maybe" `elem` typs - then (\t -> validateType t typs) `any` ls + let tyConStr = showS tyCon in + if tyConStr `elem` typs + then case tyConStr of + "(,)" -> (\t -> validateType t typs) `any` ls + "[]" -> (\t -> validateType t typs) `any` ls + "Maybe" -> (\t -> validateType t typs) `any` ls + _ -> showS argTyp `elem` typs else showS argTyp `elem` typs validateType argTyp typs = showS argTyp `elem` typs -validateDBRule :: DBRule -> PluginOpts -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -validateDBRule rule@(DBRule ruleName ruleTableName ruleColNames _ _) opts tableName clauses expr = do - simplifiedExprs <- trfWhereToSOP opts clauses - let checkDBViolation = case (matchAllInsideAnd opts) of +-- Get List of blocked functions used inside a HsExpr; Uses `getBlockedFnsList` +getBlockedFnsList :: (HasPluginOpts PluginOpts) => LHsExpr GhcTc -> FunctionRule -> TcM [(LHsExpr GhcTc, String, String)] +getBlockedFnsList arg rule@(FunctionRule { arg_no, fns_blocked_in_arg = fnsBlocked }) = do + let argHsExprs = traverseAst arg :: [LHsExpr GhcTc] + fnApps = filter (isFunApp False) argHsExprs + when ((logDebugInfo . pluginOpts $ ?pluginOpts)) $ liftIO $ do + print "getBlockedFnsList" + showOutputable arg + showOutputable fnApps + catMaybes <$> mapM checkFnBlockedInArg fnApps + where + checkFnBlockedInArg :: LHsExpr GhcTc -> TcM (Maybe (LHsExpr GhcTc, String, String)) + checkFnBlockedInArg expr = do + let res = getFnNameWithAllArgs expr + when ((logDebugInfo . pluginOpts $ ?pluginOpts)) $ liftIO $ do + print "checkFnBlockedInArg" + showOutputable res + case res of + Nothing -> pure Nothing + Just (fnNameVar, args) -> isPresentInBlockedFnList expr fnsBlocked (getLocatedVarNameWithModuleName fnNameVar) args + + isPresentInBlockedFnList :: LHsExpr GhcTc -> FnsBlockedInArg -> String -> [LHsExpr GhcTc] -> TcM (Maybe (LHsExpr GhcTc, String, String)) + isPresentInBlockedFnList expr [] _ _ = pure Nothing + isPresentInBlockedFnList expr ((ruleFnName, ruleArgNo, ruleAllowedTypes) : ls) fnName fnArgs = do + when ((logDebugInfo . pluginOpts $ ?pluginOpts)) $ liftIO $ do + print "isPresentInBlockedFnList" + print (ruleFnName, ruleArgNo, ruleAllowedTypes) + case matchNamesWithModuleName fnName ruleFnName AsteriskInSecond && length fnArgs >= ruleArgNo of + False -> isPresentInBlockedFnList expr ls fnName fnArgs + True -> do + let reqArg = head $ drop (ruleArgNo - 1) fnArgs + argType <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) reqArg + when ((logDebugInfo . pluginOpts $ ?pluginOpts)) $ liftIO $ do + showOutputable reqArg + showOutputable argType + if validateAllowedTypes argType ruleAllowedTypes + then isPresentInBlockedFnList expr ls fnName fnArgs + else pure $ Just (expr, fnName, showS argType) + + validateAllowedTypes :: Type -> TypesAllowedInArg -> Bool + validateAllowedTypes argType@(TyConApp tyCon ls) ruleAllowedTypes = + if showS tyCon == "(,)" && "(,)" `elem` ruleAllowedTypes + then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls + else if showS tyCon == "[]" && "[]" `elem` ruleAllowedTypes + then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls + else if showS tyCon == "Maybe" && "Maybe" `elem` ruleAllowedTypes + then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls + else (isEnumType argType && "EnumTypes" `elem` ruleAllowedTypes) || (showS argType) `elem` ruleAllowedTypes + validateAllowedTypes argType ruleAllowedTypes = (isEnumType argType && "EnumTypes" `elem` ruleAllowedTypes) || (showS argType) `elem` ruleAllowedTypes + +--------------------------- DB Rule Validation Logic --------------------------- +{- + +Part-1 Checking Applicability + 1. Check if the expression is an explicit list, basically a hard coded list + 2. Extract the type of the explicit list and from it further extract actual type (i.e. `Clause`) and table name + 3. Check if extracted table name matches the given table name + +Part-2 Validation + 1. Simplify the given list as SOP form (OR of AND) + 1.1 For each non breakable clause (i.e. `Se.Is`), get the field name and table name + 1.1.1 Get DB field specifier, i.e. how field is being extracted - Lens, Selector, RecordDot + 1.1.2 Extract data according to the way we have field written + 1.2 For each OR, cross product the result of current list and remaining result, creating individual lists for each clause in current list + 1.3 For each AND, cross product the list of current list and remaining result, simplifying each element of the current list + 2. Check whether we want to match only 1st column in composite key or all columns of composite key + 3. For each AND clause in SOP, perform 4 or 5, if all AND clause are indexed, then only this query is indexed + 4. Case when we want to check only 1st column of composite key + 3.1 If any of the field in AND clause does not violate the rule, then we mark this AND clause as Indexed + 3.2 For each field of AND clause, perform following for all rule keys,: + 3.2.1 If it is a composite key, we check if the 1st column of composite key matches the current field, then consider this current field as indexed + 3.2.2 If it is non-composite key, then we directly compare and if it matches, then consider this current field as indexed + 5. Case when we want to check all columns of composite key + 5.1 If any of the field in AND clause violate the rule, then we mark this AND clause as Non-Indexed + 5.2 For each field of AND clause, perform following for all rule keys,: + 3.2.1 If it is a composite key, and if it current column of composite key and it matches the current field of AND, then consider this current field as indexed + 3.2.2 If it is a composite key, and if it does not match, then if it is present in overall AND clause fields, then we skip this and check next composite key column. + 3.2.2 If it is non-composite key, then we directly compare and if it matches, then consider this current field as indexed + +-} +-- Function to check if given DB rules is violated or not +-- TODO: Fix this, keep two separate options for - 1. Match All Fields in AND 2. Use 1st column matching or all columns matching for composite key +validateDBRule :: (HasPluginOpts PluginOpts) => DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) +validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do + simplifiedExprs <- trfWhereToSOP clauses + let checkDBViolation = case (matchAllInsideAnd . pluginOpts $ ?pluginOpts) of True -> checkDBViolationMatchAll False -> checkDBViolationWithoutMatchAll violations <- catMaybes <$> mapM checkDBViolation simplifiedExprs @@ -396,77 +701,85 @@ doesMatchColNameInDbRule colName (key : keys) = type SimplifiedIsClause = (LHsExpr GhcTc, String, String) -trfWhereToSOP :: PluginOpts -> [LHsExpr GhcTc] -> TcM [[SimplifiedIsClause]] -trfWhereToSOP _ [] = pure [[]] -trfWhereToSOP opts (clause : ls) = do +-- Simplify the complex `where` clause of SQL queries as OR queries at top (i.e. ((C1 and C2 and C3) OR (C1 AND C5) OR (C6))) +trfWhereToSOP :: (HasPluginOpts PluginOpts) => [LHsExpr GhcTc] -> TcM [[SimplifiedIsClause]] +trfWhereToSOP [] = pure [[]] +trfWhereToSOP (clause : ls) = do let res = getWhereClauseFnNameWithAllArgs clause (fnName, args) = fromMaybe ("NA", []) res case (fnName, args) of - ("And", [(L _ (ExplicitList _ _ arg))]) -> do - curr <- trfWhereToSOP opts arg - rem <- trfWhereToSOP opts ls + ("And", [(L _ (PatExplicitList _ arg))]) -> do + curr <- trfWhereToSOP arg + rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] - ("Or", [(L _ (ExplicitList _ _ arg))]) -> do - curr <- foldM (\r cls -> fmap (<> r) $ trfWhereToSOP opts [cls]) [] arg - rem <- trfWhereToSOP opts ls + ("Or", [(L _ (PatExplicitList _ arg))]) -> do + curr <- foldM (\r cls -> fmap (<> r) $ trfWhereToSOP [cls]) [] arg + rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] ("$WIs", [arg1, arg2]) -> do - curr <- getIsClauseData opts arg1 arg2 clause - rem <- trfWhereToSOP opts ls + curr <- getIsClauseData arg1 arg2 clause + rem <- trfWhereToSOP ls case curr of Nothing -> pure rem Just (tblName, colName) -> pure $ fmap (\lst -> (clause, tblName, colName) : lst) rem - (fn, _) -> when (logWarnInfo opts) (liftIO $ print $ "Invalid/unknown clause in `where` clause : " <> fn <> " at " <> (showS . getLoc $ clause)) >> trfWhereToSOP opts ls + (fn, _) -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print $ "Invalid/unknown clause in `where` clause : " <> fn <> " at " <> (showS . getLoc2 $ clause)) >> trfWhereToSOP ls -getIsClauseData :: PluginOpts -> LHsExpr GhcTc -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM (Maybe (String, String)) -getIsClauseData opts fieldArg _comp _clause = do +-- Get table field name and table name for the `Se.Is` clause +-- Patterns to match 'getField`, `recordDot`, `overloadedRecordDot` (ghc > 9), selector (duplicate record fields), rec fields (ghc 9), lens +-- TODO: Refactor this to use HasField instance if possible +getIsClauseData :: (HasPluginOpts PluginOpts) => LHsExpr GhcTc -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM (Maybe (String, String)) +getIsClauseData fieldArg _comp _clause = do let fieldSpecType = getDBFieldSpecType fieldArg mbColNameAndTableName <- case fieldSpecType of - None -> when (logWarnInfo opts) (liftIO $ print "Can't identify the way in which DB field is specified") >> pure Nothing + None -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print $ "Can't identify the way in which DB field is specified: " <> showS fieldArg) >> pure Nothing Selector -> do - case (splitOn ":" $ showS fieldArg) of + let modFieldArg arg = case arg of + (L _ (HsRecFld _ fldOcc)) -> showS $ selectorAmbiguousFieldOcc fldOcc + (L loc (PatHsWrap _ wExpr)) -> modFieldArg (L loc wExpr) + (L _ expr) -> showS expr + case (splitOn ":" $ modFieldArg fieldArg) of ("$sel" : colName : tableName : []) -> pure $ Just (colName, tableName) - _ -> when (logWarnInfo opts) (liftIO $ print "Invalid pattern for Selector way") >> pure Nothing + _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print "Invalid pattern for Selector way") >> pure Nothing RecordDot -> do let tyApps = filter (\x -> case x of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> True - (HsWrap _ (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableVar))) (HsAppType _ _ fldName)) -> True + (PatHsWrap (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableVar))) (HsAppType _ _ fldName)) -> True _ -> False - ) $ (fieldArg ^? biplateRef :: [HsExpr GhcTc]) + ) $ (traverseAst fieldArg :: [HsExpr GhcTc]) if length tyApps > 0 then case head tyApps of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> do - typ <- getHsExprType opts tableVar + typ <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) tableVar let tblName' = case typ of AppTy ty1 _ -> showS ty1 TyConApp ty1 _ -> showS ty1 ty -> showS ty pure $ Just (getStrFromHsWildCardBndrs fldName, take (length tblName' - 1) tblName') - (HsWrap _ (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableType))) (HsAppType _ _ fldName)) -> + (PatHsWrap (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableType))) (HsAppType _ _ fldName)) -> let tblName' = case tableType of AppTy ty1 _ -> showS ty1 TyConApp ty1 _ -> showS ty1 ty -> showS ty in pure $ Just (getStrFromHsWildCardBndrs fldName, take (length tblName' - 1) tblName') - _ -> when (logWarnInfo opts) (liftIO $ putStrLn "HsAppType not present. Should never be the case as we already filtered.") >> pure Nothing - else when (logWarnInfo opts) (liftIO $ putStrLn "HsAppType not present after filtering. Should never reach as already deduced RecordDot.") >> pure Nothing + _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present. Should never be the case as we already filtered.") >> pure Nothing + else when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present after filtering. Should never reach as already deduced RecordDot.") >> pure Nothing Lens -> do - let opApps = filter isLensOpApp (fieldArg ^? biplateRef :: [HsExpr GhcTc]) + let opApps = filter isLensOpApp (traverseAst fieldArg :: [HsExpr GhcTc]) case opApps of - [] -> when (logWarnInfo opts) (liftIO $ putStrLn "No lens operator application present in lens case.") >> pure Nothing + [] -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "No lens operator application present in lens case.") >> pure Nothing (opExpr : _) -> do case opExpr of (OpApp _ tableVar _ fldVar) -> do let fldName = tail $ showS fldVar - typ <- getHsExprType opts tableVar + typ <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) tableVar let tblName' = case typ of AppTy ty1 _ -> showS ty1 TyConApp ty1 _ -> showS ty1 ty -> showS ty pure $ Just (fldName, take (length tblName' - 1) tblName') (SectionR _ _ (L _ lens)) -> do - let tys = lens ^? biplateRef :: [Type] + let tys = traverseAst lens :: [Type] typeForTableName = filter (\typ -> case typ of (TyConApp typ1 [typ2]) -> ("T" `isSuffixOf` showS typ1) && (showS typ2 == "Columnar' f") (AppTy typ1 typ2) -> ("T" `isSuffixOf` showS typ1) && (showS typ2 == "Columnar' f") @@ -477,39 +790,38 @@ getIsClauseData opts fieldArg _comp _clause = do TyConApp ty1 _ -> showS ty1 ty -> showS ty pure $ Just (tail $ showS lens, take (length tblName' - 1) tblName') - _ -> when (logWarnInfo opts) (liftIO $ putStrLn "OpApp not present. Should never be the case as we already filtered.") >> pure Nothing +#if __GLASGOW_HASKELL__ >= 900 + (PatHsExpansion orig (HsApp _ (L _ (HsApp _ _ tableVar)) fldVar)) -> do + let fldName = tail $ showS fldVar + typ <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) tableVar + let tblName' = case typ of + AppTy ty1 _ -> showS ty1 + TyConApp ty1 _ -> showS ty1 + ty -> showS ty + pure $ Just (fldName, take (length tblName' - 1) tblName') +#endif + _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "OpApp not present. Should never be the case as we already filtered.") >> pure Nothing pure mbColNameAndTableName -checkAndApplyRule :: Rule -> PluginOpts -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -checkAndApplyRule ruleT opts ap = case ruleT of - DBRuleT rule@(DBRule _ ruleTableName _ _ _) -> - case ap of - (L _ (ExplicitList (TyConApp ty [_, tblName]) _ exprs)) -> do - case (showS ty == "Clause" && showS tblName == (ruleTableName <> "T")) of - True -> validateDBRule rule opts (showS tblName) exprs ap - False -> pure [] - _ -> pure [] - FunctionRuleT rule@(FunctionRule _ ruleFnName arg_no _ _ _ _ _ _) -> do - let res = getFnNameWithAllArgs ap - -- let (fnName, args) = maybe ("NA", []) (\(x, y) -> ((nameStableString . varName . unLoc) x, y)) $ res - (fnName, args) = maybe ("NA", []) (\(x, y) -> ((getOccString . varName . unLoc) x, y)) $ res - case (fnName == ruleFnName && length args >= arg_no) of - True -> validateFunctionRule rule opts fnName args ap - False -> pure [] - GeneralRuleT rule -> pure [] --TODO: Add handling of general rule - +-- Get how DB field is being extracted in sequelize getDBFieldSpecType :: LHsExpr GhcTc -> DBFieldSpecType -getDBFieldSpecType (L _ expr) - | isPrefixOf "$sel" (showS expr) = Selector - | isInfixOf "^." (showS expr) = Lens - | (\x -> isInfixOf "@" x) (showS expr) = RecordDot - | otherwise = None +getDBFieldSpecType (L loc expr) + | (PatHsWrap _ wExpr) <- expr = getDBFieldSpecType (L loc wExpr) + | (HsRecFld _ fldOcc) <- expr = checkExprString . showS $ selectorAmbiguousFieldOcc fldOcc + | otherwise = checkExprString $ showS expr + where + checkExprString exprStr + | isPrefixOf "$sel" exprStr = Selector + | isInfixOf "^." exprStr = Lens + | (\x -> isInfixOf "@" x) exprStr = RecordDot + | otherwise = None +-- Get function name for the where clause for db rules cases getWhereClauseFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (String, [LHsExpr GhcTc]) -getWhereClauseFnNameWithAllArgs (L _ (HsVar _ v)) = Just (getVarName v, []) -getWhereClauseFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (getVarName $ noLoc clId, [])) <$> conLikeWrapId_maybe cl -getWhereClauseFnNameWithAllArgs (L _ (HsApp _ (L _ (HsVar _ v)) funr)) = Just (getVarName v, [funr]) +getWhereClauseFnNameWithAllArgs (L _ (HsVar _ v)) = Just (getVarName $ unLoc v, []) +getWhereClauseFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (getVarName clId, [])) <$> conLikeWrapId cl +getWhereClauseFnNameWithAllArgs (L _ (HsApp _ (L _ (HsVar _ v)) funr)) = Just (getVarName $ unLoc v, [funr]) getWhereClauseFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do let res = getWhereClauseFnNameWithAllArgs funl case res of @@ -517,23 +829,60 @@ getWhereClauseFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do Just (fnName, ls) -> Just (fnName, ls ++ [funr]) getWhereClauseFnNameWithAllArgs (L loc (OpApp _ lfun op rfun)) = do case showS op of - "($)" -> getWhereClauseFnNameWithAllArgs $ (L loc (HsApp noExtField lfun rfun)) + "($)" -> getWhereClauseFnNameWithAllArgs $ (L loc (HsApp noExtFieldOrAnn lfun rfun)) _ -> Nothing getWhereClauseFnNameWithAllArgs (L loc ap@(HsPar _ expr)) = getWhereClauseFnNameWithAllArgs expr -- If condition inside the list, add dummy type -getWhereClauseFnNameWithAllArgs (L loc ap@(HsIf _ _ _pred thenCl elseCl)) = Just ("Or", [L loc (ExplicitList (LitTy (StrTyLit "Dummy")) Nothing [thenCl, elseCl])]) -getWhereClauseFnNameWithAllArgs (L loc ap@(HsWrap _ _ expr)) = getWhereClauseFnNameWithAllArgs (L loc expr) +getWhereClauseFnNameWithAllArgs (L loc ap@(PatHsIf _pred thenCl elseCl)) = Just ("Or", [L loc (PatExplicitList (LitTy (StrTyLit "Dummy")) [thenCl, elseCl])]) +getWhereClauseFnNameWithAllArgs (L loc ap@(PatHsWrap _ expr)) = getWhereClauseFnNameWithAllArgs (L loc expr) +#if __GLASGOW_HASKELL__ >= 900 +getWhereClauseFnNameWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of + "($)" -> getWhereClauseFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> getWhereClauseFnNameWithAllArgs (L loc expanded) + _ -> getWhereClauseFnNameWithAllArgs (L loc expanded) +#endif getWhereClauseFnNameWithAllArgs (L loc ap@(ExprWithTySig _ expr _)) = getWhereClauseFnNameWithAllArgs expr getWhereClauseFnNameWithAllArgs _ = Nothing -getVarName :: Located Var -> String -getVarName var = (getOccString . varName . unLoc) var - +-- TODO: Verify the correctness of this function before moving it to utils +-- Get function name & LHsExpr which gives resolved type with all it's arguments +getFnNameAndTypeableExprWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, LHsExpr GhcTc, [LHsExpr GhcTc]) +getFnNameAndTypeableExprWithAllArgs ap@(L loc (HsVar _ v)) = Just (getLocated v loc, ap, []) +getFnNameAndTypeableExprWithAllArgs ap@(L _ (HsConLikeOut _ cl)) = (\clId -> (noExprLoc clId, ap, [])) <$> conLikeWrapId cl +getFnNameAndTypeableExprWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameAndTypeableExprWithAllArgs expr +getFnNameAndTypeableExprWithAllArgs (L _ (HsApp _ ap@(L loc (HsVar _ v)) funr)) = Just (getLocated v loc, ap, [funr]) +getFnNameAndTypeableExprWithAllArgs (L _ (HsPar _ expr)) = getFnNameAndTypeableExprWithAllArgs expr +getFnNameAndTypeableExprWithAllArgs (L _ (HsApp _ funl funr)) = do + let res = getFnNameAndTypeableExprWithAllArgs funl + case res of + Nothing -> Nothing + Just (fnName, typLHsExpr, ls) -> Just (fnName, typLHsExpr, ls ++ [funr]) +getFnNameAndTypeableExprWithAllArgs (L loc (OpApp _ funl op funr)) = do + case showS op of + "($)" -> getFnNameAndTypeableExprWithAllArgs $ (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> Nothing +getFnNameAndTypeableExprWithAllArgs ap@(L loc (PatHsWrap _ (HsVar _ v))) = Just (getLocated v loc, ap, []) +getFnNameAndTypeableExprWithAllArgs (L loc ap@(PatHsWrap _ expr)) = getFnNameAndTypeableExprWithAllArgs (L loc expr) +#if __GLASGOW_HASKELL__ >= 900 +getFnNameAndTypeableExprWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of + "($)" -> getFnNameAndTypeableExprWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> getFnNameAndTypeableExprWithAllArgs (L loc expanded) + _ -> getFnNameAndTypeableExprWithAllArgs (L loc expanded) +#endif +getFnNameAndTypeableExprWithAllArgs _ = Nothing + + +-- TODO: Verify the correctness of this function before moving it to utils +-- Get function name with all it's arguments getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) -getFnNameWithAllArgs (L _ (HsVar _ v)) = Just (v, []) -getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noLoc clId, [])) <$> conLikeWrapId_maybe cl +getFnNameWithAllArgs (L loc (HsVar _ v)) = Just (getLocated v loc, []) +getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl getFnNameWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameWithAllArgs expr -getFnNameWithAllArgs (L _ (HsApp _ (L _ (HsVar _ v)) funr)) = Just (v, [funr]) +getFnNameWithAllArgs (L _ (HsApp _ (L loc (HsVar _ v)) funr)) = Just (getLocated v loc, [funr]) getFnNameWithAllArgs (L _ (HsPar _ expr)) = getFnNameWithAllArgs expr getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do let res = getFnNameWithAllArgs funl @@ -542,58 +891,66 @@ getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do Just (fnName, ls) -> Just (fnName, ls ++ [funr]) getFnNameWithAllArgs (L loc (OpApp _ funl op funr)) = do case showS op of - "($)" -> getFnNameWithAllArgs $ (L loc (HsApp noExtField funl funr)) + "($)" -> getFnNameWithAllArgs $ (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> Nothing -getFnNameWithAllArgs (L loc ap@(HsWrap _ _ expr)) = do - getFnNameWithAllArgs (L loc expr) +getFnNameWithAllArgs (L loc ap@(PatHsWrap _ expr)) = getFnNameWithAllArgs (L loc expr) +#if __GLASGOW_HASKELL__ >= 900 +getFnNameWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of + "($)" -> getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> getFnNameWithAllArgs (L loc expanded) + _ -> getFnNameWithAllArgs (L loc expanded) +#endif getFnNameWithAllArgs _ = Nothing ---------------------------- Utils --------------------------- -isAllowedOnCurrentModule :: String -> Rule -> Bool -isAllowedOnCurrentModule moduleName rule = - let ignoreModules = getRuleIgnoreModules rule - in moduleName `notElem` ignoreModules - -hasAny :: Eq a => [a] -- ^ List of elements to look for - -> [a] -- ^ List to search - -> Bool -- ^ Result -hasAny [] _ = False -- An empty search list: always false -hasAny _ [] = False -- An empty list to scan: always false -hasAny search (x:xs) = if x `elem` search then True else hasAny search xs +--------------------------- Sheriff Plugin Utils --------------------------- +-- Transform the FnBlockedInArg Violation with correct expression +trfViolationErrorInfo :: (HasPluginOpts PluginOpts) => Violation -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Violation +trfViolationErrorInfo violation@(FnBlockedInArg p1 ruleFnName _ rule) outsideExpr insideExpr = do + errorInfo <- mkFnBlockedInArgErrorInfo outsideExpr insideExpr + pure $ FnBlockedInArg p1 ruleFnName errorInfo rule +trfViolationErrorInfo violation _ _ = pure violation --- Check if HsExpr is Function Application -isFunApp :: LHsExpr GhcTc -> Bool -isFunApp (L _ (HsApp _ _ _)) = True -isFunApp (L _ (OpApp _ funl op funr)) = True -isFunApp _ = False - --- Check if HsExpr is Lens operator application -isLensOpApp :: HsExpr GhcTc -> Bool -isLensOpApp (OpApp _ _ op _) = showS op == "(^.)" -isLensOpApp (SectionR _ op _) = showS op == "(^.)" -isLensOpApp _ = False - --- If the type is literal type, get the string name of the literal, else return the showS verison of the type -getStrFromHsWildCardBndrs :: HsWildCardBndrs (NoGhcTc GhcTc) (LHsType (NoGhcTc GhcTc)) -> String -getStrFromHsWildCardBndrs (HsWC _ (L _ (HsTyLit _ (HsStrTy _ fs)))) = unpackFS fs -getStrFromHsWildCardBndrs typ = showS typ - --- Check if a Var is fun type -isFunVar :: Var -> Bool -isFunVar = isFunTy . dropForAlls . idType - --- Check if a Type is Enum type -isEnumType :: Type -> Bool -isEnumType (TyConApp tyCon _) = isEnumerationTyCon tyCon -isEnumType _ = False - --- Pretty print the Internal Representations -showOutputable :: (MonadIO m, Outputable a) => a -> m () -showOutputable = liftIO . putStrLn . showS +-- Create Error Info for FnBlockedInArg Violation +mkFnBlockedInArgErrorInfo :: (HasPluginOpts PluginOpts) => LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Value +mkFnBlockedInArgErrorInfo lOutsideExpr@(L _ outsideExpr) lInsideExpr@(L _ insideExpr) = do + let loc1 = getLoc2 lOutsideExpr + loc2 = getLoc2 lInsideExpr + filePath <- unpackFS . srcSpanFile . tcg_top_loc . env_gbl <$> getEnv + let overall_src_span = showS loc1 + overall_err_line_orig = showS lOutsideExpr + err_fn_src_span = showS loc2 + err_fn_err_line_orig = showS lInsideExpr + overall_err_line <- + if (useIOForSourceCode . pluginOpts $ ?pluginOpts) + then liftIO $ extractSrcSpanSegment loc1 filePath overall_err_line_orig + else pure overall_err_line_orig + err_fn_err_line <- + if (useIOForSourceCode . pluginOpts $ ?pluginOpts) + then liftIO $ extractSrcSpanSegment loc2 filePath err_fn_err_line_orig + else pure err_fn_err_line_orig + pure $ A.object [ + ("overall_src_span", A.toJSON overall_src_span), + ("overall_err_line", A.toJSON overall_err_line), + ("err_fn_src_span", A.toJSON err_fn_src_span), + ("err_fn_err_line", A.toJSON err_fn_err_line) + ] --- Print the AST -printAst :: (MonadIO m, Data a) => a -> m () -printAst = liftIO . putStrLn . showAst +-- Check if a rule is allowed on current module +isAllowedOnCurrentModule :: String -> Rule -> Bool +isAllowedOnCurrentModule moduleName rule = + let ignoredModules = getRuleIgnoreModules rule + allowedModules = getRuleCheckModules rule + isCurrentModuleAllowed = any (matchNamesWithAsterisk AsteriskInBoth moduleName) allowedModules + isCurrentModuleIgnored = any (matchNamesWithAsterisk AsteriskInBoth moduleName) ignoredModules + in isCurrentModuleAllowed && not isCurrentModuleIgnored + +-- Check if a rule is allowed on current Function +isAllowedOnCurrentFunction :: String -> Rule -> Bool +isAllowedOnCurrentFunction currentFnNameWithModule rule = + let ignoredFunctions = getRuleIgnoreFunctions rule + in not $ any (matchNamesWithAsterisk AsteriskInSecond currentFnNameWithModule) ignoredFunctions -- Create GHC compilation error from CompileError mkGhcCompileError :: CompileError -> (SrcSpan, OP.SDoc) @@ -616,38 +973,52 @@ mkInvalidYamlFileErr err = OP.text err -- Create Internal Representation of Logging Error mkCompileError :: String -> (LHsExpr GhcTc, Violation) -> TcM CompileError -mkCompileError modName (expr, violation) = pure $ CompileError "" modName (show violation) (getLoc expr) violation (getViolationSuggestions violation) (getErrorInfoFromViolation violation) +mkCompileError modName (expr, violation) = pure $ CompileError "" modName (show violation) (getLoc2 expr) violation (getViolationSuggestions violation) (getErrorInfoFromViolation violation) + +-- Add GHC error to a file +addErrToFile :: ModSummary -> String -> [CompileError] -> TcM () +addErrToFile modSummary path errs = do + let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary + res = encodePretty errs + liftIO $ createDirectoryIfMissing True path + liftIO $ writeFile (path <> moduleName' <> "_compilationErrors.json") res --- Create Error Info for FnBlockedInArg Violation -mkFnBlockedInArgErrorInfo :: PluginOpts -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Value -mkFnBlockedInArgErrorInfo opts lOutsideExpr@(L loc1 outsideExpr) lInsideExpr@(L loc2 insideExpr) = do - filePath <- unpackFS . srcSpanFile . tcg_top_loc . env_gbl <$> getEnv - let overall_src_span = showS loc1 - overall_err_line_orig = showS lOutsideExpr - err_fn_src_span = showS loc2 - err_fn_err_line_orig = showS lInsideExpr - overall_err_line <- - if useIOForSourceCode opts - then liftIO $ extractSrcSpanSegment loc1 filePath overall_err_line_orig - else pure overall_err_line_orig - err_fn_err_line <- - if useIOForSourceCode opts - then liftIO $ extractSrcSpanSegment loc2 filePath err_fn_err_line_orig - else pure err_fn_err_line_orig - pure $ A.object [ - ("overall_src_span", A.toJSON overall_src_span), - ("overall_err_line", A.toJSON overall_err_line), - ("err_fn_src_span", A.toJSON err_fn_src_span), - ("err_fn_err_line", A.toJSON err_fn_err_line) - ] +-- TODO: Verify the correctness of function +-- Check if HsExpr is HsVar which can be simple variable or function application +isHsVar :: LHsExpr GhcTc -> Bool +isHsVar (L _ (HsVar _ _)) = True +isHsVar _ = False --- Transform the FnBlockedInArg Violation with correct expression -trfViolationErrorInfo :: PluginOpts -> Violation -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Violation -trfViolationErrorInfo opts violation@(FnBlockedInArg p1 _ rule) outsideExpr insideExpr = do - errorInfo <- mkFnBlockedInArgErrorInfo opts outsideExpr insideExpr - pure $ FnBlockedInArg p1 errorInfo rule -trfViolationErrorInfo _ violation _ _ = pure violation +-- TODO: Verify the correctness of function +-- Check if HsExpr is Function Application +isFunApp :: Bool -> LHsExpr GhcTc -> Bool +isFunApp hsVarAsFunApp (L _ (HsVar _ _)) = hsVarAsFunApp +isFunApp _ (L _ (HsApp _ _ _)) = True +isFunApp _ (L _ (OpApp _ funl op funr)) = True +isFunApp hsVarAsFunApp (L loc (PatHsWrap _ expr)) = isFunApp hsVarAsFunApp (L loc expr) +#if __GLASGOW_HASKELL__ >= 900 +isFunApp _ (L _ (PatHsExpansion orig expanded)) = + case orig of + (OpApp{}) -> True + _ -> False +#endif +isFunApp _ _ = False +-- Check if HsExpr is Lens operator application +isLensOpApp :: HsExpr GhcTc -> Bool +isLensOpApp (OpApp _ _ op _) = showS op == "(^.)" +isLensOpApp (SectionR _ op _) = showS op == "(^.)" +#if __GLASGOW_HASKELL__ >= 900 +isLensOpApp (PatHsExpansion (OpApp _ _ op _) expanded) = showS op == "(^.)" +#endif +isLensOpApp _ = False + +-- If the type is literal type, get the string name of the literal, else return the showS verison of the type +getStrFromHsWildCardBndrs :: HsWildCardBndrs (NoGhcTc GhcTc) (LHsType (NoGhcTc GhcTc)) -> String +getStrFromHsWildCardBndrs (HsWC _ (L _ (HsTyLit _ (HsStrTy _ fs)))) = unpackFS fs +getStrFromHsWildCardBndrs typ = showS typ + +-- -------------------------------- DEPRECATED CODE (Might be useful for some other use cases or some other plugin) -------------------------------- -- -- [DEPRECATED] Get Return type of the function application arg getArgTypeWrapper :: LHsExpr GhcTc -> [Type] getArgTypeWrapper expr@(L _ (HsApp _ lfun rfun)) = getArgType expr True @@ -657,7 +1028,7 @@ getArgTypeWrapper expr@(L _ (OpApp _ lfun op rfun)) = "(.)" -> getArgTypeWrapper lfun "(<>)" -> getArgTypeWrapper lfun _ -> getArgType op True -getArgTypeWrapper (L loc (HsWrap _ _ expr)) = getArgTypeWrapper (L loc expr) +getArgTypeWrapper (L loc (PatHsWrap _ expr)) = getArgTypeWrapper (L loc expr) getArgTypeWrapper (L loc (HsPar _ expr)) = getArgTypeWrapper expr getArgTypeWrapper expr = getArgType expr False @@ -665,10 +1036,10 @@ getArgTypeWrapper expr = getArgType expr False getArgType :: LHsExpr GhcTc -> Bool -> [Type] getArgType (L _ (HsLit _ v)) _ = getLitType v getArgType (L _ (HsOverLit _ (OverLit (OverLitTc _ typ) v _))) _ = [typ] -getArgType (L loc (HsWrap _ _ expr)) shouldReturnFinalType = getArgType (L loc expr) shouldReturnFinalType +getArgType (L loc (PatHsWrap _ expr)) shouldReturnFinalType = getArgType (L loc expr) shouldReturnFinalType getArgType (L loc (HsApp _ lfun rfun)) shouldReturnFinalType = getArgType lfun shouldReturnFinalType getArgType arg shouldReturnFinalType = - let vars = filter (not . isSystemName . varName) $ arg ^? biplateRef in + let vars = filter (not . isSystemName . varName) $ traverseAst arg in if length vars == 0 then [] else @@ -678,29 +1049,6 @@ getArgType arg shouldReturnFinalType = actualReturnTyp = (trfUsingConstraints constraints $ typeReturnFn actualTyp) in actualReturnTyp --- [DEPRECATED] Get HsLit literal type -getLitType :: HsLit GhcTc -> [Type] -getLitType (HsChar _ _ ) = [charTy] -getLitType (HsCharPrim _ _) = [charTy] -getLitType (HsString _ _) = [stringTy] -getLitType (HsStringPrim _ _) = [stringTy] -getLitType (HsInt _ _) = [intTy] -getLitType (HsIntPrim _ _) = [intTy] -getLitType (HsWordPrim _ _) = [wordTy] -getLitType (HsInt64Prim _ _) = [intTy] -getLitType (HsWord64Prim _ _) = [wordTy] -getLitType (HsInteger _ _ _) = [intTy] -getLitType (HsRat _ _ _) = [doubleTy] -getLitType (HsFloatPrim _ _) = [floatTy] -getLitType (HsDoublePrim _ _) = [doubleTy] -getLitType _ = [] - --- [DEPRECATED] Get final return type of any type/function signature -getReturnType :: Type -> [Type] -getReturnType typ - | isFunTy typ = getReturnType $ tcFunResultTy typ - | otherwise = let (x, y) = tcSplitAppTys typ in x : y - -- [DECRECATED] Transform the type from the constraints trfUsingConstraints :: [PredType] -> [Type] -> [Type] trfUsingConstraints constraints typs = @@ -721,123 +1069,9 @@ trfUsingConstraints constraints typs = replacer replacements typ@(AppTy ty1 ty2) = AppTy (replacer replacements ty1) (replacer replacements ty2) replacer replacements typ@(TyConApp tyCon typOrKinds) = TyConApp tyCon $ map (replacer replacements) typOrKinds replacer replacements typ@(ForAllTy bndrs typ') = ForAllTy bndrs (replacer replacements typ') +#if __GLASGOW_HASKELL__ >= 900 + replacer replacements typ@(FunTy flag mult ty1 ty2) = FunTy flag mult (replacer replacements ty1) (replacer replacements ty2) +#else replacer replacements typ@(FunTy flag ty1 ty2) = FunTy flag (replacer replacements ty1) (replacer replacements ty2) - replacer replacements typ = maybe typ snd $ (\x -> eqType (fst x) typ) `find` replacements - --- [DEPRECATED] Get List of stringification functions used inside a HsExpr; Uses `stringifierFns` -getStringificationFns :: LHsExpr GhcTc -> TcM [String] -getStringificationFns (L _ ap@(HsVar _ v)) = do - liftIO $ putStrLn "Inside HsVar" >> putStrLn (showS ap) - pure $ [getOccString v] - -- case (getOccString v) `elem` stringifierFns of - -- True -> pure [getOccString v] - -- False -> pure [] -getStringificationFns (L _ ap@(HsApp _ lfun rfun)) = do - liftIO $ putStrLn "Inside HsApp" >> putStrLn (showS ap) - x1 <- getStringificationFns lfun - x2 <- getStringificationFns rfun - pure $ x1 <> x2 -getStringificationFns (L _ ap@(OpApp _ lfun op rfun)) = do - liftIO $ putStrLn "Inside OpApp" >> putStrLn (showS ap) - x1 <- getStringificationFns lfun - x2 <- getStringificationFns op - x3 <- getStringificationFns rfun - pure $ x1 <> x2 <> x3 -getStringificationFns (L _ ap@(HsPar _ expr)) = do - liftIO $ putStrLn "Inside HsPar" >> putStrLn (showS ap) - getStringificationFns expr -getStringificationFns (L loc ap@(HsWrap _ _ expr)) = do - liftIO $ putStrLn "Inside HsWrap" >> putStrLn (showS ap) - getStringificationFns (L loc expr) -getStringificationFns _ = do - liftIO $ putStrLn $ "Inside _" - pure [] - --- [DEPRECATED] Get List of stringification functions used inside a HsExpr; Uses `stringifierFns` -getStringificationFns2 :: LHsExpr GhcTc -> FunctionRule -> [String] -getStringificationFns2 arg rule = - let vars = arg ^? biplateRef :: [Var] - blockedFns = fmap (\(fname, _, _) -> fname) $ fns_blocked_in_arg rule - in map getOccString $ filter (\x -> ((getOccString x) `elem` blockedFns)) $ takeWhile isFunVar $ filter (not . isSystemName . varName) vars - --- Get List of blocked functions used inside a HsExpr; Uses `getBlockedFnsList` -getBlockedFnsList :: PluginOpts -> LHsExpr GhcTc -> FunctionRule -> TcM [(LHsExpr GhcTc, String, String)] -getBlockedFnsList opts arg rule@(FunctionRule _ _ arg_no fnsBlocked _ _ _ _ _) = do - let argHsExprs = arg ^? biplateRef :: [LHsExpr GhcTc] - fnApps = filter isFunApp argHsExprs - when (logDebugInfo opts) $ liftIO $ do - print "getBlockedFnsList" - showOutputable arg - showOutputable fnApps - catMaybes <$> mapM checkFnBlockedInArg fnApps - -- vars = arg ^? biplateRef :: [Var] - -- blockedFns = fmap (\(fname, _, _) -> fname) $ fns_blocked_in_arg rule - -- in map getOccString $ filter (\x -> ((getOccString x) `elem` blockedFns) && (not . isSystemName . varName) x) vars - where - checkFnBlockedInArg :: LHsExpr GhcTc -> TcM (Maybe (LHsExpr GhcTc, String, String)) - checkFnBlockedInArg expr = do - let res = getFnNameWithAllArgs expr - when (logDebugInfo opts) $ liftIO $ do - print "checkFnBlockedInArg" - showOutputable res - case res of - Nothing -> pure Nothing - Just (fnName, args) -> isPresentInBlockedFnList expr fnsBlocked ((getOccString . varName . unLoc) fnName) args - - isPresentInBlockedFnList :: LHsExpr GhcTc -> FnsBlockedInArg -> String -> [LHsExpr GhcTc] -> TcM (Maybe (LHsExpr GhcTc, String, String)) - isPresentInBlockedFnList expr [] _ _ = pure Nothing - isPresentInBlockedFnList expr ((ruleFnName, ruleArgNo, ruleAllowedTypes) : ls) fnName fnArgs = do - when (logDebugInfo opts) $ liftIO $ do - print "isPresentInBlockedFnList" - print (ruleFnName, ruleArgNo, ruleAllowedTypes) - case ruleFnName == fnName && length fnArgs >= ruleArgNo of - False -> isPresentInBlockedFnList expr ls fnName fnArgs - True -> do - let reqArg = head $ drop (ruleArgNo - 1) fnArgs - argType <- getHsExprType opts reqArg - when (logDebugInfo opts) $ liftIO $ do - showOutputable reqArg - showOutputable argType - if validateAllowedTypes argType ruleAllowedTypes - then isPresentInBlockedFnList expr ls fnName fnArgs - else pure $ Just (expr, fnName, showS argType) - - validateAllowedTypes :: Type -> TypesAllowedInArg -> Bool - validateAllowedTypes argType@(TyConApp tyCon ls) ruleAllowedTypes = - if showS tyCon == "(,)" && "(,)" `elem` ruleAllowedTypes - then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls - else if showS tyCon == "[]" && "[]" `elem` ruleAllowedTypes - then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls - else if showS tyCon == "Maybe" && "Maybe" `elem` ruleAllowedTypes - then (\t -> validateAllowedTypes t ruleAllowedTypes) `all` ls - else (isEnumType argType && "EnumTypes" `elem` ruleAllowedTypes) || (showS argType) `elem` ruleAllowedTypes - validateAllowedTypes argType ruleAllowedTypes = (isEnumType argType && "EnumTypes" `elem` ruleAllowedTypes) || (showS argType) `elem` ruleAllowedTypes - --- Add GHC error to a file -addErrToFile :: ModSummary -> String -> [CompileError] -> TcM () -addErrToFile modSummary path errs = do - let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary - -- res = (encodePretty moduleName') <> (fromString ": ") <> encodePretty errs <> (fromString ",") - res = encodePretty errs - liftIO $ createDirectoryIfMissing True path - liftIO $ writeFile (path <> moduleName' <> "_compilationErrors.json") res - --- Get type for a LHsExpr GhcTc -getHsExprType :: PluginOpts -> LHsExpr GhcTc -> TcM Type -getHsExprType opts expr = do - coreExpr <- initDsTc $ dsLExpr expr - when (logTypeDebugging opts) $ liftIO $ print $ "DebugType = " <> (debugPrintType $ exprType coreExpr) - pure $ exprType coreExpr - --- Get type for a LHsExpr GhcTc -getHsExprTypeWithResolver :: PluginOpts -> LHsExpr GhcTc -> TcM Type -getHsExprTypeWithResolver opts expr = deNoteType <$> getHsExprType opts expr - -debugPrintType :: Type -> String -debugPrintType (TyVarTy v) = "(TyVar " <> showS v <> ")" -debugPrintType (AppTy ty1 ty2) = "(AppTy " <> debugPrintType ty1 <> " " <> debugPrintType ty2 <> ")" -debugPrintType (TyConApp tycon tys) = "(TyCon (" <> showS tycon <> ") [" <> foldr (\x r -> debugPrintType x <> ", " <> r) "" tys <> "]" -debugPrintType (ForAllTy _ ty) = "(ForAllTy " <> debugPrintType ty <> ")" -debugPrintType (FunTy _ ty1 ty2) = "(FunTy " <> debugPrintType ty1 <> " " <> debugPrintType ty2 <> ")" -debugPrintType (LitTy litTy) = "(LitTy " <> showS litTy <> ")" -debugPrintType _ = "" +#endif + replacer replacements typ = maybe typ snd $ (\x -> eqType (fst x) typ) `find` replacements diff --git a/sheriff/src/Sheriff/Rules.hs b/sheriff/src/Sheriff/Rules.hs index cfdbd83..688ff02 100644 --- a/sheriff/src/Sheriff/Rules.hs +++ b/sheriff/src/Sheriff/Rules.hs @@ -1,144 +1,39 @@ module Sheriff.Rules where import Sheriff.Types +import Sheriff.TypesUtils +import Sheriff.Utils --- TODO: Take these from the configuration file -badPracticeRules :: Rules -badPracticeRules = [ +defaultSheriffRules :: Rules +defaultSheriffRules = [ defaultRule - -- , logRule1 - -- , logRule2 - -- , logRule3 - -- , logRule4 - -- , logRule5 - -- , logRule6 - -- , logRule7 - -- , logRule8 - -- , logRule9 - -- , logRule10 - -- , logRule11 - -- , logRule12 - -- , logRule13 - -- , logRule14 - -- , logRule15 + -- , noKVDBRule + -- , infiniteRecursionRule , showRule ] -- Exceptions to rule out if these rules are also applied to same LHsExpr -exceptionRules :: Rules -exceptionRules = [ +defaultSheriffExceptionsRules :: Rules +defaultSheriffExceptionsRules = [ defaultRule - -- , updateFunctionRuleArgNo logRule1 1 - -- , updateFunctionRuleArgNo logRule2 1 - -- , updateFunctionRuleArgNo logRule3 1 - -- , updateFunctionRuleArgNo logRule4 1 - -- , updateFunctionRuleArgNo logRule5 1 - -- , updateFunctionRuleArgNo logRule6 1 - -- , updateFunctionRuleArgNo logRule7 1 - -- , updateFunctionRuleArgNo logRule8 1 - -- , updateFunctionRuleArgNo logRule9 1 - -- , updateFunctionRuleArgNo logRule10 1 - -- , updateFunctionRuleArgNo logRule11 1 - -- , updateFunctionRuleArgNo logRule12 1 - -- , updateFunctionRuleArgNo logRule13 1 - -- , updateFunctionRuleArgNo logRule14 1 - -- , updateFunctionRuleArgNo logRule15 1 ] -logArgNo :: ArgNo -logArgNo = 2 - -logRule1 :: Rule -logRule1 = FunctionRuleT $ FunctionRule "LogRule" "logErrorT" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule2 :: Rule -logRule2 = FunctionRuleT $ FunctionRule "LogRule" "logErrorV" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule3 :: Rule -logRule3 = FunctionRuleT $ FunctionRule "LogRule" "logError" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule4 :: Rule -logRule4 = FunctionRuleT $ FunctionRule "LogRule" "logInfoT" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule5 :: Rule -logRule5 = FunctionRuleT $ FunctionRule "LogRule" "logInfoV" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule6 :: Rule -logRule6 = FunctionRuleT $ FunctionRule "LogRule" "logInfo" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule7 :: Rule -logRule7 = FunctionRuleT $ FunctionRule "LogRule" "logDebugT" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule8 :: Rule -logRule8 = FunctionRuleT $ FunctionRule "LogRule" "logDebugV" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule9 :: Rule -logRule9 = FunctionRuleT $ FunctionRule "LogRule" "logDebug" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule10 :: Rule -logRule10 = FunctionRuleT $ FunctionRule "LogRule" "logErrorWithCategory" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule11 :: Rule -logRule11 = FunctionRuleT $ FunctionRule "LogRule" "logErrorWithCategoryV" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule12 :: Rule -logRule12 = FunctionRuleT $ FunctionRule "LogRule" "forkErrorLog" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule13 :: Rule -logRule13 = FunctionRuleT $ FunctionRule "LogRule" "forkInfoLog" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule14 :: Rule -logRule14 = FunctionRuleT $ FunctionRule "LogRule" "debugLog" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - -logRule15 :: Rule -logRule15 = FunctionRuleT $ FunctionRule "LogRule" "warnLog" logArgNo stringifierFns [] textTypesToCheck logRuleSuggestions [] [] - showRuleExceptions :: Rules showRuleExceptions = [ defaultRule - , logRule1 - , logRule2 - , logRule3 - , logRule4 - , logRule5 - , logRule6 - , logRule7 - , logRule8 - , logRule9 - , logRule10 - , logRule11 - , logRule12 - , logRule13 - , logRule14 - , logRule15 - , updateFunctionRuleArgNo logRule1 1 - , updateFunctionRuleArgNo logRule2 1 - , updateFunctionRuleArgNo logRule3 1 - , updateFunctionRuleArgNo logRule4 1 - , updateFunctionRuleArgNo logRule5 1 - , updateFunctionRuleArgNo logRule6 1 - , updateFunctionRuleArgNo logRule7 1 - , updateFunctionRuleArgNo logRule8 1 - , updateFunctionRuleArgNo logRule9 1 - , updateFunctionRuleArgNo logRule10 1 - , updateFunctionRuleArgNo logRule11 1 - , updateFunctionRuleArgNo logRule12 1 - , updateFunctionRuleArgNo logRule13 1 - , updateFunctionRuleArgNo logRule14 1 - , updateFunctionRuleArgNo logRule15 1 ] showRule :: Rule -showRule = FunctionRuleT $ FunctionRule "ShowRule" "show" 1 stringifierFns textTypesBlocked textTypesToCheck showRuleSuggestions showRuleExceptions [] +showRule = FunctionRuleT $ FunctionRule "ShowRule" ["show"] 1 [] stringifierFns textTypesBlocked textTypesToCheck showRuleSuggestions showRuleExceptions [] ["*"] [] -noUseRule :: Rule -noUseRule = FunctionRuleT $ FunctionRule "NoDecodeUtf8Rule" "$text-1.2.4.1$Data.Text.Encoding$decodeUtf8" 0 [] [] [] ["You might want to use some other wrapper function."] [] [] +infiniteRecursionRule :: Rule +infiniteRecursionRule = InfiniteRecursionRuleT defaultInfiniteRecursionRuleT -dbRule :: Rule -dbRule = DBRuleT $ DBRule "NonIndexedDBRule" "TxnRiskCheck" [NonCompositeKey "partitionKey"] dbRuleSuggestions [] +defaultInfiniteRecursionRuleT :: InfiniteRecursionRule +defaultInfiniteRecursionRuleT = defaultInfiniteRecursionRule {infinite_recursion_rule_name = "Infinite Recursion", infinite_recursion_rule_fixes = ["Remove the infinite recursion.", "Add a base case check.", "Pass the modified value to function arguments."]} -dbRuleCustomer :: Rule -dbRuleCustomer = DBRuleT $ DBRule "NonIndexedDBRule" "MerchantKey" [NonCompositeKey "status"] dbRuleSuggestions [] +noKVDBRule :: Rule +noKVDBRule = FunctionRuleT $ FunctionRule "ART KVDB Rule" ["runKVDB"] 0 [] [] [] [] ["You might want to use some other wrapper function from `EulerHS.Extra.Redis` module.", "For e.g. - rExists, rDel, rGet, rExpire, etc."] [] [] ["*"] [] updateFunctionRuleArgNo :: Rule -> ArgNo -> Rule updateFunctionRuleArgNo (FunctionRuleT fnRule) newArgNo = FunctionRuleT $ fnRule{arg_no = newArgNo} diff --git a/sheriff/src/Sheriff/Types.hs b/sheriff/src/Sheriff/Types.hs index 8dd6203..24fd04c 100644 --- a/sheriff/src/Sheriff/Types.hs +++ b/sheriff/src/Sheriff/Types.hs @@ -1,80 +1,84 @@ +{-# LANGUAGE RecordWildCards #-} + module Sheriff.Types where +import Sheriff.Utils import Data.Aeson as A -import SrcLoc -import Var -import Outputable as OP hiding ((<>)) import Control.Applicative ((<|>)) import Data.Text (unpack) import Data.Data (Data) -import GHC.Hs.Dump -import Language.Haskell.GHC.ExactPrint (exactPrint) -import Language.Haskell.GHC.ExactPrint.Annotater (Annotate) + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Types.SrcLoc +import GHC.Types.Var +#else +import SrcLoc +import Var +#endif data PluginOpts = PluginOpts { - saveToFile :: Bool, + saveToFile :: Bool, throwCompilationError :: Bool, - failOnFileNotFound :: Bool, - savePath :: String, - indexedKeysPath :: String, - rulesConfigPath :: String, - exceptionsConfigPath :: String, - matchAllInsideAnd :: Bool, + failOnFileNotFound :: Bool, + savePath :: String, + indexedKeysPath :: String, + rulesConfigPath :: String, + exceptionsConfigPath :: String, + matchAllInsideAnd :: Bool, shouldCheckExceptions :: Bool, - logDebugInfo :: Bool, - logWarnInfo :: Bool, - logTypeDebugging :: Bool, - useIOForSourceCode :: Bool + logDebugInfo :: Bool, + logWarnInfo :: Bool, + logTypeDebugging :: Bool, + useIOForSourceCode :: Bool } deriving (Show, Eq) defaultPluginOpts :: PluginOpts defaultPluginOpts = PluginOpts { - saveToFile = False, + saveToFile = False, throwCompilationError = True, - failOnFileNotFound = True, - matchAllInsideAnd = False, - savePath = ".juspay/tmp/sheriff/", - indexedKeysPath = ".juspay/indexedKeys.yaml" , - rulesConfigPath = ".juspay/sheriffRules.yaml", - exceptionsConfigPath = ".juspay/sheriffExceptionRules.yaml", - logDebugInfo = False, - logWarnInfo = True, - logTypeDebugging = False, + failOnFileNotFound = True, + matchAllInsideAnd = False, + savePath = ".juspay/tmp/sheriff/", + indexedKeysPath = ".juspay/indexedKeys.yaml" , + rulesConfigPath = ".juspay/sheriffRules.yaml", + exceptionsConfigPath = ".juspay/sheriffExceptionRules.yaml", + logDebugInfo = False, + logWarnInfo = True, + logTypeDebugging = False, shouldCheckExceptions = True, - useIOForSourceCode = False + useIOForSourceCode = False } instance FromJSON PluginOpts where parseJSON = withObject "PluginOpts" $ \o -> do - saveToFile <- o .:? "saveToFile" .!= (saveToFile defaultPluginOpts) - failOnFileNotFound <- o .:? "failOnFileNotFound" .!= (failOnFileNotFound defaultPluginOpts) + saveToFile <- o .:? "saveToFile" .!= (saveToFile defaultPluginOpts) + failOnFileNotFound <- o .:? "failOnFileNotFound" .!= (failOnFileNotFound defaultPluginOpts) throwCompilationError <- o .:? "throwCompilationError" .!= (throwCompilationError defaultPluginOpts) - savePath <- o .:? "savePath" .!= (savePath defaultPluginOpts) - indexedKeysPath <- o .:? "indexedKeysPath" .!= (indexedKeysPath defaultPluginOpts) - rulesConfigPath <- o .:? "rulesConfigPath" .!= (rulesConfigPath defaultPluginOpts) - exceptionsConfigPath <- o .:? "exceptionsConfigPath" .!= (exceptionsConfigPath defaultPluginOpts) - matchAllInsideAnd <- o .:? "matchAllInsideAnd" .!= (matchAllInsideAnd defaultPluginOpts) - shouldCheckExceptions <- o .:? "matchAllInsideAnd" .!= (shouldCheckExceptions defaultPluginOpts) - logDebugInfo <- o .:? "logDebugInfo" .!= (logDebugInfo defaultPluginOpts) - logWarnInfo <- o .:? "logWarnInfo" .!= (logWarnInfo defaultPluginOpts) - logTypeDebugging <- o .:? "logTypeDebugging" .!= (logTypeDebugging defaultPluginOpts) - useIOForSourceCode <- o .:? "useIOForSourceCode" .!= (useIOForSourceCode defaultPluginOpts) - return PluginOpts { - saveToFile = saveToFile, - throwCompilationError = throwCompilationError, - matchAllInsideAnd = matchAllInsideAnd, - savePath = savePath, - indexedKeysPath = indexedKeysPath, - rulesConfigPath = rulesConfigPath, - exceptionsConfigPath = exceptionsConfigPath, - failOnFileNotFound = failOnFileNotFound, - shouldCheckExceptions = shouldCheckExceptions, - logWarnInfo = logWarnInfo, - logDebugInfo = logDebugInfo, - logTypeDebugging = logTypeDebugging , - useIOForSourceCode = useIOForSourceCode - } + savePath <- o .:? "savePath" .!= (savePath defaultPluginOpts) + indexedKeysPath <- o .:? "indexedKeysPath" .!= (indexedKeysPath defaultPluginOpts) + rulesConfigPath <- o .:? "rulesConfigPath" .!= (rulesConfigPath defaultPluginOpts) + exceptionsConfigPath <- o .:? "exceptionsConfigPath" .!= (exceptionsConfigPath defaultPluginOpts) + matchAllInsideAnd <- o .:? "matchAllInsideAnd" .!= (matchAllInsideAnd defaultPluginOpts) + shouldCheckExceptions <- o .:? "matchAllInsideAnd" .!= (shouldCheckExceptions defaultPluginOpts) + logDebugInfo <- o .:? "logDebugInfo" .!= (logDebugInfo defaultPluginOpts) + logWarnInfo <- o .:? "logWarnInfo" .!= (logWarnInfo defaultPluginOpts) + logTypeDebugging <- o .:? "logTypeDebugging" .!= (logTypeDebugging defaultPluginOpts) + useIOForSourceCode <- o .:? "useIOForSourceCode" .!= (useIOForSourceCode defaultPluginOpts) + return PluginOpts {..} + +type Rules = [Rule] +type ArgNo = Int +type ArgTypes = [String] +type SignaturesBlockedInFn = [String] +type FnsBlockedInArg = [(String, ArgNo, TypesAllowedInArg)] +type TypesAllowedInArg = [String] +type TypesBlockedInArg = [String] +type TypesToCheckInArg = [String] +type Suggestions = [String] +type Modules = [String] +type FunctionNames = [String] +type ModulesWithFunctions = [String] data SheriffRules = SheriffRules { rules :: Rules @@ -95,17 +99,25 @@ instance FromJSON YamlTables where return YamlTables { tables = tableList } data YamlTable = YamlTable - { tableName :: String - , indexedKeys :: [YamlTableKeys] + { tableName :: String + , indexedKeys :: [YamlTableKeys] + , ignoredFunctions :: ModulesWithFunctions + , ignoredModules :: Modules + , checkModules :: Modules } deriving (Show, Eq) instance FromJSON YamlTable where parseJSON = withObject "YamlTable" $ \o -> do - name <- o .: "name" - keys <- o .: "indexedKeys" - return YamlTable { tableName = name, indexedKeys = keys } - -data YamlTableKeys = NonCompositeKey String | CompositeKey { cols :: [String] } + tableName <- o .: "name" + indexedKeys <- o .: "indexedKeys" + ignoredFunctions <- o .:? "ignoredFunctions" .!= [] + ignoredModules <- o .:? "ignoredModules" .!= [] + checkModules <- o .:? "checkModules" .!= ["*"] + return YamlTable {..} + +data YamlTableKeys = + NonCompositeKey String + | CompositeKey { cols :: [String] } deriving (Show, Eq) instance FromJSON YamlTableKeys where @@ -119,85 +131,135 @@ instance FromJSON YamlTableKeys where data CompileError = CompileError { - pkg_name :: String, - mod_name :: String, - err_msg :: String, - src_span :: SrcSpan, - violation :: Violation, - suggested_fixes :: Suggestions, - error_info :: Value + pkg_name :: String, + mod_name :: String, + err_msg :: String, + src_span :: SrcSpan, + violation :: Violation, + suggested_fixes :: Suggestions, + error_info :: Value } deriving (Eq, Show) -instance ToJSON CompileError where - toJSON (CompileError pkg modName errMsg srcLoc vlt suggestions errorInfo) = - object [ "package_name" .= pkg - , "module_name" .= modName - , "error_message" .= errMsg - , "src_span" .= show srcLoc - , "violation_type" .= getViolationType vlt - , "violated_rule" .= getViolationRuleName vlt - , "suggested_fixes" .= suggestions - , "error_info" .= errorInfo - ] - -type Rules = [Rule] -type ArgNo = Int -type ArgTypes = [String] -type FnsBlockedInArg = [(String, ArgNo, TypesAllowedInArg)] -type TypesAllowedInArg = [String] -type TypesBlockedInArg = [String] -type TypesToCheckInArg = [String] -type Suggestions = [String] -type Modules = [String] - data FunctionRule = FunctionRule { - fn_rule_name :: String, - fn_name :: String, - arg_no :: ArgNo, - fns_blocked_in_arg :: FnsBlockedInArg, - types_blocked_in_arg :: TypesBlockedInArg, - types_to_check_in_arg :: TypesToCheckInArg, - fn_rule_fixes :: Suggestions, - fn_rule_exceptions :: Rules, - fn_rule_ignore_modules :: Modules + fn_rule_name :: String, + fn_name :: FunctionNames, + arg_no :: ArgNo, + fn_sigs_blocked :: SignaturesBlockedInFn, + fns_blocked_in_arg :: FnsBlockedInArg, + types_blocked_in_arg :: TypesBlockedInArg, + types_to_check_in_arg :: TypesToCheckInArg, + fn_rule_fixes :: Suggestions, + fn_rule_exceptions :: Rules, + fn_rule_ignore_modules :: Modules, + fn_rule_check_modules :: Modules, + fn_rule_ignore_functions :: ModulesWithFunctions } deriving (Show, Eq) +defaultFunctionRule :: FunctionRule +defaultFunctionRule = FunctionRule { + fn_rule_name = "NA", + fn_name = [], + arg_no = -1, + fn_sigs_blocked = [], + fns_blocked_in_arg = [], + types_blocked_in_arg = [], + types_to_check_in_arg = [], + fn_rule_fixes = [], + fn_rule_exceptions = [], + fn_rule_ignore_modules = [], + fn_rule_check_modules = ["*"], + fn_rule_ignore_functions = [] + } + instance FromJSON FunctionRule where parseJSON = withObject "FunctionRule" $ \o -> do - fn_rule_name <- o .: "fn_rule_name" - fn_name <- o .: "fn_name" - arg_no <- o .: "arg_no" - fns_blocked_in_arg <- o .: "fns_blocked_in_arg" - types_blocked_in_arg <- o .: "types_blocked_in_arg" - types_to_check_in_arg <- o .: "types_to_check_in_arg" - fn_rule_fixes <- o .: "fn_rule_fixes" - fn_rule_exceptions <- o .: "fn_rule_exceptions" - fn_rule_ignore_modules <- o .: "fn_rule_ignore_modules" - return (FunctionRule {fn_rule_name = fn_rule_name, fn_name = fn_name, arg_no = arg_no, fns_blocked_in_arg = fns_blocked_in_arg, types_blocked_in_arg = types_blocked_in_arg, types_to_check_in_arg = types_to_check_in_arg, fn_rule_fixes = fn_rule_fixes, fn_rule_exceptions = fn_rule_exceptions, fn_rule_ignore_modules = fn_rule_ignore_modules }) + fn_rule_name <- o .: "fn_rule_name" + fn_name <- o .: "fn_name" >>= parseAsListOrString + arg_no <- o .: "arg_no" + fn_sigs_blocked <- o .:? "fn_sigs_blocked" .!= (fn_sigs_blocked defaultFunctionRule) + fns_blocked_in_arg <- o .: "fns_blocked_in_arg" + types_blocked_in_arg <- o .: "types_blocked_in_arg" + types_to_check_in_arg <- o .: "types_to_check_in_arg" + fn_rule_fixes <- o .: "fn_rule_fixes" + fn_rule_exceptions <- o .: "fn_rule_exceptions" + fn_rule_ignore_modules <- o .: "fn_rule_ignore_modules" + fn_rule_check_modules <- o .:? "fn_rule_check_modules" .!= (fn_rule_check_modules defaultFunctionRule) + fn_rule_ignore_functions <- o .:? "fn_rule_ignore_functions" .!= (fn_rule_ignore_functions defaultFunctionRule) + return FunctionRule {..} + +data InfiniteRecursionRule = + InfiniteRecursionRule + { + infinite_recursion_rule_name :: String, + infinite_recursion_rule_fixes :: Suggestions, + infinite_recursion_rule_exceptions :: Rules, + infinite_recursion_rule_ignore_modules :: Modules, + infinite_recursion_rule_check_modules :: Modules, + infinite_recursion_rule_ignore_functions :: ModulesWithFunctions + } + deriving (Show, Eq) + +defaultInfiniteRecursionRule :: InfiniteRecursionRule +defaultInfiniteRecursionRule = InfiniteRecursionRule { + infinite_recursion_rule_name = "NA", + infinite_recursion_rule_fixes = [], + infinite_recursion_rule_exceptions = [], + infinite_recursion_rule_ignore_modules = [], + infinite_recursion_rule_check_modules = ["*"], + infinite_recursion_rule_ignore_functions = [] + } + +instance FromJSON InfiniteRecursionRule where + parseJSON = withObject "InfiniteRecursionRule" $ \o -> do + infinite_recursion_rule_name <- o .: "infinite_recursion_rule_name" + infinite_recursion_rule_fixes <- o .: "infinite_recursion_rule_fixes" + infinite_recursion_rule_exceptions <- o .:? "infinite_recursion_rule_exceptions" .!= (infinite_recursion_rule_exceptions defaultInfiniteRecursionRule) + infinite_recursion_rule_ignore_modules <- o .:? "infinite_recursion_rule_ignore_modules" .!= (infinite_recursion_rule_ignore_modules defaultInfiniteRecursionRule) + infinite_recursion_rule_check_modules <- o .:? "infinite_recursion_rule_check_modules" .!= (infinite_recursion_rule_check_modules defaultInfiniteRecursionRule) + infinite_recursion_rule_ignore_functions <- o .:? "infinite_recursion_rule_ignore_functions" .!= (infinite_recursion_rule_ignore_functions defaultInfiniteRecursionRule) + return InfiniteRecursionRule {..} data DBRule = DBRule { - db_rule_name :: String, - table_name :: String, - indexed_cols_names :: [YamlTableKeys], - db_rule_fixes :: Suggestions, - db_rule_exceptions :: Rules + db_rule_name :: String, + table_name :: String, + indexed_cols_names :: [YamlTableKeys], + db_rule_fixes :: Suggestions, + db_rule_exceptions :: Rules, + db_rule_check_modules :: Modules, + db_rule_ignore_modules :: Modules, + db_rule_ignore_functions :: ModulesWithFunctions } deriving (Show, Eq) +defaultDBRule :: DBRule +defaultDBRule = DBRule { + db_rule_name = "NA", + table_name = "NA", + indexed_cols_names = [], + db_rule_fixes = [], + db_rule_exceptions = [], + db_rule_check_modules = ["*"], + db_rule_ignore_modules = [], + db_rule_ignore_functions = [] + } + instance FromJSON DBRule where parseJSON = withObject "DBRule" $ \o -> do - db_rule_name <- o .: "db_rule_name" - table_name <- o .: "table_name" - indexed_cols_names <- o .: "indexed_cols_names" - db_rule_fixes <- o .: "db_rule_fixes" - db_rule_exceptions <- o .: "db_rule_exceptions" - return (DBRule {db_rule_name = db_rule_name, table_name = table_name, indexed_cols_names = indexed_cols_names, db_rule_fixes = db_rule_fixes, db_rule_exceptions = db_rule_exceptions}) + db_rule_name <- o .: "db_rule_name" + table_name <- o .: "table_name" + indexed_cols_names <- o .: "indexed_cols_names" + db_rule_fixes <- o .: "db_rule_fixes" + db_rule_exceptions <- o .: "db_rule_exceptions" + db_rule_check_modules <- o .:? "db_rule_check_modules" .!= (db_rule_check_modules defaultDBRule) + db_rule_ignore_modules <- o .:? "db_rule_ignore_modules" .!= (db_rule_ignore_modules defaultDBRule) + db_rule_ignore_functions <- o .:? "db_rule_ignore_functions" .!= (db_rule_ignore_functions defaultDBRule) + return DBRule {..} data Action = Allowed | Blocked deriving (Show, Eq) @@ -224,14 +286,14 @@ data FunctionInfo = instance FromJSON FunctionInfo where parseJSON = withObject "FunctionInfo" $ \o -> do - fnName <- o .: "fnName" + fnName <- o .: "fnName" isQualified <- o .:? "isQualified" .!= False - argNo <- o .: "argNo" - fnAction <- o .: "action" - argTypes <- o .: "argTypes" - argFns <- o .: "argFns" - suggestedFixes <- o .: "suggestedFixes" - return (FunctionInfo { fnName = fnName, isQualified = isQualified, argNo = argNo, fnAction = fnAction, argTypes = argTypes, argFns = argFns, suggestedFixes = suggestedFixes }) + argNo <- o .: "argNo" + fnAction <- o .: "action" + argTypes <- o .: "argTypes" + argFns <- o .: "argFns" + suggestedFixes <- o .: "suggestedFixes" + return FunctionInfo {..} -- First check for all conditions to be true, and if satisfied, then data GeneralRule = @@ -248,16 +310,17 @@ instance FromJSON GeneralRule where ruleName <- o .: "ruleName" conditions <- o .: "conditions" ruleInfo <- o .: "ruleInfo" - return (GeneralRule {ruleName = ruleName, conditions = conditions, ruleInfo = ruleInfo}) + return GeneralRule {..} data Rule = DBRuleT DBRule | FunctionRuleT FunctionRule + | InfiniteRecursionRuleT InfiniteRecursionRule | GeneralRuleT GeneralRule deriving (Show, Eq) instance FromJSON Rule where - parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) + parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (InfiniteRecursionRuleT <$> parseJSON str) <|> (GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) data LocalVar = FnArg Var | FnWhere Var | FnLocal Var deriving (Eq) @@ -275,94 +338,21 @@ data DBFieldSpecType = deriving (Show, Eq) data Violation = - ArgTypeBlocked String String FunctionRule - | FnBlockedInArg (String, String) Value FunctionRule + ArgTypeBlocked String String String FunctionRule + | FnBlockedInArg (String, String) String Value FunctionRule | NonIndexedDBColumn String String DBRule - | FnUseBlocked FunctionRule + | FnUseBlocked String FunctionRule + | FnSigBlocked String String FunctionRule + | InfiniteRecursionDetected InfiniteRecursionRule | NoViolation deriving (Eq) instance Show Violation where - show (ArgTypeBlocked typ exprTy rule) = "Use of '" <> (fn_name rule) <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." - show (FnBlockedInArg (fnName, typ) _ rule) = "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> (fn_name rule) <> "' is not allowed." - show (FnUseBlocked rule) = "Use of '" <> (fn_name rule) <> "' in the code is not allowed." - show (NonIndexedDBColumn colName tableName _) = "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." - show NoViolation = "NoViolation" - -getViolationSuggestions :: Violation -> Suggestions -getViolationSuggestions v = case v of - ArgTypeBlocked _ _ r -> fn_rule_fixes r - FnBlockedInArg _ _ r -> fn_rule_fixes r - FnUseBlocked r -> fn_rule_fixes r - NonIndexedDBColumn _ _ r -> db_rule_fixes r - NoViolation -> [] - -getViolationType :: Violation -> String -getViolationType v = case v of - ArgTypeBlocked _ _ _ -> "ArgTypeBlocked" - FnBlockedInArg _ _ _ -> "FnBlockedInArg" - FnUseBlocked _ -> "FnUseBlocked" - NonIndexedDBColumn _ _ _ -> "NonIndexedDBColumn" - NoViolation -> "NoViolation" - -getViolationRule :: Violation -> Rule -getViolationRule v = case v of - ArgTypeBlocked _ _ r -> FunctionRuleT r - FnBlockedInArg _ _ r -> FunctionRuleT r - FnUseBlocked r -> FunctionRuleT r - NonIndexedDBColumn _ _ r -> DBRuleT r - NoViolation -> defaultRule - -getViolationRuleName :: Violation -> String -getViolationRuleName v = case v of - ArgTypeBlocked _ _ r -> fn_rule_name r - FnBlockedInArg _ _ r -> fn_rule_name r - FnUseBlocked r -> fn_rule_name r - NonIndexedDBColumn _ _ r -> db_rule_name r - NoViolation -> "NA" - -getViolationRuleExceptions :: Violation -> Rules -getViolationRuleExceptions = getRuleExceptions . getViolationRule - -getErrorInfoFromViolation :: Violation -> Value -getErrorInfoFromViolation violation = case violation of - FnBlockedInArg _ errInfo _ -> errInfo - _ -> A.Null - -getRuleFromCompileError :: CompileError -> Rule -getRuleFromCompileError = getViolationRule . violation - -getRuleExceptionsFromCompileError :: CompileError -> Rules -getRuleExceptionsFromCompileError = getRuleExceptions . getRuleFromCompileError - -getRuleExceptions :: Rule -> Rules -getRuleExceptions rule = case rule of - DBRuleT dbRule -> db_rule_exceptions dbRule - FunctionRuleT fnRule -> fn_rule_exceptions fnRule - _ -> [] - -getRuleIgnoreModules :: Rule -> Modules -getRuleIgnoreModules rule = case rule of - FunctionRuleT fnRule -> fn_rule_ignore_modules fnRule - _ -> [] - -showS :: (Outputable a) => a -> String -showS = showSDocUnsafe . ppr - -showPrettyPrinted :: (Annotate a) => Located a -> String -showPrettyPrinted = flip exactPrint mempty - -showAst :: Data a => a -> String -showAst = showSDocUnsafe . showAstData BlankSrcSpan - -noSuggestion :: Suggestions -noSuggestion = [] - -defaultRule :: Rule -defaultRule = FunctionRuleT $ FunctionRule "NA" "NA" (-1) [] [] [] noSuggestion [] [] - -emptyLoggingError :: CompileError -emptyLoggingError = CompileError "" "" "$NA$" noSrcSpan NoViolation noSuggestion A.Null - -yamlToDbRule :: YamlTable -> Rule -yamlToDbRule table = DBRuleT $ DBRule "NonIndexedDBRule" (tableName table) (indexedKeys table) ["You might want to include an indexed column in the `where` clause of the query."] [] \ No newline at end of file + show violation = case violation of + (ArgTypeBlocked typ exprTy ruleFnName rule) -> "Use of '" <> ruleFnName <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." + (FnBlockedInArg (fnName, typ) ruleFnName _ rule) -> "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> ruleFnName <> "' is not allowed." + (FnUseBlocked ruleFnName rule) -> "Use of '" <> ruleFnName <> "' in the code is not allowed." + (FnSigBlocked ruleFnName ruleFnSig rule) -> "Use of '" <> ruleFnName <> "' with signature '" <> ruleFnSig <> "' is not allowed in the code." + (NonIndexedDBColumn colName tableName _) -> "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." + (InfiniteRecursionDetected _) -> "Infinite recursion detected in expression" + NoViolation -> "NoViolation" \ No newline at end of file diff --git a/sheriff/src/Sheriff/TypesUtils.hs b/sheriff/src/Sheriff/TypesUtils.hs new file mode 100644 index 0000000..50f5555 --- /dev/null +++ b/sheriff/src/Sheriff/TypesUtils.hs @@ -0,0 +1,213 @@ +module Sheriff.TypesUtils where + +import Data.Aeson as A +import qualified Data.ByteString.Lazy.Char8 as Char8 +import Data.List.Extra (splitOn) +import Sheriff.CommonTypes +import Sheriff.Types +import Sheriff.Utils + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Types.SrcLoc +import GHC.Types.Var +import GHC.Utils.Outputable as OP hiding ((<>)) +#else +import SrcLoc +import Var +import Outputable as OP hiding ((<>)) +#endif + +instance ToJSON CompileError where + toJSON (CompileError pkg modName errMsg srcLoc vlt suggestions errorInfo) = + object [ + "error_info" .= errorInfo + , "error_message" .= errMsg + , "module_name" .= modName + , "package_name" .= pkg + , "src_span" .= show srcLoc + , "suggested_fixes" .= suggestions + , "violated_rule" .= getViolationRuleName vlt + , "violation_type" .= getViolationType vlt + ] + +getViolationSuggestions :: Violation -> Suggestions +getViolationSuggestions v = case v of + ArgTypeBlocked _ _ _ r -> fn_rule_fixes r + FnBlockedInArg _ _ _ r -> fn_rule_fixes r + FnUseBlocked _ r -> fn_rule_fixes r + FnSigBlocked _ _ r -> fn_rule_fixes r + NonIndexedDBColumn _ _ r -> db_rule_fixes r + InfiniteRecursionDetected r -> infinite_recursion_rule_fixes r + NoViolation -> [] + +getViolationType :: Violation -> String +getViolationType v = case v of + ArgTypeBlocked _ _ _ _ -> "ArgTypeBlocked" + FnBlockedInArg _ _ _ _ -> "FnBlockedInArg" + FnUseBlocked _ _ -> "FnUseBlocked" + FnSigBlocked _ _ _ -> "FnSigBlocked" + NonIndexedDBColumn _ _ _ -> "NonIndexedDBColumn" + InfiniteRecursionDetected _ -> "InfiniteRecursionDetected" + NoViolation -> "NoViolation" + +getViolationRule :: Violation -> Rule +getViolationRule v = case v of + ArgTypeBlocked _ _ _ r -> FunctionRuleT r + FnBlockedInArg _ _ _ r -> FunctionRuleT r + FnUseBlocked _ r -> FunctionRuleT r + FnSigBlocked _ _ r -> FunctionRuleT r + NonIndexedDBColumn _ _ r -> DBRuleT r + InfiniteRecursionDetected r -> InfiniteRecursionRuleT r + NoViolation -> defaultRule + +getViolationRuleName :: Violation -> String +getViolationRuleName v = case v of + ArgTypeBlocked _ _ _ r -> fn_rule_name r + FnBlockedInArg _ _ _ r -> fn_rule_name r + FnUseBlocked _ r -> fn_rule_name r + FnSigBlocked _ _ r -> fn_rule_name r + NonIndexedDBColumn _ _ r -> db_rule_name r + InfiniteRecursionDetected r -> infinite_recursion_rule_name r + NoViolation -> "NA" + +getViolationRuleExceptions :: Violation -> Rules +getViolationRuleExceptions = getRuleExceptions . getViolationRule + +getErrorInfoFromViolation :: Violation -> Value +getErrorInfoFromViolation violation = case violation of + FnBlockedInArg _ _ errInfo _ -> errInfo + _ -> A.Null + +getRuleFromCompileError :: CompileError -> Rule +getRuleFromCompileError = getViolationRule . violation + +getRuleExceptionsFromCompileError :: CompileError -> Rules +getRuleExceptionsFromCompileError = getRuleExceptions . getRuleFromCompileError + +getRuleExceptions :: Rule -> Rules +getRuleExceptions rule = case rule of + DBRuleT dbRule -> db_rule_exceptions dbRule + FunctionRuleT fnRule -> fn_rule_exceptions fnRule + InfiniteRecursionRuleT infiniteRecursionRule -> infinite_recursion_rule_exceptions infiniteRecursionRule + _ -> [] + +getRuleIgnoreModules :: Rule -> Modules +getRuleIgnoreModules rule = case rule of + FunctionRuleT fnRule -> fn_rule_ignore_modules fnRule + InfiniteRecursionRuleT infiniteRecursionRule -> infinite_recursion_rule_ignore_modules infiniteRecursionRule + DBRuleT dbRule -> db_rule_ignore_modules dbRule + _ -> [] + +getRuleIgnoreFunctions :: Rule -> Modules +getRuleIgnoreFunctions rule = case rule of + FunctionRuleT fnRule -> fn_rule_ignore_functions fnRule + InfiniteRecursionRuleT infiniteRecursionRule -> infinite_recursion_rule_ignore_functions infiniteRecursionRule + DBRuleT dbRule -> db_rule_ignore_functions dbRule + _ -> [] + +getRuleCheckModules :: Rule -> Modules +getRuleCheckModules rule = case rule of + FunctionRuleT fnRule -> fn_rule_check_modules fnRule + InfiniteRecursionRuleT infiniteRecursionRule -> infinite_recursion_rule_check_modules infiniteRecursionRule + DBRuleT dbRule -> db_rule_check_modules dbRule + _ -> ["*"] + +getRuleName :: Rule -> String +getRuleName rule = case rule of + FunctionRuleT fnRule -> fn_rule_name fnRule + DBRuleT dbRule -> db_rule_name dbRule + InfiniteRecursionRuleT infiniteRecursionRule -> infinite_recursion_rule_name infiniteRecursionRule + _ -> "Rule not handled" + +noSuggestion :: Suggestions +noSuggestion = [] + +defaultRule :: Rule +defaultRule = FunctionRuleT defaultFunctionRule + +emptyLoggingError :: CompileError +emptyLoggingError = CompileError "" "" "$NA$" noSrcSpan NoViolation noSuggestion A.Null + +yamlToDbRule :: YamlTable -> Rule +yamlToDbRule table = DBRuleT $ DBRule { + db_rule_name = "NonIndexedDBRule", + table_name = tableName table, + indexed_cols_names = indexedKeys table, + db_rule_fixes = ["You might want to include an indexed column in the `where` clause of the query."], + db_rule_exceptions = [], + db_rule_check_modules = checkModules table, + db_rule_ignore_modules = ignoredModules table, + db_rule_ignore_functions = ignoredFunctions table + } + +updateValInOpts :: String -> String -> PluginOpts -> PluginOpts +updateValInOpts key val currentOpts = case key of + "saveToFile" -> + case decode (Char8.pack val) of + Just v -> currentOpts {saveToFile = v} + Nothing -> currentOpts + "throwCompilationError" -> + case decode (Char8.pack val) of + Just v -> currentOpts {throwCompilationError = v} + Nothing -> currentOpts + "failOnFileNotFound" -> + case decode (Char8.pack val) of + Just v -> currentOpts {failOnFileNotFound = v} + Nothing -> currentOpts + "savePath" -> + case decode (Char8.pack val) of + Just v -> currentOpts {savePath = v} + Nothing -> currentOpts + "indexedKeysPath" -> + case decode (Char8.pack val) of + Just v -> currentOpts {indexedKeysPath = v} + Nothing -> currentOpts + "rulesConfigPath" -> + case decode (Char8.pack val) of + Just v -> currentOpts {rulesConfigPath = v} + Nothing -> currentOpts + "exceptionsConfigPath" -> + case decode (Char8.pack val) of + Just v -> currentOpts {exceptionsConfigPath = v} + Nothing -> currentOpts + "matchAllInsideAnd" -> + case decode (Char8.pack val) of + Just v -> currentOpts {matchAllInsideAnd = v} + Nothing -> currentOpts + "shouldCheckExceptions" -> + case decode (Char8.pack val) of + Just v -> currentOpts {shouldCheckExceptions = v} + Nothing -> currentOpts + "logDebugInfo" -> + case decode (Char8.pack val) of + Just v -> currentOpts {logDebugInfo = v} + Nothing -> currentOpts + "logWarnInfo" -> + case decode (Char8.pack val) of + Just v -> currentOpts {logWarnInfo = v} + Nothing -> currentOpts + "logTypeDebugging" -> + case decode (Char8.pack val) of + Just v -> currentOpts {logTypeDebugging = v} + Nothing -> currentOpts + "useIOForSourceCode" -> + case decode (Char8.pack val) of + Just v -> currentOpts {useIOForSourceCode = v} + Nothing -> currentOpts + _ -> currentOpts + +{- Note: We do not allow sheriff plugin opts in individual module as of now +decodeAndUpdateOpts :: [String] -> PluginOpts -> PluginOpts +decodeAndUpdateOpts [] currentOpts = currentOpts +decodeAndUpdateOpts (x : xs) currentOpts = case A.decode (Char8.pack x) of + Just decodedOpts -> decodeAndUpdateOpts xs decodedOpts + Nothing -> case (splitOn "=" x) of + (key:val:[]) -> decodeAndUpdateOpts xs (updateValInOpts key val currentOpts) + _ -> decodeAndUpdateOpts xs currentOpts +-} + +decodeAndUpdateOpts :: [String] -> PluginOpts -> PluginOpts +decodeAndUpdateOpts [] currentOpts = currentOpts +decodeAndUpdateOpts (x : xs) currentOpts = case A.decode (Char8.pack x) of + Just decodedOpts -> decodedOpts + Nothing -> currentOpts \ No newline at end of file diff --git a/sheriff/src/Sheriff/Utils.hs b/sheriff/src/Sheriff/Utils.hs new file mode 100644 index 0000000..1ef0ea0 --- /dev/null +++ b/sheriff/src/Sheriff/Utils.hs @@ -0,0 +1,466 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE TypeFamilies #-} + +module Sheriff.Utils where + +import Control.Applicative ((<|>)) +import Control.Exception +import Control.Monad (when) +import Control.Monad.IO.Class (MonadIO (..)) +import Data.Aeson +import Data.Bool +import Data.Data (Data) +import Data.Generics.Uniplate.Data +import qualified Data.HashMap.Strict as HM +import Data.List.Extra (splitOn, trim, isInfixOf) +import Data.Maybe (maybe) +import qualified Data.Text as T +import Data.Yaml +import GHC hiding (exprType) +import GHC.Hs.Dump +import GHC.Hs.Extension +import Language.Haskell.GHC.ExactPrint (exactPrint) +import Sheriff.Patterns +import Sheriff.CommonTypes + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Core.ConLike +import GHC.Core.TyCo.Rep +import GHC.Data.IOEnv +import GHC.Driver.Main +import GHC.HsToCore.Expr +import GHC.HsToCore.Monad +import GHC.Plugins hiding ((<>), getHscEnv) +import GHC.Tc.Gen.Expr +import GHC.Tc.Module +import GHC.Tc.Types +import GHC.Tc.Utils.TcType +import Language.Haskell.GHC.ExactPrint (ExactPrint) +#else +import ConLike +import DsMonad +import DsExpr +import GhcPlugins hiding ((<>), getHscEnv) +import HscMain +import Language.Haskell.GHC.ExactPrint.Annotater (Annotate) +import TcExpr +import TcRnDriver +import TcRnMonad +import TcRnTypes +import TcType +import TyCoRep +#endif + +{- + These are the common utility functions which can be used for building any plugin of any sort + Mainly it has generic functions for all - parse, rename and typecheck plugin. +-} + +-- Debug Show any haskell internal representation type +showS :: (Outputable a) => a -> String +showS = showSDocUnsafe . ppr + +matchLocatedVarNamesWithModuleName :: (HasPluginOpts a) => Located Var -> Located Var -> AsteriskMatching -> Bool +matchLocatedVarNamesWithModuleName v1 v2 asteriskMatching = matchVarNamesWithModuleName (unLoc v1) (unLoc v2) asteriskMatching + +matchVarNamesWithModuleName :: (HasPluginOpts a) => Var -> Var -> AsteriskMatching -> Bool +matchVarNamesWithModuleName v1 v2 asteriskMatching = + let var1nameWithModule = getVarNameWithModuleName v1 + var2nameWithModule = getVarNameWithModuleName v2 + in matchNamesWithModuleName var1nameWithModule var2nameWithModule asteriskMatching + +getLocatedVarNameWithModuleName :: (HasPluginOpts a) => Located Var -> String +getLocatedVarNameWithModuleName lvar = getVarNameWithModuleName $ unLoc lvar + +getVarNameWithModuleName :: (HasPluginOpts a) => Var -> String +getVarNameWithModuleName var = getNameWithModuleName $ varName var + +getNameWithModuleName :: (HasPluginOpts a) => Name -> String +getNameWithModuleName name = + let occName = getOccString name + in getModuleName name <> "." <> occName + +getModuleName :: (HasPluginOpts a) => Name -> String +getModuleName name = + case nameModule_maybe name of + Just modName -> (moduleNameString $ moduleName modName) + Nothing -> (currentModule ?pluginOpts) + +getNameAndModuleNameWithNMV :: (HasPluginOpts a) => Name -> (Name, String) +getNameAndModuleNameWithNMV name = + let modNameMap = nameModuleMap ?pluginOpts + in case getNameAndModuleFromNMV modNameMap (NMV_Name name) of + (nm, Just modName) -> (nm, modName) + (nm, Nothing) -> (nm, getModuleName nm) + +getNameAndModuleFromNMV :: HM.HashMap NameModuleValue NameModuleValue -> NameModuleValue -> (Name, Maybe String) +getNameAndModuleFromNMV mp nmv = case HM.lookup nmv mp of + Just val -> getNameAndModuleFromNMV mp val + Nothing -> case nmv of + NMV_Name nm -> (nm, Nothing) + NMV_ClassModule nm modName -> (nm, Just modName) + +matchNamesWithModuleName :: String -> String -> AsteriskMatching -> Bool +matchNamesWithModuleName varNameWithModule fnToMatch asteriskMatching = + let (varModuleName, varName) = splitAtLastChar '.' varNameWithModule + in case splitAtLastChar '.' fnToMatch of + ("", fnName) -> matchNamesWithAsterisk asteriskMatching varName fnName + (modName, fnName) -> matchNamesWithAsterisk AsteriskInBoth varModuleName modName && matchNamesWithAsterisk asteriskMatching varName fnName + where + splitAtLastChar :: Char -> String -> (String, String) + splitAtLastChar ch str = + let (before, after) = break (== ch) (reverse str) + in (reverse (drop 1 after), reverse before) + +matchNamesWithAsterisk :: AsteriskMatching -> String -> String -> Bool +matchNamesWithAsterisk asteriskMatching str1 str2 = + let splitList1 = splitOn "." str1 + splitList2 = splitOn "." str2 + in go "" "" splitList1 splitList2 + where + checkAsteriskInFirst = (asteriskMatching == AsteriskInFirst || asteriskMatching == AsteriskInBoth) + checkAsteriskInSecond = (asteriskMatching == AsteriskInSecond || asteriskMatching == AsteriskInBoth) + + go :: String -> String -> [String] -> [String] -> Bool + go lastX lastY [] [] = True + go lastX lastY xs [] = lastY == "*" && checkAsteriskInSecond + go lastX lastY [] ys = lastX == "*" && checkAsteriskInFirst + go lastX lastY (x : xs) (y : ys) = (x == y || checkAsteriskInFirst && x == "*" || y == "*" && checkAsteriskInSecond) && go x y xs ys + +-- Pretty print haskell internal representation types using `exactprint` +#if __GLASGOW_HASKELL__ >= 900 +showPrettyPrinted :: (ExactPrint a) => Located a -> String +showPrettyPrinted = exactPrint + +showAst :: Data a => a -> String +showAst = showSDocUnsafe . showAstData BlankSrcSpan BlankEpAnnotations + +noExtFieldOrAnn :: EpAnn a +noExtFieldOrAnn = noAnn + +getLoc2 :: GenLocated (SrcSpanAnn' a) e -> SrcSpan +getLoc2 = getLocA + +noExprLoc :: a -> Located a +noExprLoc = noLoc + +getLocated :: GenLocated (SrcSpanAnn' a) e -> (SrcSpanAnn' b) -> Located e +getLocated ap (SrcSpanAnn _ loc) = L loc (unLoc ap) + +mkGenLocated :: a -> SrcSpan -> GenLocated (SrcAnn ann) a +mkGenLocated e srcSpan = L (noAnnSrcSpan srcSpan) e + +#else +showPrettyPrinted :: (Annotate a) => Located a -> String +showPrettyPrinted = flip exactPrint mempty + +showAst :: Data a => a -> String +showAst = showSDocUnsafe . showAstData BlankSrcSpan + +noExtFieldOrAnn :: NoExtField +noExtFieldOrAnn = noExtField + +getLoc2 :: HasSrcSpan a => a -> SrcSpan +getLoc2 = getLoc + +noExprLoc :: (HasSrcSpan a) => SrcSpanLess a -> a +noExprLoc = noLoc + +getLocated :: (HasSrcSpan a) => a -> SrcSpan -> Located (SrcSpanLess a) +getLocated ap loc = L loc (unLoc ap) + +mkGenLocated :: a -> SrcSpan -> GenLocated SrcSpan a +mkGenLocated e srcSpan = L srcSpan e +#endif + +-- Create Located HSExpr for HsVar type +mkLHsVar :: Located Var -> LHsExpr GhcTc +mkLHsVar (L srcSpan e) = mkGenLocated (HsVar noExtField $ mkGenLocated e srcSpan) srcSpan + +-- Debug print the Type represented in Haskell +debugPrintType :: Type -> String +debugPrintType (TyVarTy v) = "(TyVarTy " <> showS v <> ")" +debugPrintType (AppTy ty1 ty2) = "(AppTy " <> debugPrintType ty1 <> " " <> debugPrintType ty2 <> ")" +debugPrintType (TyConApp tycon tys) = "(TyConApp (" <> showS tycon <> ") [" <> foldr (\x r -> debugPrintType x <> ", " <> r) "" tys <> "]" +debugPrintType (ForAllTy _ ty) = "(ForAllTy " <> debugPrintType ty <> ")" +debugPrintType (PatFunTy _ ty1 ty2) = "(FunTy " <> debugPrintType ty1 <> " " <> debugPrintType ty2 <> ")" +debugPrintType (LitTy litTy) = "(LitTy " <> showS litTy <> ")" +debugPrintType _ = "" + +-- Get final return type of any type/function signature +getReturnType :: Type -> [Type] +getReturnType typ + | isFunTy typ = getReturnType $ tcFunResultTy typ + | otherwise = let (x, y) = tcSplitAppTys typ in x : y + +-- Get HsLit literal type +-- Similar to GHC library's `hsLitType` function +getLitType :: HsLit (GhcPass p) -> [Type] +getLitType (HsChar _ _) = [charTy] +getLitType (HsCharPrim _ _) = [charTy] +getLitType (HsString _ _) = [stringTy] +getLitType (HsStringPrim _ _) = [stringTy] +getLitType (HsInt _ _) = [intTy] +getLitType (HsIntPrim _ _) = [intTy] +getLitType (HsWordPrim _ _) = [wordTy] +getLitType (HsInt64Prim _ _) = [intTy] +getLitType (HsWord64Prim _ _) = [wordTy] +getLitType (HsInteger _ _ _) = [intTy] +getLitType (HsRat _ _ _) = [doubleTy] +getLitType (HsFloatPrim _ _) = [floatTy] +getLitType (HsDoublePrim _ _) = [doubleTy] +#if __GLASGOW_HASKELL__ < 900 +getLitType _ = [] +#endif + +-- Check if 1st array has any element in 2nd array +hasAny :: Eq a => [a] -- ^ List of elements to look for + -> [a] -- ^ List to search + -> Bool -- ^ Result +hasAny [] _ = False -- An empty search list: always false +hasAny _ [] = False -- An empty list to scan: always false +hasAny search (x:xs) = if x `elem` search then True else hasAny search xs + +-- Check if a Var is fun type +isFunVar :: Var -> Bool +isFunVar = isFunTy . dropForAlls . idType + +-- Check if a Type is Enum type +isEnumType :: Type -> Bool +isEnumType (TyConApp tyCon _) = isEnumerationTyCon tyCon +isEnumType _ = False + +-- Pretty print the Internal Representations +showOutputable :: (MonadIO m, Outputable a) => a -> m () +showOutputable = liftIO . putStrLn . showS + +-- Print the AST +printAst :: (MonadIO m, Data a) => a -> m () +printAst = liftIO . putStrLn . showAst + +-- Parse the YAML file +parseYAMLFile :: (FromJSON a) => FilePath -> IO (Either ParseException a) +parseYAMLFile file = decodeFileEither file + +-- get RealSrcSpan from SrcSpanAnn +extractRealSrcSpan :: SrcSpan -> Maybe RealSrcSpan +extractRealSrcSpan srcSpan = case srcSpan of +#if __GLASGOW_HASKELL__ >= 900 + RealSrcSpan span _ -> Just span + _ -> Nothing +#else + RealSrcSpan span -> Just span + _ -> Nothing +#endif + +-- Function to extract the code segment based on SrcSpan +extractSrcSpanSegment :: SrcSpan -> FilePath -> String -> IO String +extractSrcSpanSegment srcSpan filePath oldCode = case extractRealSrcSpan srcSpan of + Just span -> do + content' <- try (readFile filePath) :: IO (Either SomeException String) + case content' of + Left _ -> pure oldCode + Right content -> do + let fileLines = T.lines (T.pack content) + startLine = srcSpanStartLine span + endLine = srcSpanEndLine span + startCol = srcSpanStartCol span + endCol = srcSpanEndCol span + + -- Extract relevant lines + relevantLines = take (endLine - startLine + 1) $ drop (startLine - 1) fileLines + -- Handle single-line and multi-line spans + result = case relevantLines of + [] -> "" + [singleLine] -> T.take (endCol - startCol) $ T.drop (startCol - 1) singleLine + _ -> T.unlines $ [T.drop (startCol - 1) (head relevantLines)] ++ + (init (tail relevantLines)) ++ + [T.take endCol (last relevantLines)] + pure $ T.unpack result + _ -> pure oldCode + +-- Get all nodes with given type `b` starting from `a` (Alternative to `biplateRef`) +traverseAst :: (Data from, Data to) => from -> [to] +traverseAst node = traverseAstConditionally node (const False) + +-- Get all nodes with given type `b` starting from `a` (Alternative to `biplateRef` but with more granular control using a predicate) +traverseAstConditionally :: (Data from, Data to) => from -> (to -> Bool) -> [to] +traverseAstConditionally node pred = + let firstLevel = childrenBi node + in traverseConditionalUni pred firstLevel + +-- Takes a predicate which return true if further expansion is not required while traversing AST, false otherwise +traverseConditionalUni :: (Data to) => (to -> Bool) -> [to] -> [to] +traverseConditionalUni _ [] = [] +traverseConditionalUni p (x : xs) = + if p x + then x : traverseConditionalUni p xs + else (x : traverseConditionalUni p (children x)) <> traverseConditionalUni p xs + +-- Get type for a LHsExpr GhcTc +getHsExprType :: Bool -> LHsExpr GhcTc -> TcM Type +getHsExprType logTypeDebugging expr = do + coreExpr <- initDsTc $ dsLExpr expr + let typ = exprType coreExpr + when logTypeDebugging $ liftIO . print $ "DebugType = " <> (debugPrintType typ) + pure typ + +-- Get type for a LHsExpr GhcTc with resolving type aliases to `data` or `newtype` +getHsExprTypeWithResolver :: Bool -> LHsExpr GhcTc -> TcM Type +getHsExprTypeWithResolver logTypeDebugging expr = deNoteType <$> getHsExprType logTypeDebugging expr + +-- TODO: Add support for matching constraints +-- Get Qualified Types as List +getHsExprTypeAsTypeDataListWithConstraintCheck :: (HasPluginOpts a) => Bool -> Type -> [TypeData] +getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg typ = case typ of + LitTy ty -> [TextTy $ showS ty] + TyVarTy var -> [TextTy $ getVarNameWithModuleName var] + TyConApp tycon tys -> [NestedTy $ [TextTy $ getNameWithModuleName (tyConName tycon)] <> (concat $ fmap (getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg) tys)] + AppTy ty1 ty2 -> getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty1 <> getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty2 + ForAllTy _ ty -> getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty + PatFunTy anonArgFlag ty1 ty2 -> bool (getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty1 <> getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty2) (getHsExprTypeAsTypeDataListWithConstraintCheck ignoreConstraintArg ty2) (ignoreConstraintArg && anonArgFlag == InvisArg) + _ -> [] + +-- Get Qualified Types as List Ignoring constraint checks +getHsExprTypeAsTypeDataList :: (HasPluginOpts a) => Type -> [TypeData] +getHsExprTypeAsTypeDataList = getHsExprTypeAsTypeDataListWithConstraintCheck True + +-- Get Qualified Types as List +getHsExprTypeAsTypeDataListKeepConstraints :: (HasPluginOpts a) => Type -> [TypeData] +getHsExprTypeAsTypeDataListKeepConstraints = getHsExprTypeAsTypeDataListWithConstraintCheck False + +parseParenData :: String -> ([TypeData], String) +parseParenData [] = ([], []) +parseParenData (x:xs) + | x == '(' = let (nestedData, rest) = parseParenData xs + (remainingData, rest') = parseParenData rest + in (NestedTy nestedData : remainingData, rest') + | x == ')' = ([], xs) + | otherwise = let (textData, rest) = parseParenData xs + in case textData of + (TextTy t : ts) -> if x == ' ' then (TextTy t : ts, rest) else (TextTy (x:t) : ts, rest) -- append char to current text if it is not empty space + _ -> if x == ' ' then (textData, rest) else (TextTy [x] : textData, rest) -- start new text if it is not empty space + +-- Top-level function to handle parsing from the root +extractParenData :: String -> [TypeData] +extractParenData str = fst (parseParenData str) + +-- Match function signatures +matchFnSignatures :: [TypeData] -> String -> Bool +matchFnSignatures exprSig ruleSig = + let splitRuleSig = fmap (NestedTy . extractParenData . trim) $ splitOn "->" ruleSig + in go exprSig splitRuleSig + where + go :: [TypeData] -> [TypeData] -> Bool + go [] [] = True + go (x : xs) [] = x == TextTy "*" + go [] (y : ys) = y == TextTy "*" + go (x : xs) (y : ys) + | x == TextTy "*" = go xs ys + | y == TextTy "*" = go xs ys + | otherwise = case (x, y) of + (TextTy a, TextTy b) -> matchNamesWithModuleName a b AsteriskInBoth && go xs ys + (NestedTy a, NestedTy b) -> go a b && go xs ys + _ -> False + +-- Get name of the variable +getVarName :: IdP GhcTc -> String +getVarName var = occNameString . occName $ var + +-- Generic function to get type for a LHsExpr (GhcPass p) at any compilation phase p +getHsExprTypeGeneric :: forall p m. (IsPass p) => Bool -> LHsExpr (GhcPass p) -> PassMonad p (Maybe Type) +getHsExprTypeGeneric logTypeDebugging expr = case ghcPass @p of + GhcPs -> do + e <- getHscEnv + (_, mbType) <- liftIO $ tcRnExpr e TM_Inst expr + when logTypeDebugging $ liftIO . print $ "DebugType = " <> (maybe "Type can not be decoded" debugPrintType mbType) + pure mbType + GhcRn -> do + e <- getEnv + (_, typ) <- liftIO $ runIOEnv e $ tcInferRho expr + when logTypeDebugging $ liftIO . print $ "DebugType = " <> (debugPrintType typ) + pure (Just typ) + GhcTc -> do + e <- getEnv + typ <- liftIO $ runIOEnv e $ exprType <$> initDsTc (dsLExpr expr) + when logTypeDebugging $ liftIO . print $ "DebugType = " <> (debugPrintType typ) + pure (Just typ) + +parseAsListOrString :: Value -> Parser [String] +parseAsListOrString v = parseJSON v <|> fmap (:[]) (parseJSON v) + +-- Get Var for the data constructor +conLikeWrapId :: ConLike -> Maybe Var +conLikeWrapId (RealDataCon dc) = Just (dataConWrapId dc) +conLikeWrapId _ = Nothing + +-- TODO: Verify the correctness of this function +-- Get Pattern Match as SimpleTcExpr +trfPatToSimpleTcExpr :: Pat GhcTc -> SimpleTcExpr +trfPatToSimpleTcExpr pat = case pat of + VarPat _ (L _ var) -> SimpleVar var + LazyPat _ (L _ lPat) -> trfPatToSimpleTcExpr lPat + AsPat _ (L _ var) (L _ sPat) -> SimpleAliasPat (SimpleVar var) (trfPatToSimpleTcExpr sPat) + ParPat _ (L _ sPat) -> trfPatToSimpleTcExpr sPat + BangPat _ (L _ sPat) -> trfPatToSimpleTcExpr sPat + SigPat _ (L _ sPat) _ -> trfPatToSimpleTcExpr sPat + ListPat _ lPatList -> SimpleList (fmap (trfPatToSimpleTcExpr . unLoc) lPatList) + TuplePat _ lPatList _ -> SimpleTuple (fmap (trfPatToSimpleTcExpr . unLoc) lPatList) + LitPat _ lit -> SimpleLit lit + NPat _ (L _ (OverLit{ol_val = overloadedLit})) _ _ -> SimpleOverloadedLit overloadedLit +#if __GLASGOW_HASKELL__ >= 900 + ConPat _ (L _ con) (PrefixCon [] lPatList) -> SimpleDataCon (conLikeWrapId con) (fmap (trfPatToSimpleTcExpr . unLoc) lPatList) +#else + ConPatIn (L _ con) (PrefixCon lPatList) -> SimpleDataCon (Just con) (fmap (trfPatToSimpleTcExpr . unLoc) lPatList) + ConPatOut (L _ con) _ _ _ _ (PrefixCon lPatList) _ -> SimpleDataCon (conLikeWrapId con) (fmap (trfPatToSimpleTcExpr . unLoc) lPatList) +#endif + _ -> SimpleUnhandledTcExpr + +-- TODO: Verify the correctness of this function +-- Get LHsExpr as SimpleTcExpr +trfLHsExprToSimpleTcExpr :: LHsExpr GhcTc -> SimpleTcExpr +trfLHsExprToSimpleTcExpr (L loc hsExpr) = case hsExpr of + HsVar _ (L _ var) -> SimpleVar var + HsConLikeOut _ cl -> SimpleDataCon (conLikeWrapId cl) [] + HsLit _ lit -> SimpleLit lit + HsPar _ expr -> trfLHsExprToSimpleTcExpr expr + HsAppType _ expr _ -> trfLHsExprToSimpleTcExpr expr + PatHsWrap _ expr -> trfLHsExprToSimpleTcExpr (L loc expr) + ExplicitTuple _ ls _ -> SimpleTuple (fmap trfTupleArg ls) + PatExplicitList _ ls -> SimpleList (fmap trfLHsExprToSimpleTcExpr ls) + ExprWithTySig _ expr _ -> trfLHsExprToSimpleTcExpr expr +#if __GLASGOW_HASKELL__ >= 900 + PatHsExpansion _ expanded -> trfLHsExprToSimpleTcExpr (L loc expanded) +#endif + HsOverLit _ (OverLit{ol_val = overloadedLit}) -> SimpleOverloadedLit overloadedLit + HsApp _ (L _ (HsConLikeOut _ cl)) funr -> SimpleDataCon (conLikeWrapId cl) [trfLHsExprToSimpleTcExpr funr] + HsApp _ funl funr -> + case trfLHsExprToSimpleTcExpr funl of + SimpleDataCon mbVar ls -> SimpleDataCon mbVar (ls ++ [trfLHsExprToSimpleTcExpr funr]) + _ -> SimpleUnhandledTcExpr + _ -> SimpleUnhandledTcExpr + where +#if __GLASGOW_HASKELL__ >= 900 + trfTupleArg :: HsTupArg GhcTc -> SimpleTcExpr + trfTupleArg hsTupleArg = case hsTupleArg of + Present _ lhsExpr -> trfLHsExprToSimpleTcExpr lhsExpr + _ -> SimpleUnhandledTcExpr +#else + trfTupleArg :: LHsTupArg GhcTc -> SimpleTcExpr + trfTupleArg (L _ hsTupleArg) = case hsTupleArg of + Present _ lhsExpr -> trfLHsExprToSimpleTcExpr lhsExpr + _ -> SimpleUnhandledTcExpr +#endif + +instance StrictEq SimpleTcExpr where + (===) (SimpleFnNameVar var1 ty1) (SimpleFnNameVar var2 ty2) = + -- trace (if "sameName" `isInfixOf` getVarName var1; then show (getNameAndModuleNameWithNMV (varName var1)) <> " ::: " <> show (getNameAndModuleNameWithNMV (varName var2)); else "") $ + (getNameAndModuleNameWithNMV (varName var1) == getNameAndModuleNameWithNMV (varName var2)) && -- match name unique and module name + (getVarName var1 == getVarName var2) && -- match function name (can be avoided) + (getHsExprTypeAsTypeDataList ty1 == getHsExprTypeAsTypeDataList ty2) -- Match types for instances resolution + (===) var1 var2 = (var1 == var2) \ No newline at end of file diff --git a/sheriff/test/Exceptions.hs b/sheriff/test/Exceptions.hs new file mode 100644 index 0000000..eea3253 --- /dev/null +++ b/sheriff/test/Exceptions.hs @@ -0,0 +1,28 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} + +module Exceptions where + +import Control.Concurrent (threadDelay) +import qualified Sheriff.Plugin () +import qualified TestUtils as TU + +main :: IO () +main = do + print $ TU.throwExceptionV2 "Hello" -- Should not throw Error as per rules + print $ TU.throwExceptionV4 "Hello" -- Should not throw Error as per rules + +-- Infinite recursion, but genuine case; should be ignored to test +pattern6 :: IO a -> Int -> IO a +pattern6 flow delay = do + !_ <- flow + !_ <- threadDelay delay + pattern6 flow delay diff --git a/sheriff/test/Main.hs b/sheriff/test/Main.hs index 11a43f4..3de07eb 100644 --- a/sheriff/test/Main.hs +++ b/sheriff/test/Main.hs @@ -1,127 +1,11 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} module Main (main) where -import qualified Sheriff.Plugin () -import Data.Text as T -import qualified Data.Text.Lazy as DTL -import qualified Data.Text.Encoding as DTE -import Data.Aeson as A -import GHC.Generics -import qualified Data.ByteString.Lazy as BSL -import qualified Test1 as T1 - --- -- Data Types Declarations --- data A = A Int String --- deriving (Generic, Show, ToJSON, FromJSON) - --- data B = B {f1 :: Int, f2 :: A, f3 :: Text} --- deriving (Generic, Show, ToJSON, FromJSON) - --- data CC = C1 | C2 Text | C3 Int --- deriving (Generic, Show, ToJSON, FromJSON) - --- -- Data objects --- obA :: A --- obA = A 25 "Hello ObjectA" - --- obB :: B --- obB = B 20 obA "Hello ObjectB" - --- obC1 :: CC --- obC1 = C1 - --- obC2 :: CC --- obC2 = C2 "Hello ObjectC" - --- obC3 :: CC --- obC3 = C3 30 - --- str1 :: Text --- str1 = encodeJSON ("Hello Str1" :: Text) - --- str2 :: Text --- str2 = "Hello Str2" - --- str3 :: Text --- str3 = T.pack $ show "Hello Str3" - --- -- Helper function --- encodeJSON :: (ToJSON a) => a -> Text --- encodeJSON = DTE.decodeUtf8 . BSL.toStrict . A.encode - --- logErrorV :: (ToJSON a) => a -> IO () --- logErrorV = print . toJSON - --- logErrorT :: Text -> IO () --- logErrorT = print - --- -- Test Cases Objects --- obAT1 :: Text --- obAT1 = T.pack $ show obA - --- obAT2 :: Text --- obAT2 = encodeJSON obA - --- obBT1 :: Text --- obBT1 = T.pack $ show obB - --- obBT2 :: Text --- obBT2 = encodeJSON obB - --- obC1T1 :: Text --- obC1T1 = T.pack $ show obC1 - --- obC1T2 :: Text --- obC1T2 = encodeJSON obC1 - --- obC2T1 :: Text --- obC2T1 = T.pack $ show obC2 - --- obC2T2 :: Text --- obC2T2 = encodeJSON obC2 - --- obC3T1 :: Text --- obC3T1 = T.pack $ show obC3 - --- obC3T2 :: Text --- obC3T2 = encodeJSON obC3 - --- Test Case 1: Text inside logErrorT (No error should be raised by plugin) --- Test Case 2: Text inside logErrorV (An error should be raised by plugin) --- Test Case 3: Object inside logErrorV (No error should be generated) --- Test Case 4: Object inside logErrorT (By default, compile time error) --- Test Case 5: `show Object` inside logErrorT (An error should be raised by the plugin) --- Test Case 6: `show object` inside logErrorV (An error should be raised by the plugin) --- Test Case 7: `encode object` inside logErrorT (An error should be raised by the plugin) --- Test Case 8: `encode object` inside logErrorV (An error should be raised by the plugin) - --- Overall, when passing things inside log functions, we should never had `show` or `encode` applied to them --- If we are passing some object, then we should use logErrorV and ideally it should have instance of `ToJSON` - -ob :: T1.SeqIs -ob = T1.SeqIs "fldName" "fldValue" - -loger :: Text -> IO () -loger x = T1.logErrorT (T1.encodeJSON x <> (T.pack $ show "Hello")) - main :: IO () main = do - putStrLn "Test suite not yet implemented." - print ("Hello there" :: String) - let obAT1 = "Hello" <> (T.pack $ show "Dummy2") - loger T1.obAT1 - loger obAT1 - loger obAT1 - T1.logErrorT (T.pack $ show obAT1) - T1.logErrorT ((T.pack . show) obAT1) - T1.logErrorT ((T.pack . show) "Hello" <> obAT1) - T1.logErrorT (T.pack "Hello" <> obAT1) - where - obAT1 = T1.encodeJSON ("Dummy3" :: String) - --- main2 :: \ No newline at end of file + putStrLn "Test suite not yet implemented." \ No newline at end of file diff --git a/sheriff/test/SubTests/FunctionUseTest.hs b/sheriff/test/SubTests/FunctionUseTest.hs new file mode 100644 index 0000000..5398212 --- /dev/null +++ b/sheriff/test/SubTests/FunctionUseTest.hs @@ -0,0 +1,273 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} + +module SubTests.FunctionUseTest where + +import qualified Sheriff.Plugin () +import qualified TestUtils as TU +import qualified TestUtils +import Data.Text as T +import qualified Data.Text.Lazy as DTL +import qualified Data.Text.Encoding as DTE +import Data.Aeson as A +import Data.String +import GHC.Generics +import qualified Data.ByteString.Lazy as BSL +import Data.Functor.Identity (Identity) +import Data.Kind (Type) + + +-- Data Types Declarations +data A = A Int String + deriving (Generic, Show, ToJSON, FromJSON) + +data B = B {f1 :: Int, f2 :: A, f3 :: Text} + deriving (Generic, Show, ToJSON, FromJSON) + +data DBT (f :: Type -> Type) = DB {dbf1 :: f Int} + deriving (Generic) + +type DB = DBT Identity + +deriving stock instance Show (DBT Identity) +deriving anyclass instance FromJSON (DBT Identity) +deriving anyclass instance ToJSON (DBT Identity) + +data CC = C1 | C2 Text | C3 Int | C4 Bool + deriving (Generic, Show, ToJSON, FromJSON) + +data SeqIs = SeqIs Text Text + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT = X | Y | Z + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT2 = U EnumT | V + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT3 x = M | N + deriving (Generic, Show, ToJSON, FromJSON) + +type P = Text + +en :: EnumT +en = Y + +en2 :: EnumT2 +en2 = V + +en21 :: P +en21 = "Hello" + +en22 :: Text +en22 = "Hello" + +en3 :: EnumT3 () +en3 = M + +ob :: SeqIs +ob = SeqIs "fldName" "fldValue" + +-- Data objects +obA :: A +obA = A 25 "Hello ObjectA" + +obB :: B +obB = B 20 obA "Hello ObjectB" + +obC1 :: CC +obC1 = C1 + +obC2 :: CC +obC2 = C2 "Hello ObjectC" + +obC3 :: CC +obC3 = C3 30 + +obC4 :: CC +obC4 = C4 False + +str1 :: Text +str1 = encodeJSON ("Hello Str1" :: Text) + +str2 :: Text +str2 = "Hello Str2" + +str3 :: Text +str3 = T.pack $ show "Hello Str3" + +str4 :: Text +str4 = T.pack $ show (T.pack "Hello Str4") + +db1 :: DB +db1 = DB 500 + +-- Helper function +encodeJSON :: (ToJSON a) => a -> Text +encodeJSON = DTE.decodeUtf8 . BSL.toStrict . A.encode + +runKVDB :: IO () +runKVDB = print "Somehow it's runKVDB" + +logErrorV :: (ToJSON a) => a -> IO () +logErrorV = print . toJSON + +logDebugT :: Text -> Text -> IO () +logDebugT _ = print + +logDebug :: (Show b) => a -> b -> IO () +logDebug _ = print + +forkErrorLog :: (Show b) => a -> b -> IO () +forkErrorLog _ = print + +logErrorT :: Text -> Text -> IO () +logErrorT _ = print + +logError :: String -> String -> IO () +logError _ = print + +-- Test Cases Objects +obAT1 :: Text +obAT1 = T.pack $ show obA + +obAT2 :: Text +obAT2 = encodeJSON obA + +obBT1 :: Text +obBT1 = T.pack $ show obB + +obBT2 :: Text +obBT2 = encodeJSON obB + +obC1T1 :: Text +obC1T1 = T.pack $ show obC1 + +obC1T2 :: Text +obC1T2 = encodeJSON obC1 + +obC2T1 :: Text +obC2T1 = T.pack $ show obC2 + +obC2T2 :: Text +obC2T2 = encodeJSON obC2 + +obC3T1 :: Text +obC3T1 = T.pack $ show obC3 + +obC3T2 :: Text +obC3T2 = encodeJSON obC3 + +num1 :: TU.Number +num1 = TU.Number 20 + +num2 :: TU.Number +num2 = TU.Number 10 + +-- Test Case 1: Text inside logErrorT (No error should be raised by plugin) +-- Test Case 2: Text inside logErrorV (An error should be raised by plugin) +-- Test Case 3: Object inside logErrorV (No error should be generated) +-- Test Case 4: Object inside logErrorT (By default, compile time error) +-- Test Case 5: `show Object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 6: `show object` inside logErrorV (An error should be raised by the plugin) +-- Test Case 7: `encode object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 8: `encode object` inside logErrorV (An error should be raised by the plugin) + +-- Also, from what sources we might be passing the value to logErrorT +-- 1. Received as function argument +-- 2. Created as local bind +-- 3. Imported from another module +-- 1 and 3 are same for me, I have to go and check from where I received it, change of module does not matter + +-- Overall, when passing things inside log functions, we should never had `show` or `encode` or encodeJSON applied to them +-- If we are passing some object, then we should use logErrorV and ideally it should have instance of `ToJSON` + +-- Scenario 1: Parameter sent to logger is modified in the current function and the before modification is not text (We just need to verify that the modification is not `encode`, `encodeJSON` or `show`) +-- Scenario 2: Parameter sent to logger is modified in the current function and the before modification is text and modification is stringification (We need to throw error without recursive backtracking) +-- Scenario 3: Parameter sent to logger is modified in the current function and the before modification is text and modification is not stringification (We need to mark current function as a logger function, and it needs to be checked all calls to this function) +-- Scenario 4: Parameter sent to logger is same as some argument in the current function and that argument is of text type (mark current function as a logger function, and it needs to be checked all calls to this function, it will behave like `logErrorT`) +-- Scenario 5: Parameter sent to logger is same as some argument in the current function and that argument is of non-text type (PASS case) + +addQuotes :: Text -> Text +addQuotes t = "\"" <> t <> "\"" + +noLogFn :: String -> String -> IO () +noLogFn _ _ = pure () + +throwException :: () +throwException = () + +main :: IO () +main = do + putStrLn "Test suite not yet implemented." + print ("HI there" :: String) + let obAT1 = "Dummy" + let b = logInfoT "tester" logger + + -- Test for Qualified Function Names Rules + print $ TU.throwException "Hello" -- should throw error + print throwException -- should NOT throw error + print $ TU.throwExceptionV2 "Hello" -- should throw error as part of combined rule "Hello" + print $ TU.throwExceptionV3 "Hello" + print $ TU.throwExceptionV4 "Hello" -- should throw error as part of combined rule "Hello" + + let (TU.Number sRes) = num1 `TU.subtractNumber` num2 + (TU.Number aRes) = TU.addNumber num1 num2 + (TU.Number mRes) = (TU.*?) Nothing num1 num2 + (TU.Number n1) = TU.fstArg num1 num2 + (TU.Number n2) = TU.sndArg num1 num2 + + print sRes + print aRes + print mRes + print n1 + print n2 + print (n1 * n2) + print ((*) 10 20) + + runKVDB -- Should be error + where + logErrorT = SubTests.FunctionUseTest.logErrorT + +(^*^) :: Num a => a -> a -> a +(^*^) a b = a * b + +logInfoT :: String -> (forall a b. (IsString b, Show a) => String -> a -> b) -> String +logInfoT x _ = x + +logger :: forall a b. (IsString b, Show a) => String -> a -> b +logger _ = fromString . show + +temp :: [Text] +temp = [] + +temp1 :: Maybe Text +temp1 = Nothing + +temp2 :: (Text, Text) +temp2 = ("A", "B") + +temp3 :: (Text, Int) +temp3 = ("A", 10) + +temp4 :: (Int, Int) +temp4 = (20, 10) + +temp5 :: [EnumT] +temp5 = [] + +temp6 :: [EnumT2] +temp6 = [] + +fn :: IO () -> IO () +fn x = do + _ <- x + pure () diff --git a/sheriff/test/SubTests/InfiniteRecursionTest.hs b/sheriff/test/SubTests/InfiniteRecursionTest.hs new file mode 100644 index 0000000..9185c55 --- /dev/null +++ b/sheriff/test/SubTests/InfiniteRecursionTest.hs @@ -0,0 +1,290 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeSynonymInstances #-} + +module SubTests.InfiniteRecursionTest where + +import Control.Concurrent (threadDelay) +import qualified Data.Aeson as A +import qualified TestUtils as TU + +fn1 :: IO String +fn1 = pure "Fn1" + +fn2 :: String +fn2 = "Fn2" + +fn3 :: String -> String +fn3 a = a + +data SumType = TypeA Int | TypeB | RecType SumType + +instance A.ToJSON SumType where + toJSON (TypeA v) = A.toJSON v + toJSON a = A.toJSON a -- STE :: Should Throw Error + +data RecTypeA = RecTypeA { + field1 :: Int, + field2 :: Int + } + +-- Recursive but with different data/variable +pattern1 :: String -> IO String +pattern1 x = do + let y = x <> "SameReturn1" + x <- fn1 + z <- pattern1 x -- Should not throw error since x is changed + pure (y <> z) + +-- Self recursive variable +pattern2 :: String -> String +pattern2 _ = + let sameVal = sameVal <> "Dummy" in sameVal -- STE :: Should Throw Error since infinite self recursive variable usage + +-- Self recursive function +pattern3 :: String -> String +pattern3 a = + let x = fn2 + y = "Dummy" + z = pattern3 a -- STE :: Should Throw Error since infinite self recursive function invocation + in x <> y <> z + +-- Infinite Recursive function definition in `let` +pattern4 :: String -> String -> String +pattern4 x y = + let recFn1 a b = fn3 $ recFn1 a b -- STE :: Should Throw error + in recFn1 x y + +-- Infinite Recursive function definition in `where` +pattern5 :: String -> String -> String +pattern5 x y = recFn2 x y + where + recFn2 a b = fn3 $ recFn2 a b -- STE :: Should Throw error + +-- Infinite recursion, but genuine case; not to be ignored here, but should be ignored in Test2 +pattern6 :: IO a -> Int -> IO a +pattern6 flow delay = do + !_ <- flow + !_ <- threadDelay delay + pattern6 flow delay -- STE :: Should Throw Error + +-- Self recursive straight away +pattern7 :: Int -> Int +pattern7 val = pattern7 val -- STE :: Should Throw Error + +-- Self recursive straight away but partial function +pattern8 :: Int -> Int +pattern8 = pattern8 -- STE :: Should Throw Error + +-- Indirect infinite function +pattern9 :: String -> String +pattern9 a = + let x = fn2 <> y + y = "Dummy" <> x + z = pattern9 a -- STE :: Should throw error since infinite self recursive function invocation + in x <> y <> z + +-- Infinite recursion on pattern match on data function +pattern10 :: SumType -> SumType +pattern10 (TypeA num) = + let res = pattern10 (TypeA num) -- STE :: Should Throw Error + in res +pattern10 _ = TypeB + +-- Pattern matching with infinite recursion +pattern11 :: Int -> Int +pattern11 10 = pattern11 (10 :: Int) -- STE :: Should Throw Error +pattern11 _ = -1 + +-- Pattern matching with infinite recursion +pattern12 :: String -> String +pattern12 "Pattern" = pattern12 "Pattern" -- STE :: Should Throw Error +pattern12 _ = "" + +-- Partial function with lambda case +pattern13 :: String -> String +pattern13 = \case + "Pattern" -> pattern13 "Pattern" -- STE :: Should Throw Error + _ -> "" + +-- Partial function with lambda case +pattern14 :: String -> String +pattern14 = \case + "Pattern" -> "Terminate" + _ -> pattern14 "Pattern" -- Should NOT throw Error + +-- Partial function with lambda case with extra args +pattern15 :: String -> String -> String +pattern15 a = \case + "Pattern" -> "Terminate" + _ -> pattern15 a "Pattern" -- Should NOT throw Error + +-- Partial function with lambda case with extra args +pattern16 :: String -> String -> String -> String +pattern16 a b = \case + "Pattern" -> "Terminate" + _ -> pattern16 a b "Pattern" -- Should NOT throw Error + +-- Partial function with lambda case with extra args but renamed function +pattern17 :: String -> String -> String -> String +pattern17 a b = \case + "Pattern" -> "Terminate" + _ -> + let fn' = pattern17 a b -- Should NOT throw Error + fn1' = pattern17 a -- Should NOT throw Error + in fn' "Pattern" + +-- Partial function with function composition +pattern18 :: String -> String -> String -> String +pattern18 a b = pattern15 "Hello" . pattern18 a b -- STE :: Should Throw Error + +-- Partial function with function composition chain +pattern19 :: String -> String -> String -> String +pattern19 a b = pattern9 . pattern15 "Hello" . pattern19 a b -- STE :: Should Throw Error + +-- Indirect recursion in where clause +pattern20 :: String -> String -> String +pattern20 a b = tempFn + where + tempFn :: String + tempFn = pattern20 a b -- Should NOT Throw Error + +-- Partial function with let-in +pattern21 :: String -> String +pattern21 = + let z = "Dummy" + y = "Hello" + in pattern21 -- STE :: Should Throw Error + +-- Self recursive straight away +pattern22 :: Int +pattern22 = pattern22 -- STE :: Should Throw Error + +-- Same function name but from different module +toJSON :: (A.ToJSON a) => a -> A.Value +toJSON = A.toJSON -- Should NOT Throw Error + +pattern23 :: (Num a) => a -> a +pattern23 numVal = pattern23 numVal -- STE :: Should throw error + +pattern24 :: forall a. (Num a) => a -> a +pattern24 numVal = pattern24 numVal -- STE :: Should throw error + +pattern25 :: forall a. a -> a +pattern25 numVal = pattern25 numVal -- STE :: Should throw error + +-- Same function call for partial function but within some other function +pattern26 :: [Int] -> [Int] +pattern26 = (<>) (concat $ fmap pattern26 [[1..10]]) -- Should NOT throw error + +-- Indirect recursion (may or may not be infinite) +pattern27 :: Int +pattern27 = 10 + where + whereFn :: Int + whereFn = pattern27 -- Should NOT throw Error + +-- Single argument in lambda case +pattern28 :: Int -> Int +pattern28 = \case + 10 -> 20 + lamArg -> pattern28 lamArg -- STE :: Should Throw Error + +-- Nested lambda case +pattern29 :: Int -> Int -> Int +pattern29 = \case + 10 -> \case + 20 -> pattern29 10 20 -- STE :: Should Throw Error + _ -> pattern29 50 60 -- Should NOT Throw Error + lamArg1 -> \case + lamArg2 -> pattern29 lamArg1 lamArg2 -- STE :: Should Throw Error + +-- Same function call for complete function but within some other function +pattern30 :: [Int] -> [Int] +pattern30 inpList = concat $ fmap (: pattern30 inpList) [1..10] -- STE :: Should Throw Error + +-- Lambda with 1 argument +pattern31 :: Int -> Int +pattern31 = \lamArg -> pattern31 lamArg -- STE :: Should Throw Error + +-- Lambda with 2 argument +pattern32 :: Int -> Int -> Int +pattern32 = \lamArg1 lamArg2 -> pattern32 lamArg1 lamArg2 -- STE :: Should Throw Error + +-- Lambda with 2 argument but returning partial function +pattern33 :: Int -> Int +pattern33 = \lamArg -> pattern33 lamArg -- STE :: Should Throw Error + +-- Lambda with 2 argument but changed arg +pattern34 :: Int -> Int -> Int +pattern34 = \lamArg1 lamArg2 -> pattern34 lamArg2 lamArg1 -- Should NOT Throw Error + +-- Lambda with 1 argument but changed arg +pattern35 :: Int -> Int +pattern35 = \lamArg -> pattern35 (lamArg + 5) -- Should NOT Throw Error + +-- Lambda with 1 argument but in let-in statement +pattern36 :: Int -> Int +pattern36 = let u = 20 in \lamArg -> pattern36 lamArg -- STE :: Should Throw Error + +-- Lambda with 1 argument but in function chaining +pattern37 :: Int -> Int +pattern37 = (+ 5) . \lamArg -> pattern37 lamArg -- STE :: Should Throw Error + +-- foldl case +pattern38 :: Int +pattern38 = + let sameName = foldl (\sameName x -> x + sameName) 0 [1..10] -- Should NOT Throw Error + in sameName + +pattern39 :: Int -> Int +pattern39 status = + let status + | status == 0 = status -- STE :: Should Throw Error -- STE :: Should Throw Error + | status > 0 = 1 -- STE :: Should Throw Error + | otherwise = -1 + in status + +class TypeChanger a b where + changeType :: a -> b + +instance TypeChanger Integer Int where + changeType = fromIntegral -- Should not throw error since no recursion + +instance TypeChanger Integer SumType where + changeType = TypeA . changeType -- Should NOT throw Error since type is changed + +instance TypeChanger String SumType where + changeType x = RecType $ changeType x -- STE :: Should Throw Error + +instance TypeChanger Char SumType where + changeType x = let changeType = foldr (\changeType x -> x + changeType) 0 [1..10] in TypeB -- Should NOT throw error + +instance TypeChanger Integer Integer where + changeType = changeType -- STE :: Should throw Error + +instance TypeChanger Integer String where + changeType = \case + 10 -> changeType (20 :: Integer) -- Should NOT throw Error + 50 -> changeType (50 :: Integer) -- STE :: Should Throw Error + val -> changeType val -- STE :: Should throw Error + +main :: IO () +main = do + pure () + + {- + + Note: + 1. Shadow Binding are not an issue as of now + + TODO: + 1. Add check + tests for indirect infinite recursion + 2. Add check + tests for infinite list patterns + 3. Add check + tests for record construction patterns + 4. Add check + tests for Infix Constructor patterns + 5. Validate and add more tests for Pattern matching cases + 6. Add checks for infinite recursion involving guards with pattern matching + + -} \ No newline at end of file diff --git a/sheriff/test/SubTests/LogTest.hs b/sheriff/test/SubTests/LogTest.hs new file mode 100644 index 0000000..93755cb --- /dev/null +++ b/sheriff/test/SubTests/LogTest.hs @@ -0,0 +1,282 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} + +module SubTests.LogTest where + +import qualified Sheriff.Plugin () +import qualified TestUtils as TU +import qualified TestUtils +import Data.Text as T +import qualified Data.Text.Lazy as DTL +import qualified Data.Text.Encoding as DTE +import Data.Aeson as A +import Data.String +import GHC.Generics +import qualified Data.ByteString.Lazy as BSL +import Data.Functor.Identity (Identity) +import Data.Kind (Type) + + +-- Data Types Declarations +data A = A Int String + deriving (Generic, Show, ToJSON, FromJSON) + +data B = B {f1 :: Int, f2 :: A, f3 :: Text} + deriving (Generic, Show, ToJSON, FromJSON) + +data DBT (f :: Type -> Type) = DB {dbf1 :: f Int} + deriving (Generic) + +type DB = DBT Identity + +deriving stock instance Show (DBT Identity) +deriving anyclass instance FromJSON (DBT Identity) +deriving anyclass instance ToJSON (DBT Identity) + +data CC = C1 | C2 Text | C3 Int | C4 Bool + deriving (Generic, Show, ToJSON, FromJSON) + +data SeqIs = SeqIs Text Text + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT = X | Y | Z + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT2 = U EnumT | V + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT3 x = M | N + deriving (Generic, Show, ToJSON, FromJSON) + +type P = Text + +en :: EnumT +en = Y + +en2 :: EnumT2 +en2 = V + +en21 :: P +en21 = "Hello" + +en22 :: Text +en22 = "Hello" + +en3 :: EnumT3 () +en3 = M + +ob :: SeqIs +ob = SeqIs "fldName" "fldValue" + +-- Data objects +obA :: A +obA = A 25 "Hello ObjectA" + +obB :: B +obB = B 20 obA "Hello ObjectB" + +obC1 :: CC +obC1 = C1 + +obC2 :: CC +obC2 = C2 "Hello ObjectC" + +obC3 :: CC +obC3 = C3 30 + +obC4 :: CC +obC4 = C4 False + +str1 :: Text +str1 = encodeJSON ("Hello Str1" :: Text) + +str2 :: Text +str2 = "Hello Str2" + +str3 :: Text +str3 = T.pack $ show "Hello Str3" + +str4 :: Text +str4 = T.pack $ show (T.pack "Hello Str4") + +db1 :: DB +db1 = DB 500 + +-- Helper function +encodeJSON :: (ToJSON a) => a -> Text +encodeJSON = DTE.decodeUtf8 . BSL.toStrict . A.encode + +runKVDB :: IO () +runKVDB = print "Somehow it's runKVDB" + +logErrorV :: (ToJSON a) => a -> IO () +logErrorV = print . toJSON + +logDebugT :: Text -> Text -> IO () +logDebugT _ = print + +logDebug :: (Show b) => a -> b -> IO () +logDebug _ = print + +forkErrorLog :: (Show b) => a -> b -> IO () +forkErrorLog _ = print + +logErrorT :: Text -> Text -> IO () +logErrorT _ = print + +logError :: String -> String -> IO () +logError _ = print + +-- Test Cases Objects +obAT1 :: Text +obAT1 = T.pack $ show obA + +obAT2 :: Text +obAT2 = encodeJSON obA + +obBT1 :: Text +obBT1 = T.pack $ show obB + +obBT2 :: Text +obBT2 = encodeJSON obB + +obC1T1 :: Text +obC1T1 = T.pack $ show obC1 + +obC1T2 :: Text +obC1T2 = encodeJSON obC1 + +obC2T1 :: Text +obC2T1 = T.pack $ show obC2 + +obC2T2 :: Text +obC2T2 = encodeJSON obC2 + +obC3T1 :: Text +obC3T1 = T.pack $ show obC3 + +obC3T2 :: Text +obC3T2 = encodeJSON obC3 + +num1 :: TU.Number +num1 = TU.Number 20 + +num2 :: TU.Number +num2 = TU.Number 10 + +-- Test Case 1: Text inside logErrorT (No error should be raised by plugin) +-- Test Case 2: Text inside logErrorV (An error should be raised by plugin) +-- Test Case 3: Object inside logErrorV (No error should be generated) +-- Test Case 4: Object inside logErrorT (By default, compile time error) +-- Test Case 5: `show Object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 6: `show object` inside logErrorV (An error should be raised by the plugin) +-- Test Case 7: `encode object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 8: `encode object` inside logErrorV (An error should be raised by the plugin) + +-- Also, from what sources we might be passing the value to logErrorT +-- 1. Received as function argument +-- 2. Created as local bind +-- 3. Imported from another module +-- 1 and 3 are same for me, I have to go and check from where I received it, change of module does not matter + +-- Overall, when passing things inside log functions, we should never had `show` or `encode` or encodeJSON applied to them +-- If we are passing some object, then we should use logErrorV and ideally it should have instance of `ToJSON` + +-- Scenario 1: Parameter sent to logger is modified in the current function and the before modification is not text (We just need to verify that the modification is not `encode`, `encodeJSON` or `show`) +-- Scenario 2: Parameter sent to logger is modified in the current function and the before modification is text and modification is stringification (We need to throw error without recursive backtracking) +-- Scenario 3: Parameter sent to logger is modified in the current function and the before modification is text and modification is not stringification (We need to mark current function as a logger function, and it needs to be checked all calls to this function) +-- Scenario 4: Parameter sent to logger is same as some argument in the current function and that argument is of text type (mark current function as a logger function, and it needs to be checked all calls to this function, it will behave like `logErrorT`) +-- Scenario 5: Parameter sent to logger is same as some argument in the current function and that argument is of non-text type (PASS case) + +addQuotes :: Text -> Text +addQuotes t = "\"" <> t <> "\"" + +noLogFn :: String -> String -> IO () +noLogFn _ _ = pure () + +throwException :: () +throwException = () + +main :: IO () +main = do + print ("HI there" :: String) + let obAT1 = "Dummy" + let b = logInfoT "tester" logger + logError "tag" $ show obC1 + logError "tag" $ show en + logErrorT "tag" $ encodeJSON en + logError "tag" $ show en2 + logErrorT (T.pack $ show "tag") $ encodeJSON en2 + logError "tag" $ show en3 + logErrorT "tag" $ encodeJSON en3 + logError "tag" $ show (en, "This is Text" :: String) + logErrorT "tag" $ encodeJSON (en, "This is Text" :: String) + logError "tag" $ show (en, 20 :: Int) -- Should not throw error because of show is allowed on both enums and int + logErrorT "tag" $ encodeJSON (en, 20 :: Int) -- Should throw error because of encodeJSON + logError "tag" $ obAT1 <> show SubTests.LogTest.obAT1 + fn $ logError "tag2" $ show obA + fn $ logError "tag2" $ ("Hello" <> show obA) + forkErrorLog "tag2" $ ("Hello" <> (show $ addQuotes "Testing forkErrorLog")) + forkErrorLog "tag2" $ T.pack $ show "Testing Multiple dollar" + noLogFn "tag2" $ ("Hello" <> (show $ addQuotes "Testing Show on text")) + + logDebug ("Some Tag" :: Text) (show "This to print") + + logError "tag" $ show en2 <> show obA + logError "tag" $ show en <> show obB + logError "tag" $ show temp5 + logError "tag" $ show temp6 + + logDebugT "validateMandate" $ "effective limit is: " <> T.pack (show 10) <> ", custom limit for key: " <> " is " <> T.pack (show (Just ("Hello" :: Text))) + logErrorT "Incorrect Feature in DB" + $ "merchantId: " <> ", error: " <> T.pack (show (["A", "B"] :: [Text])) + where + logErrorT = SubTests.LogTest.logErrorT + +(^*^) :: Num a => a -> a -> a +(^*^) a b = a * b + +logInfoT :: String -> (forall a b. (IsString b, Show a) => String -> a -> b) -> String +logInfoT x _ = x + +logger :: forall a b. (IsString b, Show a) => String -> a -> b +logger _ = fromString . show + +temp :: [Text] +temp = [] + +temp1 :: Maybe Text +temp1 = Nothing + +temp2 :: (Text, Text) +temp2 = ("A", "B") + +temp3 :: (Text, Int) +temp3 = ("A", 10) + +temp4 :: (Int, Int) +temp4 = (20, 10) + +temp5 :: [EnumT] +temp5 = [] + +temp6 :: [EnumT2] +temp6 = [] + +fn :: IO () -> IO () +fn x = do + _ <- x + pure () + +-- myFun :: A -> IO Text +-- myFun ob = do +-- let (A x y) = ob +-- res1 = A x "Modified" +-- res2 = encodeJSON \ No newline at end of file diff --git a/sheriff/test/SubTests/ShowTest.hs b/sheriff/test/SubTests/ShowTest.hs new file mode 100644 index 0000000..d3bf31b --- /dev/null +++ b/sheriff/test/SubTests/ShowTest.hs @@ -0,0 +1,293 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} + +module SubTests.ShowTest where + +import qualified Sheriff.Plugin () +import qualified TestUtils as TU +import qualified TestUtils +import Data.Text as T +import qualified Data.Text.Lazy as DTL +import qualified Data.Text.Encoding as DTE +import Data.Aeson as A +import Data.String +import GHC.Generics +import qualified Data.ByteString.Lazy as BSL +import Data.Functor.Identity (Identity) +import Data.Kind (Type) + + +-- Data Types Declarations +data A = A Int String + deriving (Generic, Show, ToJSON, FromJSON) + +data B = B {f1 :: Int, f2 :: A, f3 :: Text} + deriving (Generic, Show, ToJSON, FromJSON) + +data DBT (f :: Type -> Type) = DB {dbf1 :: f Int} + deriving (Generic) + +type DB = DBT Identity + +deriving stock instance Show (DBT Identity) +deriving anyclass instance FromJSON (DBT Identity) +deriving anyclass instance ToJSON (DBT Identity) + +data CC = C1 | C2 Text | C3 Int | C4 Bool + deriving (Generic, Show, ToJSON, FromJSON) + +data SeqIs = SeqIs Text Text + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT = X | Y | Z + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT2 = U EnumT | V + deriving (Generic, Show, ToJSON, FromJSON) + +data EnumT3 x = M | N + deriving (Generic, Show, ToJSON, FromJSON) + +type P = Text + +en :: EnumT +en = Y + +en2 :: EnumT2 +en2 = V + +en21 :: P +en21 = "Hello" + +en22 :: Text +en22 = "Hello" + +en3 :: EnumT3 () +en3 = M + +ob :: SeqIs +ob = SeqIs "fldName" "fldValue" + +-- Data objects +obA :: A +obA = A 25 "Hello ObjectA" + +obB :: B +obB = B 20 obA "Hello ObjectB" + +obC1 :: CC +obC1 = C1 + +obC2 :: CC +obC2 = C2 "Hello ObjectC" + +obC3 :: CC +obC3 = C3 30 + +obC4 :: CC +obC4 = C4 False + +str1 :: Text +str1 = encodeJSON ("Hello Str1" :: Text) + +str2 :: Text +str2 = "Hello Str2" + +str3 :: Text +str3 = T.pack $ show "Hello Str3" + +str4 :: Text +str4 = T.pack $ show (T.pack "Hello Str4") + +db1 :: DB +db1 = DB 500 + +-- Helper function +encodeJSON :: (ToJSON a) => a -> Text +encodeJSON = DTE.decodeUtf8 . BSL.toStrict . A.encode + +runKVDB :: IO () +runKVDB = print "Somehow it's runKVDB" + +logErrorV :: (ToJSON a) => a -> IO () +logErrorV = print . toJSON + +logDebugT :: Text -> Text -> IO () +logDebugT _ = print + +logDebug :: (Show b) => a -> b -> IO () +logDebug _ = print + +forkErrorLog :: (Show b) => a -> b -> IO () +forkErrorLog _ = print + +logErrorT :: Text -> Text -> IO () +logErrorT _ = print + +logError :: String -> String -> IO () +logError _ = print + +-- Test Cases Objects +obAT1 :: Text +obAT1 = T.pack $ show obA + +obAT2 :: Text +obAT2 = encodeJSON obA + +obBT1 :: Text +obBT1 = T.pack $ show obB + +obBT2 :: Text +obBT2 = encodeJSON obB + +obC1T1 :: Text +obC1T1 = T.pack $ show obC1 + +obC1T2 :: Text +obC1T2 = encodeJSON obC1 + +obC2T1 :: Text +obC2T1 = T.pack $ show obC2 + +obC2T2 :: Text +obC2T2 = encodeJSON obC2 + +obC3T1 :: Text +obC3T1 = T.pack $ show obC3 + +obC3T2 :: Text +obC3T2 = encodeJSON obC3 + +num1 :: TU.Number +num1 = TU.Number 20 + +num2 :: TU.Number +num2 = TU.Number 10 + +-- Test Case 1: Text inside logErrorT (No error should be raised by plugin) +-- Test Case 2: Text inside logErrorV (An error should be raised by plugin) +-- Test Case 3: Object inside logErrorV (No error should be generated) +-- Test Case 4: Object inside logErrorT (By default, compile time error) +-- Test Case 5: `show Object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 6: `show object` inside logErrorV (An error should be raised by the plugin) +-- Test Case 7: `encode object` inside logErrorT (An error should be raised by the plugin) +-- Test Case 8: `encode object` inside logErrorV (An error should be raised by the plugin) + +-- Also, from what sources we might be passing the value to logErrorT +-- 1. Received as function argument +-- 2. Created as local bind +-- 3. Imported from another module +-- 1 and 3 are same for me, I have to go and check from where I received it, change of module does not matter + +-- Overall, when passing things inside log functions, we should never had `show` or `encode` or encodeJSON applied to them +-- If we are passing some object, then we should use logErrorV and ideally it should have instance of `ToJSON` + +-- Scenario 1: Parameter sent to logger is modified in the current function and the before modification is not text (We just need to verify that the modification is not `encode`, `encodeJSON` or `show`) +-- Scenario 2: Parameter sent to logger is modified in the current function and the before modification is text and modification is stringification (We need to throw error without recursive backtracking) +-- Scenario 3: Parameter sent to logger is modified in the current function and the before modification is text and modification is not stringification (We need to mark current function as a logger function, and it needs to be checked all calls to this function) +-- Scenario 4: Parameter sent to logger is same as some argument in the current function and that argument is of text type (mark current function as a logger function, and it needs to be checked all calls to this function, it will behave like `logErrorT`) +-- Scenario 5: Parameter sent to logger is same as some argument in the current function and that argument is of non-text type (PASS case) + +addQuotes :: Text -> Text +addQuotes t = "\"" <> t <> "\"" + +noLogFn :: String -> String -> IO () +noLogFn _ _ = pure () + +throwException :: () +throwException = () + +-- TODO: Add this validation in sheriff (Yet to decide whether to do using variable tracing or show constraint) +showConstraint :: (Show a) => a -> IO () +showConstraint x = putStrLn $ show x + +main :: IO () +main = do + print ("HI there" :: String) + let obAT1 = "Dummy" + mbStr = Just "ToBeFailed" + let b = logInfoT "tester" logger + !_ = show obC1 + !_ = show en + !_ = show en2 + !_ = (T.pack $ show "tag") + !_ = show en3 + !_ = show (en, "This is Text" :: String) + !_ = show (en, 20 :: Int) -- Should not throw error because of show is allowed on both enums and int + !_ = obAT1 <> show SubTests.ShowTest.obAT1 + !_ = fn $ pure $ show obA + !_ = fn $ pure $ ("Hello" <> show obA) + !_ = ("Hello" <> (show $ addQuotes "Testing forkErrorLog")) + !_ = T.pack $ show "Testing Multiple dollar" + !_ = ("Hello" <> (show $ addQuotes "Testing Show on text")) + !_ = (show "This to print") + !_ = show $ mbStr + + print $ show temp + print $ show temp1 + print $ (show) (temp1) + print $ show temp2 + print $ show en2 + print $ show en21 + print $ show en22 + print $ show temp3 + print $ show temp4 + print $ show $ dbf1 db1 + let !_ = show en2 <> show obA + !_ = show en <> show obB + !_ = show temp5 + !_ = show temp6 + !_ = "effective limit is: " <> T.pack (show 10) <> ", custom limit for key: " <> " is " <> T.pack (show (Just ("Hello" :: Text))) + !_ = "merchantId: " <> ", error: " <> T.pack (show (["A", "B"] :: [Text])) + pure () + where + logErrorT = SubTests.ShowTest.logErrorT + +(^*^) :: Num a => a -> a -> a +(^*^) a b = a * b + +logInfoT :: String -> (forall a b. (IsString b, Show a) => String -> a -> b) -> String +logInfoT x _ = x + +logger :: forall a b. (IsString b, Show a) => String -> a -> b +logger _ = fromString . show + +temp :: [Text] +temp = [] + +temp1 :: Maybe Text +temp1 = Nothing + +temp2 :: (Text, Text) +temp2 = ("A", "B") + +temp3 :: (Text, Int) +temp3 = ("A", 10) + +temp4 :: (Int, Int) +temp4 = (20, 10) + +temp5 :: [EnumT] +temp5 = [] + +temp6 :: [EnumT2] +temp6 = [] + +fn :: IO String -> IO () +fn x = do + !_ <- x + pure () + +-- myFun :: A -> IO Text +-- myFun ob = do +-- let (A x y) = ob +-- res1 = A x "Modified" +-- res2 = encodeJSON \ No newline at end of file diff --git a/sheriff/test/Test1.hs b/sheriff/test/Test1.hs index 693d409..b000ec5 100644 --- a/sheriff/test/Test1.hs +++ b/sheriff/test/Test1.hs @@ -1,16 +1,18 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE DuplicateRecordFields #-} module Test1 where import qualified Sheriff.Plugin () +import qualified TestUtils as TU +import qualified TestUtils import Data.Text as T import qualified Data.Text.Lazy as DTL import qualified Data.Text.Encoding as DTE @@ -38,8 +40,6 @@ deriving stock instance Show (DBT Identity) deriving anyclass instance FromJSON (DBT Identity) deriving anyclass instance ToJSON (DBT Identity) -type P = Text - data CC = C1 | C2 Text | C3 Int | C4 Bool deriving (Generic, Show, ToJSON, FromJSON) @@ -55,21 +55,23 @@ data EnumT2 = U EnumT | V data EnumT3 x = M | N deriving (Generic, Show, ToJSON, FromJSON) +type P = Text + en :: EnumT en = Y en2 :: EnumT2 en2 = V -en3 :: EnumT3 () -en3 = M - en21 :: P en21 = "Hello" en22 :: Text en22 = "Hello" +en3 :: EnumT3 () +en3 = M + ob :: SeqIs ob = SeqIs "fldName" "fldValue" @@ -102,21 +104,30 @@ str3 :: Text str3 = T.pack $ show "Hello Str3" str4 :: Text -str4 = T.pack $ show (T.pack "Hello Str3") +str4 = T.pack $ show (T.pack "Hello Str4") db1 :: DB -db1 = DB 20 +db1 = DB 500 -- Helper function encodeJSON :: (ToJSON a) => a -> Text encodeJSON = DTE.decodeUtf8 . BSL.toStrict . A.encode +runKVDB :: IO () +runKVDB = print "Somehow it's runKVDB" + logErrorV :: (ToJSON a) => a -> IO () logErrorV = print . toJSON logDebugT :: Text -> Text -> IO () logDebugT _ = print +logDebug :: (Show b) => a -> b -> IO () +logDebug _ = print + +forkErrorLog :: (Show b) => a -> b -> IO () +forkErrorLog _ = print + logErrorT :: Text -> Text -> IO () logErrorT _ = print @@ -154,6 +165,12 @@ obC3T1 = T.pack $ show obC3 obC3T2 :: Text obC3T2 = encodeJSON obC3 +num1 :: TU.Number +num1 = TU.Number 20 + +num2 :: TU.Number +num2 = TU.Number 10 + -- Test Case 1: Text inside logErrorT (No error should be raised by plugin) -- Test Case 2: Text inside logErrorV (An error should be raised by plugin) -- Test Case 3: Object inside logErrorV (No error should be generated) @@ -178,6 +195,15 @@ obC3T2 = encodeJSON obC3 -- Scenario 4: Parameter sent to logger is same as some argument in the current function and that argument is of text type (mark current function as a logger function, and it needs to be checked all calls to this function, it will behave like `logErrorT`) -- Scenario 5: Parameter sent to logger is same as some argument in the current function and that argument is of non-text type (PASS case) +addQuotes :: Text -> Text +addQuotes t = "\"" <> t <> "\"" + +noLogFn :: String -> String -> IO () +noLogFn _ _ = pure () + +throwException :: () +throwException = () + main :: IO () main = do putStrLn "Test suite not yet implemented." @@ -197,6 +223,19 @@ main = do logErrorT "tag" $ encodeJSON (en, 20 :: Int) -- Should throw error because of encodeJSON logError "tag" $ obAT1 <> show Test1.obAT1 fn $ logError "tag2" $ show obA + fn $ logError "tag2" $ ("Hello" <> show obA) + forkErrorLog "tag2" $ ("Hello" <> (show $ addQuotes "Testing forkErrorLog")) + forkErrorLog "tag2" $ T.pack $ show "Testing Multiple dollar" + noLogFn "tag2" $ ("Hello" <> (show $ addQuotes "Testing Show on text")) + + logDebug ("Some Tag" :: Text) (show "This to print") + + -- Test for Qualified Function Names Rules + print $ TU.throwException "Hello" -- should throw error + print throwException -- should NOT throw error + print $ TU.throwExceptionV2 "Hello" -- should throw error as part of combined rule "Hello" + print $ TU.throwExceptionV3 "Hello" + print $ TU.throwExceptionV4 "Hello" -- should throw error as part of combined rule "Hello" print $ show temp print $ show temp1 @@ -212,6 +251,22 @@ main = do logError "tag" $ show en <> show obB logError "tag" $ show temp5 logError "tag" $ show temp6 + + let (TU.Number sRes) = num1 `TU.subtractNumber` num2 + (TU.Number aRes) = TU.addNumber num1 num2 + (TU.Number mRes) = (TU.*?) Nothing num1 num2 + (TU.Number n1) = TU.fstArg num1 num2 + (TU.Number n2) = TU.sndArg num1 num2 + + print sRes + print aRes + print mRes + print n1 + print n2 + print (n1 * n2) + print ((*) 10 20) + + runKVDB -- Should be error logDebugT "validateMandate" $ "effective limit is: " <> T.pack (show 10) <> ", custom limit for key: " <> " is " <> T.pack (show (Just ("Hello" :: Text))) @@ -220,6 +275,9 @@ main = do where logErrorT = Test1.logErrorT +(^*^) :: Num a => a -> a -> a +(^*^) a b = a * b + logInfoT :: String -> (forall a b. (IsString b, Show a) => String -> a -> b) -> String logInfoT x _ = x diff --git a/sheriff/test/TestUtils.hs b/sheriff/test/TestUtils.hs new file mode 100644 index 0000000..2b59764 --- /dev/null +++ b/sheriff/test/TestUtils.hs @@ -0,0 +1,44 @@ +module TestUtils where + +import Data.Text + +newtype Number = Number Int + +throwException :: Text -> () +throwException _ = () + +throwExceptionV1 :: Text -> () +throwExceptionV1 = throwException + +throwExceptionV2 :: Text -> () +throwExceptionV2 = throwExceptionV1 + +throwExceptionV3 :: Text -> () +throwExceptionV3 = throwExceptionV2 + +throwExceptionV4 :: Text -> () +throwExceptionV4 = throwException + +addNumber :: Number -> Number -> Number +addNumber (Number a) (Number b) = Number (a + b) + +subtractNumber :: Number -> Number -> Number +subtractNumber (Number a) (Number b) = Number (a - b) + +multiplyNumber :: Number -> Number -> Number +multiplyNumber (Number a) (Number b) = Number (a * b) + +(+?) :: Number -> Number -> Number +(+?) = addNumber + +(-?) :: Number -> Number -> Number +(-?) = subtractNumber + +(*?) :: Maybe (Either (Maybe Int) (Maybe Int)) -> Number -> Number -> Number +(*?) _ = multiplyNumber + +fstArg :: a -> a -> a +fstArg a1 a2 = a1 + +sndArg :: a -> a -> a +sndArg a1 a2 = a2 \ No newline at end of file diff --git a/sockets.py b/sockets.py new file mode 100644 index 0000000..47d068c --- /dev/null +++ b/sockets.py @@ -0,0 +1,64 @@ +import asyncio +import os +import concurrent.futures +import json +import asyncio +import json +import websockets +from aiohttp import web +# nix-shell -p python311Packages.websockets +data = dict() + +async def handler(websocket, path): + try: + async for message in websocket: + try: + obj = json.loads(message) + if data.get(path) == None: + data[path] = dict() + if data[path].get(obj.get("key")) == None: + data[path][obj.get("key")] = [] + data[path][obj.get("key")].append(obj) + except Exception as e: + print(e) + except websockets.exceptions.ConnectionClosed as e: + print(e,path) + except Exception as e: + print(e) + +def process_fdep_output(k,v): + os.makedirs(k[1:].replace(".json",""), exist_ok=True) + with open(k[1:],'w') as f: + json.dump(v,f,indent=4) + +async def drain_data(request): + global data + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + future_to_file = {executor.submit(process_fdep_output, k,v): (k,v) for (k,v) in data.items()} + for future in concurrent.futures.as_completed(future_to_file): + pass + print(json.dumps(list(data.keys()))) + exit() + +async def start_websocket_server(): + async with websockets.serve(handler, "localhost", 8000,ping_interval=None,ping_timeout=None,close_timeout=None,max_queue=1000): + print("WebSocket server started on ws://localhost:8000") + await asyncio.Future() + +async def start_http_server(): + app = web.Application() + app.router.add_get('/drain', drain_data) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8080) + await site.start() + print("HTTP server started on http://localhost:8080") + +async def main(): + await asyncio.gather( + start_websocket_server(), + start_http_server() + ) + +if __name__ == "__main__": + asyncio.run(main())