3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
11 module Language.LSP.Test.Session
36 import Control.Applicative
37 import Control.Concurrent hiding (yield)
38 import Control.Exception
39 import Control.Lens hiding (List)
41 import Control.Monad.IO.Class
42 import Control.Monad.Except
43 #if __GLASGOW_HASKELL__ == 806
44 import Control.Monad.Fail
46 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
47 import qualified Control.Monad.Trans.Reader as Reader (ask)
48 import Control.Monad.Trans.State (StateT, runStateT)
49 import qualified Control.Monad.Trans.State as State
50 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson.Encode.Pretty
53 import Data.Conduit as Conduit
54 import Data.Conduit.Parser as Parser
58 import qualified Data.Map as Map
59 import qualified Data.Text as T
60 import qualified Data.Text.IO as T
61 import qualified Data.HashMap.Strict as HashMap
64 import Language.LSP.Types.Capabilities
65 import Language.LSP.Types
66 import Language.LSP.Types.Lens
67 import qualified Language.LSP.Types.Lens as LSP
68 import Language.LSP.VFS
69 import Language.LSP.Test.Compat
70 import Language.LSP.Test.Decoding
71 import Language.LSP.Test.Exceptions
72 import System.Console.ANSI
73 import System.Directory
75 import System.Process (ProcessHandle())
76 #ifndef mingw32_HOST_OS
77 import System.Process (waitForProcess)
81 -- | A session representing one instance of launching and connecting to a server.
83 -- You can send and receive messages to the server within 'Session' via
84 -- 'Language.LSP.Test.message',
85 -- 'Language.LSP.Test.sendRequest' and
86 -- 'Language.LSP.Test.sendNotification'.
88 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
89 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
91 #if __GLASGOW_HASKELL__ >= 806
92 instance MonadFail Session where
94 lastMsg <- fromJust . lastReceivedMessage <$> get
95 liftIO $ throw (UnexpectedMessage s lastMsg)
98 -- | Stuff you can configure for a 'Session'.
99 data SessionConfig = SessionConfig
100 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
102 -- ^ Redirect the server's stderr to this stdout, defaults to False.
103 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
104 , logMessages :: Bool
105 -- ^ Trace the messages sent and received to stdout, defaults to False.
106 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
107 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
108 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
109 , ignoreLogNotifications :: Bool
110 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
111 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
114 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
115 -- ^ The initial workspace folders to send in the @initialize@ request.
116 -- Defaults to Nothing.
119 -- | The configuration used in 'Language.LSP.Test.runSession'.
120 defaultConfig :: SessionConfig
121 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
123 instance Default SessionConfig where
126 data SessionMessage = ServerMessage FromServerMessage
130 data SessionContext = SessionContext
133 , rootDir :: FilePath
134 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
135 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
136 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
137 , requestMap :: MVar RequestMap
138 , initRsp :: MVar (ResponseMessage Initialize)
139 , config :: SessionConfig
140 , sessionCapabilities :: ClientCapabilities
143 class Monad m => HasReader r m where
145 asks :: (r -> b) -> m b
148 instance HasReader SessionContext Session where
149 ask = Session (lift $ lift Reader.ask)
151 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
152 ask = lift $ lift Reader.ask
154 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
155 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
157 -- Pass this the timeoutid you *were* waiting on
158 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
159 bumpTimeoutId prev = do
160 v <- asks curTimeoutId
161 -- when updating the curtimeoutid, account for the fact that something else
162 -- might have bumped the timeoutid in the meantime
163 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
165 data SessionState = SessionState
169 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
170 , overridingTimeout :: Bool
171 -- ^ The last received message from the server.
172 -- Used for providing exception information
173 , lastReceivedMessage :: Maybe FromServerMessage
174 , curDynCaps :: Map.Map T.Text SomeRegistration
175 -- ^ The capabilities that the server has dynamically registered with us so
179 class Monad m => HasState s m where
184 modify :: (s -> s) -> m ()
185 modify f = get >>= put . f
187 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
188 modifyM f = get >>= f >>= put
190 instance HasState SessionState Session where
191 get = Session (lift State.get)
192 put = Session . lift . State.put
194 instance Monad m => HasState s (StateT s m) where
198 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
203 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
208 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
209 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
211 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
213 handler (Unexpected "ConduitParser.empty") = do
214 lastMsg <- fromJust . lastReceivedMessage <$> get
215 name <- getParserName
216 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
221 msg <- liftIO $ readChan (messageChan context)
222 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
226 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
227 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
228 isLogNotification _ = False
230 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
231 watchdog = Conduit.awaitForever $ \msg -> do
232 curId <- getCurTimeoutId
234 ServerMessage sMsg -> yield sMsg
235 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
237 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
238 -- It also does not automatically send initialize and exit messages.
239 runSession' :: Handle -- ^ Server in
240 -> Handle -- ^ Server out
241 -> Maybe ProcessHandle -- ^ Server process
242 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
244 -> ClientCapabilities
245 -> FilePath -- ^ Root directory
246 -> Session () -- ^ To exit the Server properly
249 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
250 absRootDir <- canonicalizePath rootDir
252 hSetBuffering serverIn NoBuffering
253 hSetBuffering serverOut NoBuffering
254 -- This is required to make sure that we don’t get any
255 -- newline conversion or weird encoding issues.
256 hSetBinaryMode serverIn True
257 hSetBinaryMode serverOut True
259 reqMap <- newMVar newRequestMap
260 messageChan <- newChan
261 timeoutIdVar <- newMVar 0
262 initRsp <- newEmptyMVar
264 mainThreadId <- myThreadId
266 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
267 initState vfs = SessionState 0 vfs mempty False Nothing mempty
268 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
270 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
271 serverListenerLauncher =
272 forkIO $ catch (serverHandler serverOut context) errorHandler
273 msgTimeoutMs = messageTimeout config * 10^6
274 serverAndListenerFinalizer tid = do
276 | Just sp <- mServerProc = do
277 -- Give the server some time to exit cleanly
278 -- It makes the server hangs in windows so we have to avoid it
279 #ifndef mingw32_HOST_OS
280 timeout msgTimeoutMs (waitForProcess sp)
282 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
283 | otherwise = pure ()
284 finally (timeout msgTimeoutMs (runSession' exitServer))
285 -- Make sure to kill the listener first, before closing
286 -- handles etc via cleanupProcess
287 (killThread tid >> cleanup)
289 (result, _) <- bracket serverListenerLauncher
290 serverAndListenerFinalizer
291 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
294 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
295 updateStateC = awaitForever $ \msg -> do
299 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
300 => FromServerMessage -> m ()
302 -- Keep track of dynamic capability registration
303 updateState (FromServerMess SClientRegisterCapability req) = do
304 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
306 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
308 updateState (FromServerMess SClientUnregisterCapability req) = do
309 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
311 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
312 in s { curDynCaps = newCurDynCaps }
314 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
315 let List diags = n ^. params . diagnostics
316 doc = n ^. params . uri
318 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
319 in s { curDiagnostics = newDiags }
321 updateState (FromServerMess SWorkspaceApplyEdit r) = do
323 -- First, prefer the versioned documentChanges field
324 allChangeParams <- case r ^. params . edit . documentChanges of
326 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
327 return $ map getParams cs
328 -- Then fall back to the changes field
329 Nothing -> case r ^. params . edit . changes of
331 mapM_ checkIfNeedsOpened (HashMap.keys cs)
332 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
334 error "WorkspaceEdit contains neither documentChanges nor changes!"
337 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
338 return $ s { vfs = newVFS }
340 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
341 mergedParams = map mergeParams groupedParams
343 -- TODO: Don't do this when replaying a session
344 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
346 -- Update VFS to new document versions
347 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
348 latestVersions = map ((^. textDocument) . last) sortedVersions
349 bumpedVersions = map (version . _Just +~ 1) latestVersions
351 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
354 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
355 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
356 in s { vfs = newVFS }
358 where checkIfNeedsOpened uri = do
359 oldVFS <- vfs <$> get
362 -- if its not open, open it
363 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
364 let fp = fromJust $ uriToFilePath uri
365 contents <- liftIO $ T.readFile fp
366 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
367 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
368 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
371 let (newVFS,_) = openVFS (vfs s) msg
372 return $ s { vfs = newVFS }
374 getParams (TextDocumentEdit docId (List edits)) =
375 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
376 in DidChangeTextDocumentParams docId (List changeEvents)
378 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
379 -- where n is the current version
380 textDocumentVersions uri = do
381 m <- vfsMap . vfs <$> get
382 let curVer = fromMaybe 0 $
383 _lsp_version <$> m Map.!? (toNormalizedUri uri)
384 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
386 textDocumentEdits uri edits = do
387 vers <- textDocumentVersions uri
388 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
390 getChangeParams uri (List edits) =
391 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
393 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
394 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
395 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
396 updateState _ = return ()
398 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
400 h <- serverIn <$> ask
402 liftIO $ B.hPut h (addHeader $ encode msg)
404 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
405 -- after duration seconds. This will override the global timeout
406 -- for waiting for messages to arrive defined in 'SessionConfig'.
407 withTimeout :: Int -> Session a -> Session a
408 withTimeout duration f = do
409 chan <- asks messageChan
410 timeoutId <- getCurTimeoutId
411 modify $ \s -> s { overridingTimeout = True }
413 threadDelay (duration * 1000000)
414 writeChan chan (TimeoutMessage timeoutId)
416 bumpTimeoutId timeoutId
417 modify $ \s -> s { overridingTimeout = False }
420 data LogMsgType = LogServer | LogClient
423 -- | Logs the message if the config specified it
424 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
425 => LogMsgType -> a -> m ()
427 shouldLog <- asks $ logMessages . config
428 shouldColor <- asks $ logColor . config
429 liftIO $ when shouldLog $ do
430 when shouldColor $ setSGR [SetColor Foreground Dull color]
431 putStrLn $ arrow ++ showPretty msg
432 when shouldColor $ setSGR [Reset]
435 | t == LogServer = "<-- "
438 | t == LogServer = Magenta
441 showPretty = B.unpack . encodePretty