Fix formatting requests sending misversioned didChange notifications
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index a3ba35b3a1a46f723d4ca0fce59775991680c862..4b1793f28312986437bae1c8e4e8ff32c1cead5d 100644 (file)
@@ -1,12 +1,13 @@
 {-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
-  ( Session
+  ( Session(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
@@ -22,25 +23,28 @@ module Language.Haskell.LSP.Test.Session
   , sendMessage
   , updateState
   , withTimeout
   , sendMessage
   , updateState
   , withTimeout
+  , getCurTimeoutId
+  , bumpTimeoutId
   , logMsg
   , LogMsgType(..)
   )
 
 where
 
   , logMsg
   , LogMsgType(..)
   )
 
 where
 
+import Control.Applicative
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
-#if __GLASGOW_HASKELL__ >= 806
+#if __GLASGOW_HASKELL__ == 806
 import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
 import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put)
+import qualified Control.Monad.Trans.State as State
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
@@ -58,13 +62,17 @@ import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
-import Language.Haskell.LSP.Types.Lens hiding (error)
+import Language.Haskell.LSP.Types.Lens
+import qualified Language.Haskell.LSP.Types.Lens as LSP
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.VFS
+import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
+import System.Process (waitForProcess, ProcessHandle())
+import System.Timeout
 
 -- | A session representing one instance of launching and connecting to a server.
 --
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -73,7 +81,8 @@ import System.IO
 -- 'Language.Haskell.LSP.Test.sendRequest' and
 -- 'Language.Haskell.LSP.Test.sendNotification'.
 
 -- 'Language.Haskell.LSP.Test.sendRequest' and
 -- 'Language.Haskell.LSP.Test.sendNotification'.
 
-type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
+newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
+  deriving (Functor, Applicative, Monad, MonadIO, Alternative)
 
 #if __GLASGOW_HASKELL__ >= 806
 instance MonadFail Session where
 
 #if __GLASGOW_HASKELL__ >= 806
 instance MonadFail Session where
@@ -85,15 +94,24 @@ instance MonadFail Session where
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
-  , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
-  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
+  , logStdErr      :: Bool
+  -- ^ Redirect the server's stderr to this stdout, defaults to False.
+  -- Can be overriden with @LSP_TEST_LOG_STDERR@.
+  , logMessages    :: Bool
+  -- ^ Trace the messages sent and received to stdout, defaults to False.
+  -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
+  , ignoreLogNotifications :: Bool
+  -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
+  -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
+  --
+  -- @since 0.9.0.0
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False False True Nothing
+defaultConfig = SessionConfig 60 False False True Nothing False
 
 instance Default SessionConfig where
   def = defaultConfig
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -106,7 +124,9 @@ data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan SessionMessage
+  , messageChan :: Chan SessionMessage -- ^ Where all messages come through
+  -- Keep curTimeoutId in SessionContext, as its tied to messageChan
+  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
@@ -118,22 +138,35 @@ class Monad m => HasReader r m where
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
-instance Monad m => HasReader r (ParserStateReader a s r m) where
-  ask = lift $ lift Reader.ask
+instance HasReader SessionContext Session where
+  ask  = Session (lift $ lift Reader.ask)
 
 
-instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
   ask = lift $ lift Reader.ask
 
+getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+
+-- Pass this the timeoutid you *were* waiting on
+bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
+bumpTimeoutId prev = do
+  v <- asks curTimeoutId
+  -- when updating the curtimeoutid, account for the fact that something else
+  -- might have bumped the timeoutid in the meantime
+  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
-  , curDiagnostics :: Map.Map Uri [Diagnostic]
-  , curTimeoutId :: Int
+  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
   , lastReceivedMessage :: Maybe FromServerMessage
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
   , lastReceivedMessage :: Maybe FromServerMessage
+  , curDynCaps :: Map.Map T.Text Registration
+  -- ^ The capabilities that the server has dynamically registered with us so
+  -- far
   }
 
 class Monad m => HasState s m where
   }
 
 class Monad m => HasState s m where
@@ -147,19 +180,26 @@ class Monad m => HasState s m where
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
-instance Monad m => HasState s (ParserStateReader a s r m) where
-  get = lift State.get
-  put = lift . State.put
+instance HasState SessionState Session where
+  get = Session (lift State.get)
+  put = Session . lift . State.put
 
 
-instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+instance Monad m => HasState s (StateT s m) where
+  get = State.get
+  put = State.put
+
+instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
  where
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
+instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
+ where
+  get = lift get
+  put = lift . put
 
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession context state session = runReaderT (runStateT conduit state) context
+runSession context state (Session session) = runReaderT (runStateT conduit state) context
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
@@ -172,45 +212,71 @@ runSession context state session = runReaderT (runStateT conduit state) context
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
+      unless (ignoreLogNotifications (config context) && isLogNotification msg) $
         yield msg
       chanSource
 
         yield msg
       chanSource
 
+    isLogNotification (ServerMessage (NotShowMessage _)) = True
+    isLogNotification (ServerMessage (NotLogMessage _)) = True
+    isLogNotification _ = False
+
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
-      curId <- curTimeoutId <$> get
+      curId <- getCurTimeoutId
       case msg of
         ServerMessage sMsg -> yield sMsg
       case msg of
         ServerMessage sMsg -> yield sMsg
