X-Git-Url: https://git.lukelau.me/?a=blobdiff_plain;f=src%2FLanguage%2FHaskell%2FLSP%2FTest%2FSession.hs;h=3e9e688bc221f563b8220b63e925cb71176a8668;hb=6f3106ce987b2a3794ee7ab444c8bcc204a7b3d2;hp=3426bcce47c66e105ddb235b13d308a92dcafe45;hpb=20750dca8684bcb05a7c91e8654257ad36e57ebe;p=lsp-test.git diff --git a/src/Language/Haskell/LSP/Test/Session.hs b/src/Language/Haskell/LSP/Test/Session.hs index 3426bcc..3e9e688 100644 --- a/src/Language/Haskell/LSP/Test/Session.hs +++ b/src/Language/Haskell/LSP/Test/Session.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleInstances #-} @@ -23,6 +24,8 @@ module Language.Haskell.LSP.Test.Session , sendMessage , updateState , withTimeout + , getCurTimeoutId + , bumpTimeoutId , logMsg , LogMsgType(..) ) @@ -57,10 +60,10 @@ import qualified Data.Text.IO as T import qualified Data.HashMap.Strict as HashMap import Data.Maybe import Data.Function -import Language.Haskell.LSP.Messages import Language.Haskell.LSP.Types.Capabilities import Language.Haskell.LSP.Types -import Language.Haskell.LSP.Types.Lens hiding (error) +import Language.Haskell.LSP.Types.Lens +import qualified Language.Haskell.LSP.Types.Lens as LSP import Language.Haskell.LSP.VFS import Language.Haskell.LSP.Test.Compat import Language.Haskell.LSP.Test.Decoding @@ -69,6 +72,9 @@ import System.Console.ANSI import System.Directory import System.IO import System.Process (ProcessHandle()) +#ifndef mingw32_HOST_OS +import System.Process (waitForProcess) +#endif import System.Timeout -- | A session representing one instance of launching and connecting to a server. @@ -91,13 +97,19 @@ instance MonadFail Session where -- | Stuff you can configure for a 'Session'. data SessionConfig = SessionConfig { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60. - , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False. - , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False. + , logStdErr :: Bool + -- ^ Redirect the server's stderr to this stdout, defaults to False. + -- Can be overriden with @LSP_TEST_LOG_STDERR@. + , logMessages :: Bool + -- ^ Trace the messages sent and received to stdout, defaults to False. + -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@. , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True. , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing. - -- ^ Whether or not to ignore 'ShowMessageNotification' and 'LogMessageNotification', defaults to False. - -- @since 0.9.0.0 , ignoreLogNotifications :: Bool + -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and + -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False. + -- + -- @since 0.9.0.0 } -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'. @@ -115,7 +127,9 @@ data SessionContext = SessionContext { serverIn :: Handle , rootDir :: FilePath - , messageChan :: Chan SessionMessage + , messageChan :: Chan SessionMessage -- ^ Where all messages come through + -- Keep curTimeoutId in SessionContext, as its tied to messageChan + , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on , requestMap :: MVar RequestMap , initRsp :: MVar InitializeResponse , config :: SessionConfig @@ -133,16 +147,29 @@ instance HasReader SessionContext Session where instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where ask = lift $ lift Reader.ask +getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int +getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar + +-- Pass this the timeoutid you *were* waiting on +bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m () +bumpTimeoutId prev = do + v <- asks curTimeoutId + -- when updating the curtimeoutid, account for the fact that something else + -- might have bumped the timeoutid in the meantime + liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1))) + data SessionState = SessionState { - curReqId :: LspId + curReqId :: Int , vfs :: VFS , curDiagnostics :: Map.Map NormalizedUri [Diagnostic] - , curTimeoutId :: Int , overridingTimeout :: Bool -- ^ The last received message from the server. -- Used for providing exception information , lastReceivedMessage :: Maybe FromServerMessage + , curDynCaps :: Map.Map T.Text SomeRegistration + -- ^ The capabilities that the server has dynamically registered with us so + -- far } class Monad m => HasState s m where @@ -160,15 +187,19 @@ instance HasState SessionState Session where get = Session (lift State.get) put = Session . lift . State.put -instance Monad m => HasState s (ConduitM a b (StateT s m)) +instance Monad m => HasState s (StateT s m) where + get = State.get + put = State.put + +instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put -instance Monad m => HasState s (ConduitParser a (StateT s m)) +instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState) runSession context state (Session session) = runReaderT (runStateT conduit state) context @@ -188,13 +219,13 @@ runSession context state (Session session) = runReaderT (runStateT conduit state yield msg chanSource - isLogNotification (ServerMessage (NotShowMessage _)) = True - isLogNotification (ServerMessage (NotLogMessage _)) = True + isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True + isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True isLogNotification _ = False watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () watchdog = Conduit.awaitForever $ \msg -> do - curId <- curTimeoutId <$> get + curId <- getCurTimeoutId case msg of ServerMessage sMsg -> yield sMsg TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout @@ -223,26 +254,35 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro reqMap <- newMVar newRequestMap messageChan <- newChan + timeoutIdVar <- newMVar 0 initRsp <- newEmptyMVar mainThreadId <- myThreadId - let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps - initState vfs = SessionState (IdInt 0) vfs - mempty 0 False Nothing + let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps + initState vfs = SessionState 0 vfs mempty False Nothing mempty runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses errorHandler = throwTo mainThreadId :: SessionException -> IO () serverListenerLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler server = (Just serverIn, Just serverOut, Nothing, serverProc) - serverAndListenerFinalizer tid = - finally (timeout (messageTimeout config * 1000000) - (runSession' exitServer)) - (cleanupProcess server >> killThread tid) + msgTimeoutMs = messageTimeout config * 10^6 + serverAndListenerFinalizer tid = do + finally (timeout msgTimeoutMs (runSession' exitServer)) $ do + -- Make sure to kill the listener first, before closing + -- handles etc via cleanupProcess + killThread tid + -- Give the server some time to exit cleanly + -- It makes the server hangs in windows so we have to avoid it +#ifndef mingw32_HOST_OS + timeout msgTimeoutMs (waitForProcess serverProc) +#endif + cleanupProcess server - (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer - (const $ runSession' session) + (result, _) <- bracket serverListenerLauncher + serverAndListenerFinalizer + (const $ initVFS $ \vfs -> runSession context (initState vfs) session) return result updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () @@ -252,24 +292,40 @@ updateStateC = awaitForever $ \msg -> do updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m () -updateState (NotPublishDiagnostics n) = do + +-- Keep track of dynamic capability registration +updateState (FromServerMess SClientRegisterCapability req) = do + let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations + modify $ \s -> + s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) } + +updateState (FromServerMess SClientUnregisterCapability req) = do + let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations + modify $ \s -> + let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs + in s { curDynCaps = newCurDynCaps } + +updateState (FromServerMess STextDocumentPublishDiagnostics n) = do let List diags = n ^. params . diagnostics doc = n ^. params . uri - modify (\s -> + modify $ \s -> let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s) - in s { curDiagnostics = newDiags }) + in s { curDiagnostics = newDiags } -updateState (ReqApplyWorkspaceEdit r) = do +updateState (FromServerMess SWorkspaceApplyEdit r) = do + -- First, prefer the versioned documentChanges field allChangeParams <- case r ^. params . edit . documentChanges of Just (List cs) -> do mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs return $ map getParams cs + -- Then fall back to the changes field Nothing -> case r ^. params . edit . changes of Just cs -> do mapM_ checkIfNeedsOpened (HashMap.keys cs) - return $ concatMap (uncurry getChangeParams) (HashMap.toList cs) - Nothing -> error "No changes!" + concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs) + Nothing -> + error "WorkspaceEdit contains neither documentChanges nor changes!" modifyM $ \s -> do newVFS <- liftIO $ changeFromServerVFS (vfs s) r @@ -279,7 +335,7 @@ updateState (ReqApplyWorkspaceEdit r) = do mergedParams = map mergeParams groupedParams -- TODO: Don't do this when replaying a session - forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange) + forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange) -- Update VFS to new document versions let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams @@ -302,7 +358,7 @@ updateState (ReqApplyWorkspaceEdit r) = do let fp = fromJust $ uriToFilePath uri contents <- liftIO $ T.readFile fp let item = TextDocumentItem (filePathToUri fp) "" 0 contents - msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item) + msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item) liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg) modifyM $ \s -> do @@ -313,11 +369,20 @@ updateState (ReqApplyWorkspaceEdit r) = do let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits in DidChangeTextDocumentParams docId (List changeEvents) - textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..] + -- For a uri returns an infinite list of versions [n,n+1,n+2,...] + -- where n is the current version + textDocumentVersions uri = do + m <- vfsMap . vfs <$> get + let curVer = fromMaybe 0 $ + _lsp_version <$> m Map.!? (toNormalizedUri uri) + pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..] - textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits + textDocumentEdits uri edits = do + vers <- textDocumentVersions uri + pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits - getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits)) + getChangeParams uri (List edits) = + map <$> pure getParams <*> textDocumentEdits uri (reverse edits) mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params)) @@ -330,21 +395,20 @@ sendMessage msg = do logMsg LogClient msg liftIO $ B.hPut h (addHeader $ encode msg) --- | Execute a block f that will throw a 'Timeout' exception +-- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception -- after duration seconds. This will override the global timeout -- for waiting for messages to arrive defined in 'SessionConfig'. withTimeout :: Int -> Session a -> Session a withTimeout duration f = do chan <- asks messageChan - timeoutId <- curTimeoutId <$> get + timeoutId <- getCurTimeoutId modify $ \s -> s { overridingTimeout = True } liftIO $ forkIO $ do threadDelay (duration * 1000000) writeChan chan (TimeoutMessage timeoutId) res <- f - modify $ \s -> s { curTimeoutId = timeoutId + 1, - overridingTimeout = False - } + bumpTimeoutId timeoutId + modify $ \s -> s { overridingTimeout = False } return res data LogMsgType = LogServer | LogClient @@ -369,5 +433,3 @@ logMsg t msg = do | otherwise = Cyan showPretty = B.unpack . encodePretty - -