Don't use exitServer in Replay
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index 28d4eddec1591fbbe190fc825abc5c0c1453a0f5..4d75d1defa541f07993874da444ed8d4cf08d0ff 100644 (file)
@@ -3,6 +3,7 @@
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
   ( Session
@@ -34,7 +35,7 @@ import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
 #if __GLASGOW_HASKELL__ >= 806
-import qualified Control.Monad.Fail as Fail
+import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
@@ -59,31 +60,43 @@ import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
 import Language.Haskell.LSP.Types.Lens hiding (error)
 import Language.Haskell.LSP.VFS
+import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
 import System.IO
+import System.Process (ProcessHandle())
+import System.Timeout
 
 -- | A session representing one instance of launching and connecting to a server.
 --
--- You can send and receive messages to the server within 'Session' via 'getMessage',
--- 'sendRequest' and 'sendNotification'.
---
+-- You can send and receive messages to the server within 'Session' via
+-- 'Language.Haskell.LSP.Test.message',
+-- 'Language.Haskell.LSP.Test.sendRequest' and
+-- 'Language.Haskell.LSP.Test.sendNotification'.
 
 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
 
+#if __GLASGOW_HASKELL__ >= 806
+instance MonadFail Session where
+  fail s = do
+    lastMsg <- fromJust . lastReceivedMessage <$> get
+    liftIO $ throw (UnexpectedMessage s lastMsg)
+#endif
+
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
   , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
   , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
+  , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False False True
+defaultConfig = SessionConfig 60 False False True Nothing
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -118,7 +131,7 @@ data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
-  , curDiagnostics :: Map.Map Uri [Diagnostic]
+  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
   , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
@@ -148,11 +161,6 @@ instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
 
 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
 
-#if __GLASGOW_HASKELL__ >= 806
-instance (Fail.MonadFail m) => Fail.MonadFail (ParserStateReader a s r m) where
-  fail = Fail.fail
-#endif
-
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 runSession context state session = runReaderT (runStateT conduit state) context
   where
@@ -181,30 +189,42 @@ runSession context state session = runReaderT (runStateT conduit state) context
 -- It also does not automatically send initialize and exit messages.
 runSessionWithHandles :: Handle -- ^ Server in
                       -> Handle -- ^ Server out
+                      -> ProcessHandle -- ^ Server process
                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
                       -> SessionConfig
                       -> ClientCapabilities
                       -> FilePath -- ^ Root directory
+                      -> Session () -- ^ To exit the Server properly
                       -> Session a
                       -> IO a
-runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
+runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
   hSetBuffering serverOut NoBuffering
+  -- This is required to make sure that we don’t get any
+  -- newline conversion or weird encoding issues.
+  hSetBinaryMode serverIn True
+  hSetBinaryMode serverOut True
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
   initRsp <- newEmptyMVar
 
+  mainThreadId <- myThreadId
+
   let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
       initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
+      runSession' = runSession context initState
 
-  threadId <- forkIO $ void $ serverHandler serverOut context
-  (result, _) <- runSession context initState session
-
-  killThread threadId
+      errorHandler = throwTo mainThreadId :: SessionException -> IO()
+      serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
+      server = (Just serverIn, Just serverOut, Nothing, serverProc)
+      serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
+                                             (runSession' exitServer))
+                                    (cleanupRunningProcess server >> killThread tid)
 
+  (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
   return result
 
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
@@ -217,7 +237,7 @@ updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
   modify (\s ->
-    let newDiags = Map.insert doc diags (curDiagnostics s)
+    let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
       in s { curDiagnostics = newDiags })
 
 updateState (ReqApplyWorkspaceEdit r) = do
@@ -236,7 +256,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
     return $ s { vfs = newVFS }
 
-  let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
+  let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
       mergedParams = map mergeParams groupedParams
 
   -- TODO: Don't do this when replaying a session
@@ -250,8 +270,8 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
-          newVFS = Map.adjust update uri oldVFS
+          update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
+          newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -259,7 +279,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
-          unless (uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -291,7 +311,7 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
--- | Execute a block f that will throw a 'TimeoutException'
+-- | Execute a block f that will throw a 'Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
@@ -330,3 +350,4 @@ logMsg t msg = do
           | otherwise       = Cyan
 
         showPretty = B.unpack . encodePretty
+