-        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
+        TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
+                      -> ProcessHandle -- ^ Server process
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
+                      -> Session () -- ^ To exit the Server properly
                       -> Session a
                       -> IO a
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
+  -- This is required to make sure that we don’t get any
+  -- newline conversion or weird encoding issues.
+  hSetBinaryMode serverIn True
+  hSetBinaryMode serverOut True
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
+  timeoutIdVar <- newMVar 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-      launchServerHandler = forkIO $ catch (serverHandler serverOut context)
-                                           (throwTo mainThreadId :: SessionException -> IO ())
-  (result, _) <- bracket launchServerHandler killThread $
-    const $ runSession context initState session
-
+  let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
+      initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
+      runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
+
+      errorHandler = throwTo mainThreadId :: SessionException -> IO ()
+      serverListenerLauncher =
+        forkIO $ catch (serverHandler serverOut context) errorHandler
+      server = (Just serverIn, Just serverOut, Nothing, serverProc)
+      msgTimeoutMs = messageTimeout config * 10^6
+      serverAndListenerFinalizer tid = do
+        finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
+          -- Make sure to kill the listener first, before closing
+          -- handles etc via cleanupProcess
+          killThread tid
+          -- Give the server some time to exit cleanly
+          timeout msgTimeoutMs (waitForProcess serverProc)
+          cleanupProcess server
+
+  (result, _) <- bracket serverListenerLauncher
+                         serverAndListenerFinalizer
+                         (const $ runSession' session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -218,31 +284,48 @@ updateStateC = awaitForever $ \msg -> do
   updateState msg
   yield msg
 
   updateState msg
   yield msg
 
-updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+            => FromServerMessage -> m ()
+
+-- Keep track of dynamic capability registration
+updateState (ReqRegisterCapability req) = do
+  let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
+  modify $ \s ->
+    s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
+
+updateState (ReqUnregisterCapability req) = do
+  let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
+  modify $ \s ->
+    let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
+    in s { curDynCaps = newCurDynCaps }
+
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
-  modify (\s ->
-    let newDiags = Map.insert doc diags (curDiagnostics s)
-      in s { curDiagnostics = newDiags })
+  modify \s ->
+    let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
+      in s { curDiagnostics = newDiags }
 
 updateState (ReqApplyWorkspaceEdit r) = do
 
 
 updateState (ReqApplyWorkspaceEdit r) = do
 
+  -- First, prefer the versioned documentChanges field
   allChangeParams <- case r ^. params . edit . documentChanges of
     Just (List cs) -> do
       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
       return $ map getParams cs
   allChangeParams <- case r ^. params . edit . documentChanges of
     Just (List cs) -> do
       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
       return $ map getParams cs
+    -- Then fall back to the changes field
     Nothing -> case r ^. params . edit . changes of
       Just cs -> do
         mapM_ checkIfNeedsOpened (HashMap.keys cs)
     Nothing -> case r ^. params . edit . changes of
       Just cs -> do
         mapM_ checkIfNeedsOpened (HashMap.keys cs)
-        return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
-      Nothing -> error "No changes!"
+        concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
+      Nothing ->
+        error "WorkspaceEdit contains neither documentChanges nor changes!"
 
   modifyM $ \s -> do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
 
   modifyM $ \s -> do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+  let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
@@ -256,8 +339,8 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
-          newVFS = Map.adjust update uri oldVFS
+          update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
+          newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -265,7 +348,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
           ctx <- ask
 
           -- if its not open, open it
-          unless (uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -273,18 +356,27 @@ updateState (ReqApplyWorkspaceEdit r) = do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
-              newVFS <- liftIO $ openVFS (vfs s) msg
+              let (newVFS,_) = openVFS (vfs s) msg
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
             in DidChangeTextDocumentParams docId (List changeEvents)
 
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
             in DidChangeTextDocumentParams docId (List changeEvents)
 
-        textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
+        -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
+        -- where n is the current version
+        textDocumentVersions uri = do
+          m <- vfsMap . vfs <$> get
+          let curVer = fromMaybe 0 $
+                _lsp_version <$> m Map.!? (toNormalizedUri uri)
+          pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer..]
 
 
-        textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
+        textDocumentEdits uri edits = do
+          vers <- textDocumentVersions uri
+          pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
 
 
-        getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
+        getChangeParams uri (List edits) =
+          map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
 
         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
@@ -297,21 +389,20 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
--- | Execute a block f that will throw a 'Timeout' exception
+-- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
 withTimeout duration f = do
   chan <- asks messageChan
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
 withTimeout duration f = do
   chan <- asks messageChan
-  timeoutId <- curTimeoutId <$> get
+  timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
   liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
   modify $ \s -> s { overridingTimeout = True }
   liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
-  modify $ \s -> s { curTimeoutId = timeoutId + 1,
-                     overridingTimeout = False
-                   }
+  bumpTimeoutId timeoutId
+  modify $ \s -> s { overridingTimeout = False }
   return res
 
 data LogMsgType = LogServer | LogClient
   return res
 
 data LogMsgType = LogServer | LogClient
@@ -336,4 +427,3 @@ logMsg t msg = do
           | otherwise       = Cyan
 
         showPretty = B.unpack . encodePretty
           | otherwise       = Cyan
 
         showPretty = B.unpack . encodePretty
-