X-Git-Url: https://git.lukelau.me/?a=blobdiff_plain;f=src%2FLanguage%2FHaskell%2FLSP%2FTest%2FSession.hs;h=a82651af11b5617203417505db7bb1268cc3478d;hb=e5da0e9511c679626dbe40a99e8c0de0c968dddf;hp=21c008643e95869badbe46daa67a1fe9bda9209c;hpb=3a38253a1fcd83c83b05fbfbf132d1ead842b0a7;p=lsp-test.git diff --git a/src/Language/Haskell/LSP/Test/Session.hs b/src/Language/Haskell/LSP/Test/Session.hs index 21c0086..a82651a 100644 --- a/src/Language/Haskell/LSP/Test/Session.hs +++ b/src/Language/Haskell/LSP/Test/Session.hs @@ -23,6 +23,8 @@ module Language.Haskell.LSP.Test.Session , sendMessage , updateState , withTimeout + , getCurTimeoutId + , bumpTimeoutId , logMsg , LogMsgType(..) ) @@ -60,16 +62,20 @@ import Data.Function import Language.Haskell.LSP.Messages import Language.Haskell.LSP.Types.Capabilities import Language.Haskell.LSP.Types -import Language.Haskell.LSP.Types.Lens hiding (error) +import Language.Haskell.LSP.Types.Lens +import qualified Language.Haskell.LSP.Types.Lens as LSP import Language.Haskell.LSP.VFS import Language.Haskell.LSP.Test.Compat import Language.Haskell.LSP.Test.Decoding import Language.Haskell.LSP.Test.Exceptions import System.Console.ANSI import System.Directory +import System.FSNotify (watchTree, eventPath, withManager, WatchManager) +import qualified System.FSNotify as FS import System.IO import System.Process (ProcessHandle()) import System.Timeout +import System.FilePath.Glob (match, commonDirectory, compile) -- | A session representing one instance of launching and connecting to a server. -- @@ -99,9 +105,11 @@ data SessionConfig = SessionConfig -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@. , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True. , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing. - -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False. - -- @since 0.9.0.0 , ignoreLogNotifications :: Bool + -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and + -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False. + -- + -- @since 0.9.0.0 } -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'. @@ -119,11 +127,14 @@ data SessionContext = SessionContext { serverIn :: Handle , rootDir :: FilePath - , messageChan :: Chan SessionMessage + , messageChan :: Chan SessionMessage -- ^ Where all messages come through + -- Keep curTimeoutId in SessionContext, as its tied to messageChan + , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on , requestMap :: MVar RequestMap , initRsp :: MVar InitializeResponse , config :: SessionConfig , sessionCapabilities :: ClientCapabilities + , watchManager :: WatchManager } class Monad m => HasReader r m where @@ -134,19 +145,33 @@ class Monad m => HasReader r m where instance HasReader SessionContext Session where ask = Session (lift $ lift Reader.ask) -instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where +instance Monad m => HasReader r (ConduitT a b (StateT s (ReaderT r m))) where ask = lift $ lift Reader.ask +getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int +getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar + +-- Pass this the timeoutid you *were* waiting on +bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m () +bumpTimeoutId prev = do + v <- asks curTimeoutId + -- when updating the curtimeoutid, account for the fact that something else + -- might have bumped the timeoutid in the meantime + liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1))) + data SessionState = SessionState { curReqId :: LspId , vfs :: VFS , curDiagnostics :: Map.Map NormalizedUri [Diagnostic] - , curTimeoutId :: Int , overridingTimeout :: Bool -- ^ The last received message from the server. -- Used for providing exception information , lastReceivedMessage :: Maybe FromServerMessage + , curDynCaps :: Map.Map T.Text Registration + -- ^ The capabilities that the server has dynamically registered with us so + -- far + , unwatchers :: Map.Map T.Text [IO ()] } class Monad m => HasState s m where @@ -164,15 +189,19 @@ instance HasState SessionState Session where get = Session (lift State.get) put = Session . lift . State.put -instance Monad m => HasState s (ConduitM a b (StateT s m)) +instance Monad m => HasState s (StateT s m) where + get = State.get + put = State.put + +instance (Monad m, (HasState s m)) => HasState s (ConduitT a b m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put -instance Monad m => HasState s (ConduitParser a (StateT s m)) +instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState) runSession context state (Session session) = runReaderT (runStateT conduit state) context @@ -196,9 +225,9 @@ runSession context state (Session session) = runReaderT (runStateT conduit state isLogNotification (ServerMessage (NotLogMessage _)) = True isLogNotification _ = False - watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () + watchdog :: ConduitT SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () watchdog = Conduit.awaitForever $ \msg -> do - curId <- curTimeoutId <$> get + curId <- getCurTimeoutId case msg of ServerMessage sMsg -> yield sMsg TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout @@ -227,41 +256,69 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro reqMap <- newMVar newRequestMap messageChan <- newChan + timeoutIdVar <- newMVar 0 initRsp <- newEmptyMVar mainThreadId <- myThreadId - let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps - initState vfs = SessionState (IdInt 0) vfs - mempty 0 False Nothing + withManager $ \watchManager -> do + let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps watchManager + initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty + -- Interesting note: turning on TypeFamilies causes type inference to + -- infer the type runSession' :: Session () -> IO ((), SessionState) + -- instead of runSession' :: Session a -> IO (a , SessionState) runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses errorHandler = throwTo mainThreadId :: SessionException -> IO () serverListenerLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler server = (Just serverIn, Just serverOut, Nothing, serverProc) - serverAndListenerFinalizer tid = - finally (timeout (messageTimeout config * 1000000) + serverAndListenerFinalizer tid = do + finally (timeout (messageTimeout config * 1^6) (runSession' exitServer)) - (cleanupProcess server >> killThread tid) + -- Make sure to kill the listener first, before closing + -- handles etc via cleanupProcess + (killThread tid >> cleanupProcess server) - (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer + (result, _) <- bracket serverListenerLauncher + serverAndListenerFinalizer (const $ runSession' session) return result -updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () +updateStateC :: ConduitT FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () updateStateC = awaitForever $ \msg -> do updateState msg yield msg updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m () + +-- Keep track of dynamic capability registration +updateState (ReqRegisterCapability req) = do + let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations + modify $ \s -> + s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) } + + -- Process the new registrations + forM_ newRegs $ \(regId, reg) -> do + when (reg ^. method == WorkspaceDidChangeWatchedFiles) $ do + processFileWatchRegistration regId reg + +updateState (ReqUnregisterCapability req) = do + let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations + modify $ \s -> + let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs + in s { curDynCaps = newCurDynCaps } + + -- Process the unregistrations + processFileWatchUnregistrations unRegs + updateState (NotPublishDiagnostics n) = do let List diags = n ^. params . diagnostics doc = n ^. params . uri - modify (\s -> + modify $ \s -> let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s) - in s { curDiagnostics = newDiags }) + in s { curDiagnostics = newDiags } updateState (ReqApplyWorkspaceEdit r) = do @@ -307,6 +364,7 @@ updateState (ReqApplyWorkspaceEdit r) = do contents <- liftIO $ T.readFile fp let item = TextDocumentItem (filePathToUri fp) "" 0 contents msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item) + -- TODO: use 'sendMessage'? liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg) modifyM $ \s -> do @@ -334,23 +392,23 @@ sendMessage msg = do logMsg LogClient msg liftIO $ B.hPut h (addHeader $ encode msg) --- | Execute a block f that will throw a 'Timeout' exception +-- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception -- after duration seconds. This will override the global timeout -- for waiting for messages to arrive defined in 'SessionConfig'. withTimeout :: Int -> Session a -> Session a withTimeout duration f = do chan <- asks messageChan - timeoutId <- curTimeoutId <$> get + timeoutId <- getCurTimeoutId modify $ \s -> s { overridingTimeout = True } liftIO $ forkIO $ do threadDelay (duration * 1000000) writeChan chan (TimeoutMessage timeoutId) res <- f - modify $ \s -> s { curTimeoutId = timeoutId + 1, - overridingTimeout = False - } + bumpTimeoutId timeoutId + modify $ \s -> s { overridingTimeout = False } return res +-- TODO: add a shouldTimeout helper. need to add exceptions within Session data LogMsgType = LogServer | LogClient deriving Eq @@ -374,4 +432,51 @@ logMsg t msg = do showPretty = B.unpack . encodePretty - +-- File watching + +processFileWatchRegistration :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) + => T.Text -> Registration -> m () +processFileWatchRegistration regId reg = do + mgr <- asks watchManager + let mOpts = do + regOpts <- reg ^. registerOptions + case fromJSON regOpts of + Error _ -> Nothing + Success x -> Just x + case mOpts of + Nothing -> pure () + Just (DidChangeWatchedFilesRegistrationOptions (List ws)) -> + forM_ ws $ \(FileSystemWatcher pat' watchKind) -> do + pat <- liftIO $ canonicalizePath pat' + let glob = compile pat + -- the root-most dir before any globbing stuff happens + dir = fst $ commonDirectory glob + pred = match glob . eventPath + -- If no watchKind specified, spec defaults to all true + WatchKind wkC wkM wkD = fromMaybe (WatchKind True True True) watchKind + handle <- asks serverIn + unwatch <- liftIO $ watchTree mgr dir pred $ \event -> do + let fe = FileEvent (filePathToUri (eventPath event)) typ + typ = case event of + FS.Added _ _ _ -> FcCreated + FS.Modified _ _ _ -> FcChanged + FS.Removed _ _ _ -> FcDeleted + -- This is a bit of a guess + FS.Unknown _ _ _ -> FcChanged + matches = case typ of + FcCreated -> wkC + FcChanged -> wkM + FcDeleted -> wkD + params = DidChangeWatchedFilesParams (List [fe]) + msg = fmClientDidChangeWatchedFilesNotification params + liftIO $ when matches $ B.hPut handle (addHeader $ encode msg) + modify $ \s -> + s { unwatchers = Map.insertWith (++) regId [unwatch] (unwatchers s) } + +processFileWatchUnregistrations :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) + => [T.Text] -> m () +processFileWatchUnregistrations regIds = + forM_ regIds $ \regId -> modifyM $ \s -> do + let fs = fromMaybe [] (Map.lookup regId (unwatchers s)) + liftIO $ sequence fs + return $ s { unwatchers = Map.delete regId (unwatchers s) }