Add server shutdown check to throw exception
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index 46155f0607cfa1b752b541b755a761b1551b4b30..ae8ba1e0fbb15917697bd8e2f3589ebcf59a0862 100644 (file)
@@ -48,6 +48,7 @@ import Data.Conduit as Conduit
 import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
+import Data.IORef
 import Data.List
 import qualified Data.Map as Map
 import qualified Data.Text as T
@@ -193,6 +194,10 @@ runSessionWithHandles :: Handle -- ^ Server in
                       -> Session a
                       -> IO a
 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+  -- We use this IORef to make exception non-fatal when the server is supposed to shutdown.
+
+  exitOk <- newIORef False
+  
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
@@ -210,11 +215,15 @@ runSessionWithHandles serverIn serverOut serverHandler config caps rootDir sessi
 
   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-      launchServerHandler = forkIO $ catch (serverHandler serverOut context)
-                                           (throwTo mainThreadId :: SessionException -> IO ())
-  (result, _) <- bracket launchServerHandler killThread $
-    const $ runSession context initState session
-
+      errorHandler ex = do x <- readIORef exitOk 
+                           unless x $ throwTo mainThreadId (ex :: SessionException)
+      launchServerHandler = forkIO $ catch (serverHandler serverOut context) errorHandler
+  (result, _) <- bracket
+                   launchServerHandler
+                   (\tid -> do runSession context initState sendExitMessage
+                               killThread tid
+                               atomicWriteIORef exitOk True)
+                   (const $ runSession context initState session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -301,6 +310,9 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
+sendExitMessage :: (MonadIO m, HasReader SessionContext m) => m ()
+sendExitMessage = sendMessage (NotificationMessage "2.0" Exit ExitParams)
+
 -- | Execute a block f that will throw a 'Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.