3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
11 module Language.LSP.Test.Session
36 import Control.Applicative
37 import Control.Concurrent hiding (yield)
38 import Control.Exception
39 import Control.Lens hiding (List)
41 import Control.Monad.IO.Class
42 import Control.Monad.Except
43 #if __GLASGOW_HASKELL__ == 806
44 import Control.Monad.Fail
46 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
47 import qualified Control.Monad.Trans.Reader as Reader (ask)
48 import Control.Monad.Trans.State (StateT, runStateT)
49 import qualified Control.Monad.Trans.State as State
50 import qualified Data.ByteString.Lazy.Char8 as B
52 import Data.Aeson.Encode.Pretty
53 import Data.Conduit as Conduit
54 import Data.Conduit.Parser as Parser
58 import qualified Data.Map as Map
59 import qualified Data.Text as T
60 import qualified Data.Text.IO as T
61 import qualified Data.HashMap.Strict as HashMap
64 import Language.LSP.Types.Capabilities
65 import Language.LSP.Types
66 import Language.LSP.Types.Lens
67 import qualified Language.LSP.Types.Lens as LSP
68 import Language.LSP.VFS
69 import Language.LSP.Test.Compat
70 import Language.LSP.Test.Decoding
71 import Language.LSP.Test.Exceptions
72 import System.Console.ANSI
73 import System.Directory
75 import System.Process (ProcessHandle())
76 #ifndef mingw32_HOST_OS
77 import System.Process (waitForProcess)
81 -- | A session representing one instance of launching and connecting to a server.
83 -- You can send and receive messages to the server within 'Session' via
84 -- 'Language.LSP.Test.message',
85 -- 'Language.LSP.Test.sendRequest' and
86 -- 'Language.LSP.Test.sendNotification'.
88 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
89 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
91 #if __GLASGOW_HASKELL__ >= 806
92 instance MonadFail Session where
94 lastMsg <- fromJust . lastReceivedMessage <$> get
95 liftIO $ throw (UnexpectedMessage s lastMsg)
98 -- | Stuff you can configure for a 'Session'.
99 data SessionConfig = SessionConfig
100 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
102 -- ^ Redirect the server's stderr to this stdout, defaults to False.
103 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
104 , logMessages :: Bool
105 -- ^ Trace the messages sent and received to stdout, defaults to False.
106 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
107 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
108 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
109 , ignoreLogNotifications :: Bool
110 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
111 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
114 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
115 -- ^ The initial workspace folders to send in the @initialize@ request.
116 -- Defaults to Nothing.
119 -- | The configuration used in 'Language.LSP.Test.runSession'.
120 defaultConfig :: SessionConfig
121 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
123 instance Default SessionConfig where
126 data SessionMessage = ServerMessage FromServerMessage
130 data SessionContext = SessionContext
133 , rootDir :: FilePath
134 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
135 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
136 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
137 , requestMap :: MVar RequestMap
138 , initRsp :: MVar (ResponseMessage Initialize)
139 , config :: SessionConfig
140 , sessionCapabilities :: ClientCapabilities
143 class Monad m => HasReader r m where
145 asks :: (r -> b) -> m b
148 instance HasReader SessionContext Session where
149 ask = Session (lift $ lift Reader.ask)
151 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
152 ask = lift $ lift Reader.ask
154 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
155 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
157 -- Pass this the timeoutid you *were* waiting on
158 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
159 bumpTimeoutId prev = do
160 v <- asks curTimeoutId
161 -- when updating the curtimeoutid, account for the fact that something else
162 -- might have bumped the timeoutid in the meantime
163 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
165 data SessionState = SessionState
169 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
170 , overridingTimeout :: Bool
171 -- ^ The last received message from the server.
172 -- Used for providing exception information
173 , lastReceivedMessage :: Maybe FromServerMessage
174 , curDynCaps :: Map.Map T.Text SomeRegistration
175 -- ^ The capabilities that the server has dynamically registered with us so
179 class Monad m => HasState s m where
184 modify :: (s -> s) -> m ()
185 modify f = get >>= put . f
187 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
188 modifyM f = get >>= f >>= put
190 instance HasState SessionState Session where
191 get = Session (lift State.get)
192 put = Session . lift . State.put
194 instance Monad m => HasState s (StateT s m) where
198 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
203 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
208 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
209 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
211 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
213 handler (Unexpected "ConduitParser.empty") = do
214 lastMsg <- fromJust . lastReceivedMessage <$> get
215 name <- getParserName
216 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
221 msg <- liftIO $ readChan (messageChan context)
222 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
226 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
227 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
228 isLogNotification _ = False
230 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
231 watchdog = Conduit.awaitForever $ \msg -> do
232 curId <- getCurTimeoutId
234 ServerMessage sMsg -> yield sMsg
235 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
237 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
238 -- It also does not automatically send initialize and exit messages.
239 runSession' :: Handle -- ^ Server in
240 -> Handle -- ^ Server out
241 -> Maybe ProcessHandle -- ^ Server process
242 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
244 -> ClientCapabilities
245 -> FilePath -- ^ Root directory
246 -> Session () -- ^ To exit the Server properly
249 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
250 absRootDir <- canonicalizePath rootDir
252 hSetBuffering serverIn NoBuffering
253 hSetBuffering serverOut NoBuffering
254 -- This is required to make sure that we don’t get any
255 -- newline conversion or weird encoding issues.
256 hSetBinaryMode serverIn True
257 hSetBinaryMode serverOut True
259 reqMap <- newMVar newRequestMap
260 messageChan <- newChan
261 timeoutIdVar <- newMVar 0
262 initRsp <- newEmptyMVar
264 mainThreadId <- myThreadId
266 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
267 initState vfs = SessionState 0 vfs mempty False Nothing mempty
268 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
270 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
271 serverListenerLauncher =
272 forkIO $ catch (serverHandler serverOut context) errorHandler
273 msgTimeoutMs = messageTimeout config * 10^6
274 serverAndListenerFinalizer tid = do
276 | Just sp <- mServerProc = do
277 -- Give the server some time to exit cleanly
278 -- It makes the server hangs in windows so we have to avoid it
279 #ifndef mingw32_HOST_OS
280 timeout msgTimeoutMs (waitForProcess sp)
282 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
283 | otherwise = pure ()
284 finally (timeout msgTimeoutMs (runSession' exitServer))
285 -- Make sure to kill the listener first, before closing
286 -- handles etc via cleanupProcess
287 (killThread tid >> cleanup)
289 (result, _) <- bracket serverListenerLauncher
290 serverAndListenerFinalizer
291 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
294 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
295 updateStateC = awaitForever $ \msg -> do
300 -- extract Uri out from DocumentChange
301 documentChangeUri :: DocumentChange -> Uri
302 documentChangeUri (InL x) = x ^. textDocument . uri
303 documentChangeUri (InR (InL x)) = x ^. uri
304 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
305 documentChangeUri (InR (InR (InR x))) = x ^. uri
307 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
308 => FromServerMessage -> m ()
310 -- Keep track of dynamic capability registration
311 updateState (FromServerMess SClientRegisterCapability req) = do
312 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
314 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
316 updateState (FromServerMess SClientUnregisterCapability req) = do
317 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
319 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
320 in s { curDynCaps = newCurDynCaps }
322 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
323 let List diags = n ^. params . diagnostics
324 doc = n ^. params . uri
326 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
327 in s { curDiagnostics = newDiags }
329 updateState (FromServerMess SWorkspaceApplyEdit r) = do
331 -- First, prefer the versioned documentChanges field
332 allChangeParams <- case r ^. params . edit . documentChanges of
334 mapM_ (checkIfNeedsOpened . documentChangeUri) cs
335 return $ mapMaybe getParamsFromDocumentChange cs
336 -- Then fall back to the changes field
337 Nothing -> case r ^. params . edit . changes of
339 mapM_ checkIfNeedsOpened (HashMap.keys cs)
340 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
342 error "WorkspaceEdit contains neither documentChanges nor changes!"
345 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
346 return $ s { vfs = newVFS }
348 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
349 mergedParams = map mergeParams groupedParams
351 -- TODO: Don't do this when replaying a session
352 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
354 -- Update VFS to new document versions
355 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
356 latestVersions = map ((^. textDocument) . last) sortedVersions
357 bumpedVersions = map (version . _Just +~ 1) latestVersions
359 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
362 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
363 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
364 in s { vfs = newVFS }
366 where checkIfNeedsOpened uri = do
367 oldVFS <- vfs <$> get
370 -- if its not open, open it
371 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
372 let fp = fromJust $ uriToFilePath uri
373 contents <- liftIO $ T.readFile fp
374 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
375 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
376 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
379 let (newVFS,_) = openVFS (vfs s) msg
380 return $ s { vfs = newVFS }
382 getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
383 getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) =
384 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
385 in DidChangeTextDocumentParams docId (List changeEvents)
387 getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
388 getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
389 getParamsFromDocumentChange _ = Nothing
392 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
393 -- where n is the current version
394 textDocumentVersions uri = do
395 m <- vfsMap . vfs <$> get
396 let curVer = fromMaybe 0 $
397 _lsp_version <$> m Map.!? (toNormalizedUri uri)
398 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
400 textDocumentEdits uri edits = do
401 vers <- textDocumentVersions uri
402 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
404 getChangeParams uri (List edits) = do
405 map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
407 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
408 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
409 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
410 updateState _ = return ()
412 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
414 h <- serverIn <$> ask
416 liftIO $ B.hPut h (addHeader $ encode msg)
418 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
419 -- after duration seconds. This will override the global timeout
420 -- for waiting for messages to arrive defined in 'SessionConfig'.
421 withTimeout :: Int -> Session a -> Session a
422 withTimeout duration f = do
423 chan <- asks messageChan
424 timeoutId <- getCurTimeoutId
425 modify $ \s -> s { overridingTimeout = True }
427 threadDelay (duration * 1000000)
428 writeChan chan (TimeoutMessage timeoutId)
430 bumpTimeoutId timeoutId
431 modify $ \s -> s { overridingTimeout = False }
434 data LogMsgType = LogServer | LogClient
437 -- | Logs the message if the config specified it
438 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
439 => LogMsgType -> a -> m ()
441 shouldLog <- asks $ logMessages . config
442 shouldColor <- asks $ logColor . config
443 liftIO $ when shouldLog $ do
444 when shouldColor $ setSGR [SetColor Foreground Dull color]
445 putStrLn $ arrow ++ showPretty msg
446 when shouldColor $ setSGR [Reset]
449 | t == LogServer = "<-- "
452 | t == LogServer = Magenta
455 showPretty = B.unpack . encodePretty