2 {-# LANGUAGE BangPatterns #-}
4 {-# LANGUAGE OverloadedStrings #-}
5 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
6 {-# LANGUAGE FlexibleInstances #-}
7 {-# LANGUAGE MultiParamTypeClasses #-}
8 {-# LANGUAGE FlexibleContexts #-}
9 {-# LANGUAGE RankNTypes #-}
10 {-# LANGUAGE TypeInType #-}
12 module Language.LSP.Test.Session
38 import Control.Applicative
39 import Control.Concurrent hiding (yield)
40 import Control.Exception
41 import Control.Lens hiding (List)
43 import Control.Monad.IO.Class
44 import Control.Monad.Except
45 #if __GLASGOW_HASKELL__ == 806
46 import Control.Monad.Fail
48 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
49 import qualified Control.Monad.Trans.Reader as Reader (ask)
50 import Control.Monad.Trans.State (StateT, runStateT)
51 import qualified Control.Monad.Trans.State as State
52 import qualified Data.ByteString.Lazy.Char8 as B
54 import Data.Aeson.Encode.Pretty
55 import Data.Conduit as Conduit
56 import Data.Conduit.Parser as Parser
60 import qualified Data.Map.Strict as Map
61 import qualified Data.Set as Set
62 import qualified Data.Text as T
63 import qualified Data.Text.IO as T
64 import qualified Data.HashMap.Strict as HashMap
67 import Language.LSP.Types.Capabilities
68 import Language.LSP.Types
69 import Language.LSP.Types.Lens
70 import qualified Language.LSP.Types.Lens as LSP
71 import Language.LSP.VFS
72 import Language.LSP.Test.Compat
73 import Language.LSP.Test.Decoding
74 import Language.LSP.Test.Exceptions
75 import System.Console.ANSI
76 import System.Directory
78 import System.Process (ProcessHandle())
79 #ifndef mingw32_HOST_OS
80 import System.Process (waitForProcess)
85 -- | A session representing one instance of launching and connecting to a server.
87 -- You can send and receive messages to the server within 'Session' via
88 -- 'Language.LSP.Test.message',
89 -- 'Language.LSP.Test.sendRequest' and
90 -- 'Language.LSP.Test.sendNotification'.
92 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
93 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
95 #if __GLASGOW_HASKELL__ >= 806
96 instance MonadFail Session where
98 lastMsg <- fromJust . lastReceivedMessage <$> get
99 liftIO $ throw (UnexpectedMessage s lastMsg)
102 -- | Stuff you can configure for a 'Session'.
103 data SessionConfig = SessionConfig
104 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
106 -- ^ Redirect the server's stderr to this stdout, defaults to False.
107 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
108 , logMessages :: Bool
109 -- ^ Trace the messages sent and received to stdout, defaults to False.
110 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
111 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
112 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
113 , ignoreLogNotifications :: Bool
114 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
115 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
118 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
119 -- ^ The initial workspace folders to send in the @initialize@ request.
120 -- Defaults to Nothing.
123 -- | The configuration used in 'Language.LSP.Test.runSession'.
124 defaultConfig :: SessionConfig
125 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
127 instance Default SessionConfig where
130 data SessionMessage = ServerMessage FromServerMessage
134 data SessionContext = SessionContext
137 , rootDir :: FilePath
138 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
139 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
140 , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on
141 , requestMap :: MVar RequestMap
142 , initRsp :: MVar (ResponseMessage Initialize)
143 , config :: SessionConfig
144 , sessionCapabilities :: ClientCapabilities
147 class Monad m => HasReader r m where
149 asks :: (r -> b) -> m b
152 instance HasReader SessionContext Session where
153 ask = Session (lift $ lift Reader.ask)
155 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
156 ask = lift $ lift Reader.ask
158 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
159 getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef
161 -- Pass this the timeoutid you *were* waiting on
162 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
163 bumpTimeoutId prev = do
164 v <- asks curTimeoutId
165 -- when updating the curtimeoutid, account for the fact that something else
166 -- might have bumped the timeoutid in the meantime
167 liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ()))
169 data SessionState = SessionState
173 , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic])
174 , overridingTimeout :: !Bool
175 -- ^ The last received message from the server.
176 -- Used for providing exception information
177 , lastReceivedMessage :: !(Maybe FromServerMessage)
178 , curDynCaps :: !(Map.Map T.Text SomeRegistration)
179 -- ^ The capabilities that the server has dynamically registered with us so
181 , curProgressSessions :: !(Set.Set ProgressToken)
184 class Monad m => HasState s m where
189 modify :: (s -> s) -> m ()
190 modify f = get >>= put . f
192 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
193 modifyM f = get >>= f >>= put
195 instance HasState SessionState Session where
196 get = Session (lift State.get)
197 put = Session . lift . State.put
199 instance Monad m => HasState s (StateT s m) where
203 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
208 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
213 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
214 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
216 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
218 handler (Unexpected "ConduitParser.empty") = do
219 lastMsg <- fromJust . lastReceivedMessage <$> get
220 name <- getParserName
221 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
226 msg <- liftIO $ readChan (messageChan context)
227 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
231 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
232 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
233 isLogNotification _ = False
235 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
236 watchdog = Conduit.awaitForever $ \msg -> do
237 curId <- getCurTimeoutId
239 ServerMessage sMsg -> yield sMsg
240 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
242 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
243 -- It also does not automatically send initialize and exit messages.
244 runSession' :: Handle -- ^ Server in
245 -> Handle -- ^ Server out
246 -> Maybe ProcessHandle -- ^ Server process
247 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
249 -> ClientCapabilities
250 -> FilePath -- ^ Root directory
251 -> Session () -- ^ To exit the Server properly
254 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
255 absRootDir <- canonicalizePath rootDir
257 hSetBuffering serverIn NoBuffering
258 hSetBuffering serverOut NoBuffering
259 -- This is required to make sure that we don’t get any
260 -- newline conversion or weird encoding issues.
261 hSetBinaryMode serverIn True
262 hSetBinaryMode serverOut True
264 reqMap <- newMVar newRequestMap
265 messageChan <- newChan
266 timeoutIdVar <- newIORef 0
267 initRsp <- newEmptyMVar
269 mainThreadId <- myThreadId
271 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
272 initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
273 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
275 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
276 serverListenerLauncher =
277 forkIO $ catch (serverHandler serverOut context) errorHandler
278 msgTimeoutMs = messageTimeout config * 10^6
279 serverAndListenerFinalizer tid = do
281 | Just sp <- mServerProc = do
282 -- Give the server some time to exit cleanly
283 -- It makes the server hangs in windows so we have to avoid it
284 #ifndef mingw32_HOST_OS
285 timeout msgTimeoutMs (waitForProcess sp)
287 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
288 | otherwise = pure ()
289 finally (timeout msgTimeoutMs (runSession' exitServer))
290 -- Make sure to kill the listener first, before closing
291 -- handles etc via cleanupProcess
292 (killThread tid >> cleanup)
294 (result, _) <- bracket serverListenerLauncher
295 serverAndListenerFinalizer
296 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
299 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
300 updateStateC = awaitForever $ \msg -> do
305 respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
306 respond (FromServerMess SWindowWorkDoneProgressCreate req) =
307 sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
308 respond (FromServerMess SWorkspaceApplyEdit r) = do
309 sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
313 -- extract Uri out from DocumentChange
314 -- didn't put this in `lsp-types` because TH was getting in the way
315 documentChangeUri :: DocumentChange -> Uri
316 documentChangeUri (InL x) = x ^. textDocument . uri
317 documentChangeUri (InR (InL x)) = x ^. uri
318 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
319 documentChangeUri (InR (InR (InR x))) = x ^. uri
321 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
322 => FromServerMessage -> m ()
323 updateState (FromServerMess SProgress req) = case req ^. params . value of
325 modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
327 modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
330 -- Keep track of dynamic capability registration
331 updateState (FromServerMess SClientRegisterCapability req) = do
332 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
334 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
336 updateState (FromServerMess SClientUnregisterCapability req) = do
337 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
339 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
340 in s { curDynCaps = newCurDynCaps }
342 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
343 let List diags = n ^. params . diagnostics
344 doc = n ^. params . uri
346 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
347 in s { curDiagnostics = newDiags }
349 updateState (FromServerMess SWorkspaceApplyEdit r) = do
351 -- First, prefer the versioned documentChanges field
352 allChangeParams <- case r ^. params . edit . documentChanges of
354 mapM_ (checkIfNeedsOpened . documentChangeUri) cs
355 return $ mapMaybe getParamsFromDocumentChange cs
356 -- Then fall back to the changes field
357 Nothing -> case r ^. params . edit . changes of
359 mapM_ checkIfNeedsOpened (HashMap.keys cs)
360 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
362 error "WorkspaceEdit contains neither documentChanges nor changes!"
365 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
366 return $ s { vfs = newVFS }
368 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
369 mergedParams = map mergeParams groupedParams
371 -- TODO: Don't do this when replaying a session
372 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
374 -- Update VFS to new document versions
375 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
376 latestVersions = map ((^. textDocument) . last) sortedVersions
377 bumpedVersions = map (version . _Just +~ 1) latestVersions
379 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
382 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
383 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
384 in s { vfs = newVFS }
386 where checkIfNeedsOpened uri = do
387 oldVFS <- vfs <$> get
390 -- if its not open, open it
391 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
392 let fp = fromJust $ uriToFilePath uri
393 contents <- liftIO $ T.readFile fp
394 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
395 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
396 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
399 let (newVFS,_) = openVFS (vfs s) msg
400 return $ s { vfs = newVFS }
402 getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
403 getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) =
404 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
405 in DidChangeTextDocumentParams docId (List changeEvents)
407 getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
408 getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
409 getParamsFromDocumentChange _ = Nothing
412 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
413 -- where n is the current version
414 textDocumentVersions uri = do
415 m <- vfsMap . vfs <$> get
416 let curVer = fromMaybe 0 $
417 _lsp_version <$> m Map.!? (toNormalizedUri uri)
418 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
420 textDocumentEdits uri edits = do
421 vers <- textDocumentVersions uri
422 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
424 getChangeParams uri (List edits) = do
425 map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
427 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
428 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
429 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
430 updateState _ = return ()
432 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
434 h <- serverIn <$> ask
436 liftIO $ B.hPut h (addHeader $ encode msg)
438 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
439 -- after duration seconds. This will override the global timeout
440 -- for waiting for messages to arrive defined in 'SessionConfig'.
441 withTimeout :: Int -> Session a -> Session a
442 withTimeout duration f = do
443 chan <- asks messageChan
444 timeoutId <- getCurTimeoutId
445 modify $ \s -> s { overridingTimeout = True }
446 tid <- liftIO $ forkIO $ do
447 threadDelay (duration * 1000000)
448 writeChan chan (TimeoutMessage timeoutId)
450 liftIO $ killThread tid
451 bumpTimeoutId timeoutId
452 modify $ \s -> s { overridingTimeout = False }
455 data LogMsgType = LogServer | LogClient
458 -- | Logs the message if the config specified it
459 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
460 => LogMsgType -> a -> m ()
462 shouldLog <- asks $ logMessages . config
463 shouldColor <- asks $ logColor . config
464 liftIO $ when shouldLog $ do
465 when shouldColor $ setSGR [SetColor Foreground Dull color]
466 putStrLn $ arrow ++ showPretty msg
467 when shouldColor $ setSGR [Reset]
470 | t == LogServer = "<-- "
473 | t == LogServer = Magenta
476 showPretty = B.unpack . encodePretty