From 6042c31e8b18eefb81b98a8ebb3e1e6f4a004907 Mon Sep 17 00:00:00 2001 From: Zubin Duggal Date: Wed, 24 Feb 2021 22:01:29 +0530 Subject: [PATCH] Kill timeout thread --- src/Language/LSP/Test/Parsing.hs | 11 ++++++++--- src/Language/LSP/Test/Session.hs | 29 ++++++++++++++++------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/Language/LSP/Test/Parsing.hs b/src/Language/LSP/Test/Parsing.hs index e55909f..58785d9 100644 --- a/src/Language/LSP/Test/Parsing.hs +++ b/src/Language/LSP/Test/Parsing.hs @@ -83,16 +83,21 @@ satisfyMaybeM pred = do skipTimeout <- overridingTimeout <$> get timeoutId <- getCurTimeoutId - unless skipTimeout $ do + mtid <- + if skipTimeout + then pure Nothing + else Just <$> do chan <- asks messageChan timeout <- asks (messageTimeout . config) - void $ liftIO $ forkIO $ do + liftIO $ forkIO $ do threadDelay (timeout * 1000000) writeChan chan (TimeoutMessage timeoutId) x <- Session await - unless skipTimeout (bumpTimeoutId timeoutId) + forM_ mtid $ \tid -> do + bumpTimeoutId timeoutId + liftIO $ killThread tid modify $ \s -> s { lastReceivedMessage = Just x } diff --git a/src/Language/LSP/Test/Session.hs b/src/Language/LSP/Test/Session.hs index 55055cd..f3d6f8c 100644 --- a/src/Language/LSP/Test/Session.hs +++ b/src/Language/LSP/Test/Session.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -56,7 +57,7 @@ import Data.Conduit.Parser as Parser import Data.Default import Data.Foldable import Data.List -import qualified Data.Map as Map +import qualified Data.Map.Strict as Map import qualified Data.Set as Set import qualified Data.Text as T import qualified Data.Text.IO as T @@ -79,6 +80,7 @@ import System.Process (ProcessHandle()) import System.Process (waitForProcess) #endif import System.Timeout +import Data.IORef -- | A session representing one instance of launching and connecting to a server. -- @@ -135,7 +137,7 @@ data SessionContext = SessionContext , rootDir :: FilePath , messageChan :: Chan SessionMessage -- ^ Where all messages come through -- Keep curTimeoutId in SessionContext, as its tied to messageChan - , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on + , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on , requestMap :: MVar RequestMap , initRsp :: MVar (ResponseMessage Initialize) , config :: SessionConfig @@ -154,7 +156,7 @@ instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where ask = lift $ lift Reader.ask getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int -getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar +getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef -- Pass this the timeoutid you *were* waiting on bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m () @@ -162,21 +164,21 @@ bumpTimeoutId prev = do v <- asks curTimeoutId -- when updating the curtimeoutid, account for the fact that something else -- might have bumped the timeoutid in the meantime - liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1))) + liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ())) data SessionState = SessionState { - curReqId :: Int - , vfs :: VFS - , curDiagnostics :: Map.Map NormalizedUri [Diagnostic] - , overridingTimeout :: Bool + curReqId :: !Int + , vfs :: !VFS + , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic]) + , overridingTimeout :: !Bool -- ^ The last received message from the server. -- Used for providing exception information - , lastReceivedMessage :: Maybe FromServerMessage - , curDynCaps :: Map.Map T.Text SomeRegistration + , lastReceivedMessage :: !(Maybe FromServerMessage) + , curDynCaps :: !(Map.Map T.Text SomeRegistration) -- ^ The capabilities that the server has dynamically registered with us so -- far - , curProgressSessions :: Set.Set ProgressToken + , curProgressSessions :: !(Set.Set ProgressToken) } class Monad m => HasState s m where @@ -261,7 +263,7 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi reqMap <- newMVar newRequestMap messageChan <- newChan - timeoutIdVar <- newMVar 0 + timeoutIdVar <- newIORef 0 initRsp <- newEmptyMVar mainThreadId <- myThreadId @@ -441,10 +443,11 @@ withTimeout duration f = do chan <- asks messageChan timeoutId <- getCurTimeoutId modify $ \s -> s { overridingTimeout = True } - liftIO $ forkIO $ do + tid <- liftIO $ forkIO $ do threadDelay (duration * 1000000) writeChan chan (TimeoutMessage timeoutId) res <- f + liftIO $ killThread tid bumpTimeoutId timeoutId modify $ \s -> s { overridingTimeout = False } return res -- 2.30.2