Fix curtimeoutid being reset in the server exit handler
authorLuke Lau <luke_lau@icloud.com>
Tue, 5 May 2020 17:20:21 +0000 (18:20 +0100)
committerLuke Lau <luke_lau@icloud.com>
Tue, 5 May 2020 17:20:21 +0000 (18:20 +0100)
src/Language/Haskell/LSP/Test/Parsing.hs
src/Language/Haskell/LSP/Test/Session.hs

index 5ce9b52995c05f0deb899c1327ae7476043760d2..12ef1a6281547c4fc73dd4ad812cb71c529e95f5 100644 (file)
@@ -75,7 +75,7 @@ satisfyMaybe :: (FromServerMessage -> Maybe a) -> Session a
 satisfyMaybe pred = do
 
   skipTimeout <- overridingTimeout <$> get
-  timeoutId <- curTimeoutId <$> get
+  timeoutId <- getCurTimeoutId
   unless skipTimeout $ do
     chan <- asks messageChan
     timeout <- asks (messageTimeout . config)
@@ -85,8 +85,7 @@ satisfyMaybe pred = do
 
   x <- Session await
 
-  unless skipTimeout $
-    modify $ \s -> s { curTimeoutId = timeoutId + 1 }
+  unless skipTimeout (bumpTimeoutId timeoutId)
 
   modify $ \s -> s { lastReceivedMessage = Just x }
 
index c33d801efc35c9bb1d5c46c6c3715c5fbf8bbde3..ddd07a5da6693cc3ccc27bd0845df8cd7b58734d 100644 (file)
@@ -23,6 +23,8 @@ module Language.Haskell.LSP.Test.Session
   , sendMessage
   , updateState
   , withTimeout
+  , getCurTimeoutId
+  , bumpTimeoutId
   , logMsg
   , LogMsgType(..)
   )
@@ -121,7 +123,9 @@ data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan SessionMessage
+  , messageChan :: Chan SessionMessage -- ^ Where all messages come through
+  -- Keep curTimeoutId in SessionContext, as its tied to messageChan
+  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
@@ -139,12 +143,22 @@ instance HasReader SessionContext Session where
 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
+getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+
+-- Pass this the timeoutid you *were* waiting on
+bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
+bumpTimeoutId prev = do
+  v <- asks curTimeoutId
+  -- when updating the curtimeoutid, account for the fact that something else
+  -- might have bumped the timeoutid in the meantime
+  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
-  , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
@@ -166,15 +180,19 @@ instance HasState SessionState Session where
   get = Session (lift State.get)
   put = Session . lift . State.put
 
-instance Monad m => HasState s (ConduitM a b (StateT s m))
+instance Monad m => HasState s (StateT s m) where
+  get = State.get
+  put = State.put
+
+instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
-instance Monad m => HasState s (ConduitParser a (StateT s m))
+instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 runSession context state (Session session) = runReaderT (runStateT conduit state) context
@@ -200,7 +218,7 @@ runSession context state (Session session) = runReaderT (runStateT conduit state
 
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
-      curId <- curTimeoutId <$> get
+      curId <- getCurTimeoutId
       case msg of
         ServerMessage sMsg -> yield sMsg
         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
@@ -229,25 +247,26 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
+  timeoutIdVar <- newMVar 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState vfs = SessionState (IdInt 0) vfs
-                                       mempty 0 False Nothing
+  let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
+      initState vfs = SessionState (IdInt 0) vfs mempty False Nothing
       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
 
       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
       serverListenerLauncher =
         forkIO $ catch (serverHandler serverOut context) errorHandler
       server = (Just serverIn, Just serverOut, Nothing, serverProc)
-      serverAndListenerFinalizer tid =
-        finally (timeout (messageTimeout config * 1000000)
+      serverAndListenerFinalizer tid = do
+        finally (timeout (messageTimeout config * 1^6)
                          (runSession' exitServer))
                 (cleanupProcess server >> killThread tid)
 
-  (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
+  (result, _) <- bracket serverListenerLauncher
+                         serverAndListenerFinalizer
                          (const $ runSession' session)
   return result
 
@@ -342,15 +361,14 @@ sendMessage msg = do
 withTimeout :: Int -> Session a -> Session a
 withTimeout duration f = do
   chan <- asks messageChan
-  timeoutId <- curTimeoutId <$> get
+  timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
   liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
-  modify $ \s -> s { curTimeoutId = timeoutId + 1,
-                     overridingTimeout = False
-                   }
+  bumpTimeoutId timeoutId
+  modify $ \s -> s { overridingTimeout = False }
   return res
 
 data LogMsgType = LogServer | LogClient