3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
11 module Language.LSP.Test.Session
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
83 -- | A session representing one instance of launching and connecting to a server.
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
96 lastMsg <- fromJust . lastReceivedMessage <$> get
97 liftIO $ throw (UnexpectedMessage s lastMsg)
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
104 -- ^ Redirect the server's stderr to this stdout, defaults to False.
105 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106 , logMessages :: Bool
107 -- ^ Trace the messages sent and received to stdout, defaults to False.
108 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111 , ignoreLogNotifications :: Bool
112 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
116 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117 -- ^ The initial workspace folders to send in the @initialize@ request.
118 -- Defaults to Nothing.
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
125 instance Default SessionConfig where
128 data SessionMessage = ServerMessage FromServerMessage
132 data SessionContext = SessionContext
135 , rootDir :: FilePath
136 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139 , requestMap :: MVar RequestMap
140 , initRsp :: MVar (ResponseMessage Initialize)
141 , config :: SessionConfig
142 , sessionCapabilities :: ClientCapabilities
145 class Monad m => HasReader r m where
147 asks :: (r -> b) -> m b
150 instance HasReader SessionContext Session where
151 ask = Session (lift $ lift Reader.ask)
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154 ask = lift $ lift Reader.ask
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162 v <- asks curTimeoutId
163 -- when updating the curtimeoutid, account for the fact that something else
164 -- might have bumped the timeoutid in the meantime
165 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
167 data SessionState = SessionState
171 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172 , overridingTimeout :: Bool
173 -- ^ The last received message from the server.
174 -- Used for providing exception information
175 , lastReceivedMessage :: Maybe FromServerMessage
176 , curDynCaps :: Map.Map T.Text SomeRegistration
177 -- ^ The capabilities that the server has dynamically registered with us so
179 , curProgressSessions :: Set.Set ProgressToken
182 class Monad m => HasState s m where
187 modify :: (s -> s) -> m ()
188 modify f = get >>= put . f
190 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191 modifyM f = get >>= f >>= put
193 instance HasState SessionState Session where
194 get = Session (lift State.get)
195 put = Session . lift . State.put
197 instance Monad m => HasState s (StateT s m) where
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
214 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
216 handler (Unexpected "ConduitParser.empty") = do
217 lastMsg <- fromJust . lastReceivedMessage <$> get
218 name <- getParserName
219 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
224 msg <- liftIO $ readChan (messageChan context)
225 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
229 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231 isLogNotification _ = False
233 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234 watchdog = Conduit.awaitForever $ \msg -> do
235 curId <- getCurTimeoutId
237 ServerMessage sMsg -> yield sMsg
238 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243 -> Handle -- ^ Server out
244 -> Maybe ProcessHandle -- ^ Server process
245 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
247 -> ClientCapabilities
248 -> FilePath -- ^ Root directory
249 -> Session () -- ^ To exit the Server properly
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253 absRootDir <- canonicalizePath rootDir
255 hSetBuffering serverIn NoBuffering
256 hSetBuffering serverOut NoBuffering
257 -- This is required to make sure that we don’t get any
258 -- newline conversion or weird encoding issues.
259 hSetBinaryMode serverIn True
260 hSetBinaryMode serverOut True
262 reqMap <- newMVar newRequestMap
263 messageChan <- newChan
264 timeoutIdVar <- newMVar 0
265 initRsp <- newEmptyMVar
267 mainThreadId <- myThreadId
269 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270 initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
273 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274 serverListenerLauncher =
275 forkIO $ catch (serverHandler serverOut context) errorHandler
276 msgTimeoutMs = messageTimeout config * 10^6
277 serverAndListenerFinalizer tid = do
279 | Just sp <- mServerProc = do
280 -- Give the server some time to exit cleanly
281 -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283 timeout msgTimeoutMs (waitForProcess sp)
285 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286 | otherwise = pure ()
287 finally (timeout msgTimeoutMs (runSession' exitServer))
288 -- Make sure to kill the listener first, before closing
289 -- handles etc via cleanupProcess
290 (killThread tid >> cleanup)
292 (result, _) <- bracket serverListenerLauncher
293 serverAndListenerFinalizer
294 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
303 respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
304 respond (FromServerMess SWindowWorkDoneProgressCreate req) =
305 sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
306 respond (FromServerMess SWorkspaceApplyEdit r) = do
307 sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
311 -- extract Uri out from DocumentChange
312 -- didn't put this in `lsp-types` because TH was getting in the way
313 documentChangeUri :: DocumentChange -> Uri
314 documentChangeUri (InL x) = x ^. textDocument . uri
315 documentChangeUri (InR (InL x)) = x ^. uri
316 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
317 documentChangeUri (InR (InR (InR x))) = x ^. uri
319 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
320 => FromServerMessage -> m ()
321 updateState (FromServerMess SProgress req) = case req ^. params . value of
323 modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
325 modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
328 -- Keep track of dynamic capability registration
329 updateState (FromServerMess SClientRegisterCapability req) = do
330 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
332 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
334 updateState (FromServerMess SClientUnregisterCapability req) = do
335 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
337 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
338 in s { curDynCaps = newCurDynCaps }
340 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
341 let List diags = n ^. params . diagnostics
342 doc = n ^. params . uri
344 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
345 in s { curDiagnostics = newDiags }
347 updateState (FromServerMess SWorkspaceApplyEdit r) = do
349 -- First, prefer the versioned documentChanges field
350 allChangeParams <- case r ^. params . edit . documentChanges of
352 mapM_ (checkIfNeedsOpened . documentChangeUri) cs
353 return $ mapMaybe getParamsFromDocumentChange cs
354 -- Then fall back to the changes field
355 Nothing -> case r ^. params . edit . changes of
357 mapM_ checkIfNeedsOpened (HashMap.keys cs)
358 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
360 error "WorkspaceEdit contains neither documentChanges nor changes!"
363 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
364 return $ s { vfs = newVFS }
366 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
367 mergedParams = map mergeParams groupedParams
369 -- TODO: Don't do this when replaying a session
370 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
372 -- Update VFS to new document versions
373 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
374 latestVersions = map ((^. textDocument) . last) sortedVersions
375 bumpedVersions = map (version . _Just +~ 1) latestVersions
377 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
380 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
381 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
382 in s { vfs = newVFS }
384 where checkIfNeedsOpened uri = do
385 oldVFS <- vfs <$> get
388 -- if its not open, open it
389 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
390 let fp = fromJust $ uriToFilePath uri
391 contents <- liftIO $ T.readFile fp
392 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
393 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
394 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
397 let (newVFS,_) = openVFS (vfs s) msg
398 return $ s { vfs = newVFS }
400 getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
401 getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) =
402 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
403 in DidChangeTextDocumentParams docId (List changeEvents)
405 getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
406 getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
407 getParamsFromDocumentChange _ = Nothing
410 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
411 -- where n is the current version
412 textDocumentVersions uri = do
413 m <- vfsMap . vfs <$> get
414 let curVer = fromMaybe 0 $
415 _lsp_version <$> m Map.!? (toNormalizedUri uri)
416 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
418 textDocumentEdits uri edits = do
419 vers <- textDocumentVersions uri
420 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
422 getChangeParams uri (List edits) = do
423 map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
425 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
426 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
427 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
428 updateState _ = return ()
430 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
432 h <- serverIn <$> ask
434 liftIO $ B.hPut h (addHeader $ encode msg)
436 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
437 -- after duration seconds. This will override the global timeout
438 -- for waiting for messages to arrive defined in 'SessionConfig'.
439 withTimeout :: Int -> Session a -> Session a
440 withTimeout duration f = do
441 chan <- asks messageChan
442 timeoutId <- getCurTimeoutId
443 modify $ \s -> s { overridingTimeout = True }
445 threadDelay (duration * 1000000)
446 writeChan chan (TimeoutMessage timeoutId)
448 bumpTimeoutId timeoutId
449 modify $ \s -> s { overridingTimeout = False }
452 data LogMsgType = LogServer | LogClient
455 -- | Logs the message if the config specified it
456 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
457 => LogMsgType -> a -> m ()
459 shouldLog <- asks $ logMessages . config
460 shouldColor <- asks $ logColor . config
461 liftIO $ when shouldLog $ do
462 when shouldColor $ setSGR [SetColor Foreground Dull color]
463 putStrLn $ arrow ++ showPretty msg
464 when shouldColor $ setSGR [Reset]
467 | t == LogServer = "<-- "
470 | t == LogServer = Magenta
473 showPretty = B.unpack . encodePretty