From af401b6d0439751d73ea230a219f37eb57286c90 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Tue, 5 May 2020 18:20:21 +0100 Subject: [PATCH] Fix curtimeoutid being reset in the server exit handler --- src/Language/Haskell/LSP/Test/Parsing.hs | 5 +-- src/Language/Haskell/LSP/Test/Session.hs | 56 ++++++++++++++++-------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/Language/Haskell/LSP/Test/Parsing.hs b/src/Language/Haskell/LSP/Test/Parsing.hs index 5ce9b52..12ef1a6 100644 --- a/src/Language/Haskell/LSP/Test/Parsing.hs +++ b/src/Language/Haskell/LSP/Test/Parsing.hs @@ -75,7 +75,7 @@ satisfyMaybe :: (FromServerMessage -> Maybe a) -> Session a satisfyMaybe pred = do skipTimeout <- overridingTimeout <$> get - timeoutId <- curTimeoutId <$> get + timeoutId <- getCurTimeoutId unless skipTimeout $ do chan <- asks messageChan timeout <- asks (messageTimeout . config) @@ -85,8 +85,7 @@ satisfyMaybe pred = do x <- Session await - unless skipTimeout $ - modify $ \s -> s { curTimeoutId = timeoutId + 1 } + unless skipTimeout (bumpTimeoutId timeoutId) modify $ \s -> s { lastReceivedMessage = Just x } diff --git a/src/Language/Haskell/LSP/Test/Session.hs b/src/Language/Haskell/LSP/Test/Session.hs index c33d801..ddd07a5 100644 --- a/src/Language/Haskell/LSP/Test/Session.hs +++ b/src/Language/Haskell/LSP/Test/Session.hs @@ -23,6 +23,8 @@ module Language.Haskell.LSP.Test.Session , sendMessage , updateState , withTimeout + , getCurTimeoutId + , bumpTimeoutId , logMsg , LogMsgType(..) ) @@ -121,7 +123,9 @@ data SessionContext = SessionContext { serverIn :: Handle , rootDir :: FilePath - , messageChan :: Chan SessionMessage + , messageChan :: Chan SessionMessage -- ^ Where all messages come through + -- Keep curTimeoutId in SessionContext, as its tied to messageChan + , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on , requestMap :: MVar RequestMap , initRsp :: MVar InitializeResponse , config :: SessionConfig @@ -139,12 +143,22 @@ instance HasReader SessionContext Session where instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where ask = lift $ lift Reader.ask +getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int +getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar + +-- Pass this the timeoutid you *were* waiting on +bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m () +bumpTimeoutId prev = do + v <- asks curTimeoutId + -- when updating the curtimeoutid, account for the fact that something else + -- might have bumped the timeoutid in the meantime + liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1))) + data SessionState = SessionState { curReqId :: LspId , vfs :: VFS , curDiagnostics :: Map.Map NormalizedUri [Diagnostic] - , curTimeoutId :: Int , overridingTimeout :: Bool -- ^ The last received message from the server. -- Used for providing exception information @@ -166,15 +180,19 @@ instance HasState SessionState Session where get = Session (lift State.get) put = Session . lift . State.put -instance Monad m => HasState s (ConduitM a b (StateT s m)) +instance Monad m => HasState s (StateT s m) where + get = State.get + put = State.put + +instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put -instance Monad m => HasState s (ConduitParser a (StateT s m)) +instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m) where - get = lift State.get - put = lift . State.put + get = lift get + put = lift . put runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState) runSession context state (Session session) = runReaderT (runStateT conduit state) context @@ -200,7 +218,7 @@ runSession context state (Session session) = runReaderT (runStateT conduit state watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) () watchdog = Conduit.awaitForever $ \msg -> do - curId <- curTimeoutId <$> get + curId <- getCurTimeoutId case msg of ServerMessage sMsg -> yield sMsg TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout @@ -229,25 +247,26 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro reqMap <- newMVar newRequestMap messageChan <- newChan + timeoutIdVar <- newMVar 0 initRsp <- newEmptyMVar mainThreadId <- myThreadId - let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps - initState vfs = SessionState (IdInt 0) vfs - mempty 0 False Nothing + let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps + initState vfs = SessionState (IdInt 0) vfs mempty False Nothing runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses errorHandler = throwTo mainThreadId :: SessionException -> IO () serverListenerLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler server = (Just serverIn, Just serverOut, Nothing, serverProc) - serverAndListenerFinalizer tid = - finally (timeout (messageTimeout config * 1000000) + serverAndListenerFinalizer tid = do + finally (timeout (messageTimeout config * 1^6) (runSession' exitServer)) (cleanupProcess server >> killThread tid) - (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer + (result, _) <- bracket serverListenerLauncher + serverAndListenerFinalizer (const $ runSession' session) return result @@ -342,15 +361,14 @@ sendMessage msg = do withTimeout :: Int -> Session a -> Session a withTimeout duration f = do chan <- asks messageChan - timeoutId <- curTimeoutId <$> get + timeoutId <- getCurTimeoutId modify $ \s -> s { overridingTimeout = True } liftIO $ forkIO $ do threadDelay (duration * 1000000) writeChan chan (TimeoutMessage timeoutId) res <- f - modify $ \s -> s { curTimeoutId = timeoutId + 1, - overridingTimeout = False - } + bumpTimeoutId timeoutId + modify $ \s -> s { overridingTimeout = False } return res data LogMsgType = LogServer | LogClient -- 2.30.2