Kill timeout thread
authorZubin Duggal <zubin@cmi.ac.in>
Wed, 24 Feb 2021 16:31:29 +0000 (22:01 +0530)
committerZubin Duggal <zubin@cmi.ac.in>
Thu, 25 Feb 2021 18:56:13 +0000 (00:26 +0530)
src/Language/LSP/Test/Parsing.hs
src/Language/LSP/Test/Session.hs

index e55909f4c59fc9076c665d9b8f66811ae9981bb9..58785d9bf5979a7227db8d7c91d674399deda8ba 100644 (file)
@@ -83,16 +83,21 @@ satisfyMaybeM pred = do
   
   skipTimeout <- overridingTimeout <$> get
   timeoutId <- getCurTimeoutId
-  unless skipTimeout $ do
+  mtid <-
+    if skipTimeout
+    then pure Nothing
+    else Just <$> do
       chan <- asks messageChan
       timeout <- asks (messageTimeout . config)
-    void $ liftIO $ forkIO $ do
+      liftIO $ forkIO $ do
         threadDelay (timeout * 1000000)
         writeChan chan (TimeoutMessage timeoutId)
 
   x <- Session await
 
-  unless skipTimeout (bumpTimeoutId timeoutId)
+  forM_ mtid $ \tid -> do
+    bumpTimeoutId timeoutId
+    liftIO $ killThread tid
 
   modify $ \s -> s { lastReceivedMessage = Just x }
 
index 55055cdebfce545b65755a9c5b7f3c72f33feb3b..f3d6f8c28e83c7250bd093bde204e79964e08204 100644 (file)
@@ -1,4 +1,5 @@
 {-# LANGUAGE CPP               #-}
+{-# LANGUAGE BangPatterns      #-}
 {-# LANGUAGE GADTs             #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -56,7 +57,7 @@ import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
 import Data.List
-import qualified Data.Map as Map
+import qualified Data.Map.Strict as Map
 import qualified Data.Set as Set
 import qualified Data.Text as T
 import qualified Data.Text.IO as T
@@ -79,6 +80,7 @@ import System.Process (ProcessHandle())
 import System.Process (waitForProcess)
 #endif
 import System.Timeout
+import Data.IORef
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -135,7 +137,7 @@ data SessionContext = SessionContext
   , rootDir :: FilePath
   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
-  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
+  , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar (ResponseMessage Initialize)
   , config :: SessionConfig
@@ -154,7 +156,7 @@ instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
-getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef
 
 -- Pass this the timeoutid you *were* waiting on
 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
@@ -162,21 +164,21 @@ bumpTimeoutId prev = do
   v <- asks curTimeoutId
   -- when updating the curtimeoutid, account for the fact that something else
   -- might have bumped the timeoutid in the meantime
-  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+  liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ()))
 
 data SessionState = SessionState
   {
-    curReqId :: Int
-  , vfs :: VFS
-  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
-  , overridingTimeout :: Bool
+    curReqId :: !Int
+  , vfs :: !VFS
+  , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic])
+  , overridingTimeout :: !Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
-  , lastReceivedMessage :: Maybe FromServerMessage
-  , curDynCaps :: Map.Map T.Text SomeRegistration
+  , lastReceivedMessage :: !(Maybe FromServerMessage)
+  , curDynCaps :: !(Map.Map T.Text SomeRegistration)
   -- ^ The capabilities that the server has dynamically registered with us so
   -- far
-  , curProgressSessions :: Set.Set ProgressToken
+  , curProgressSessions :: !(Set.Set ProgressToken)
   }
 
 class Monad m => HasState s m where
@@ -261,7 +263,7 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
-  timeoutIdVar <- newMVar 0
+  timeoutIdVar <- newIORef 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
@@ -441,10 +443,11 @@ withTimeout duration f = do
   chan <- asks messageChan
   timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
-  liftIO $ forkIO $ do
+  tid <- liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
+  liftIO $ killThread tid
   bumpTimeoutId timeoutId
   modify $ \s -> s { overridingTimeout = False }
   return res