Add server shutdown check to throw exception
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index ab09726f2ef5490654d4a41099ed53e42f03b608..ae8ba1e0fbb15917697bd8e2f3589ebcf59a0862 100644 (file)
@@ -48,6 +48,7 @@ import Data.Conduit as Conduit
 import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
+import Data.IORef
 import Data.List
 import qualified Data.Map as Map
 import qualified Data.Text as T
@@ -128,7 +129,7 @@ data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
-  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
@@ -193,6 +194,10 @@ runSessionWithHandles :: Handle -- ^ Server in
                       -> Session a
                       -> IO a
 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+  -- We use this IORef to make exception non-fatal when the server is supposed to shutdown.
+
+  exitOk <- newIORef False
+  
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
@@ -210,11 +215,15 @@ runSessionWithHandles serverIn serverOut serverHandler config caps rootDir sessi
 
   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-      launchServerHandler = forkIO $ catch (serverHandler serverOut context)
-                                           (throwTo mainThreadId :: SessionException -> IO ())
-  (result, _) <- bracket launchServerHandler killThread $
-    const $ runSession context initState session
-
+      errorHandler ex = do x <- readIORef exitOk 
+                           unless x $ throwTo mainThreadId (ex :: SessionException)
+      launchServerHandler = forkIO $ catch (serverHandler serverOut context) errorHandler
+  (result, _) <- bracket
+                   launchServerHandler
+                   (\tid -> do runSession context initState sendExitMessage
+                               killThread tid
+                               atomicWriteIORef exitOk True)
+                   (const $ runSession context initState session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -227,7 +236,7 @@ updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
-    let newDiags = Map.insert doc diags (curDiagnostics s)
+    let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
@@ -246,7 +255,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+  let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
@@ -261,7 +270,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
     modify $ \s ->
       let oldVFS = vfs s
           update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
-          newVFS = Map.adjust update uri oldVFS
+          newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -269,7 +278,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
-          unless (uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -301,6 +310,9 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
+sendExitMessage :: (MonadIO m, HasReader SessionContext m) => m ()
+sendExitMessage = sendMessage (NotificationMessage "2.0" Exit ExitParams)
+
 -- | Execute a block f that will throw a 'Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.