3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
11 module Language.LSP.Test.Session
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
83 -- | A session representing one instance of launching and connecting to a server.
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
96 lastMsg <- fromJust . lastReceivedMessage <$> get
97 liftIO $ throw (UnexpectedMessage s lastMsg)
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
104 -- ^ Redirect the server's stderr to this stdout, defaults to False.
105 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106 , logMessages :: Bool
107 -- ^ Trace the messages sent and received to stdout, defaults to False.
108 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111 , ignoreLogNotifications :: Bool
112 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
116 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117 -- ^ The initial workspace folders to send in the @initialize@ request.
118 -- Defaults to Nothing.
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
125 instance Default SessionConfig where
128 data SessionMessage = ServerMessage FromServerMessage
132 data SessionContext = SessionContext
135 , rootDir :: FilePath
136 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139 , requestMap :: MVar RequestMap
140 , initRsp :: MVar (ResponseMessage Initialize)
141 , config :: SessionConfig
142 , sessionCapabilities :: ClientCapabilities
145 class Monad m => HasReader r m where
147 asks :: (r -> b) -> m b
150 instance HasReader SessionContext Session where
151 ask = Session (lift $ lift Reader.ask)
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154 ask = lift $ lift Reader.ask
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162 v <- asks curTimeoutId
163 -- when updating the curtimeoutid, account for the fact that something else
164 -- might have bumped the timeoutid in the meantime
165 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
167 data SessionState = SessionState
171 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172 , overridingTimeout :: Bool
173 -- ^ The last received message from the server.
174 -- Used for providing exception information
175 , lastReceivedMessage :: Maybe FromServerMessage
176 , curDynCaps :: Map.Map T.Text SomeRegistration
177 -- ^ The capabilities that the server has dynamically registered with us so
179 , curProgressSessions :: Set.Set ProgressToken
182 class Monad m => HasState s m where
187 modify :: (s -> s) -> m ()
188 modify f = get >>= put . f
190 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191 modifyM f = get >>= f >>= put
193 instance HasState SessionState Session where
194 get = Session (lift State.get)
195 put = Session . lift . State.put
197 instance Monad m => HasState s (StateT s m) where
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
214 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
216 handler (Unexpected "ConduitParser.empty") = do
217 lastMsg <- fromJust . lastReceivedMessage <$> get
218 name <- getParserName
219 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
224 msg <- liftIO $ readChan (messageChan context)
225 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
229 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231 isLogNotification _ = False
233 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234 watchdog = Conduit.awaitForever $ \msg -> do
235 curId <- getCurTimeoutId
237 ServerMessage sMsg -> yield sMsg
238 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243 -> Handle -- ^ Server out
244 -> Maybe ProcessHandle -- ^ Server process
245 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
247 -> ClientCapabilities
248 -> FilePath -- ^ Root directory
249 -> Session () -- ^ To exit the Server properly
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253 absRootDir <- canonicalizePath rootDir
255 hSetBuffering serverIn NoBuffering
256 hSetBuffering serverOut NoBuffering
257 -- This is required to make sure that we don’t get any
258 -- newline conversion or weird encoding issues.
259 hSetBinaryMode serverIn True
260 hSetBinaryMode serverOut True
262 reqMap <- newMVar newRequestMap
263 messageChan <- newChan
264 timeoutIdVar <- newMVar 0
265 initRsp <- newEmptyMVar
267 mainThreadId <- myThreadId
269 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270 initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
273 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274 serverListenerLauncher =
275 forkIO $ catch (serverHandler serverOut context) errorHandler
276 msgTimeoutMs = messageTimeout config * 10^6
277 serverAndListenerFinalizer tid = do
279 | Just sp <- mServerProc = do
280 -- Give the server some time to exit cleanly
281 -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283 timeout msgTimeoutMs (waitForProcess sp)
285 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286 | otherwise = pure ()
287 finally (timeout msgTimeoutMs (runSession' exitServer))
288 -- Make sure to kill the listener first, before closing
289 -- handles etc via cleanupProcess
290 (killThread tid >> cleanup)
292 (result, _) <- bracket serverListenerLauncher
293 serverAndListenerFinalizer
294 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
302 -- extract Uri out from DocumentChange
303 -- didn't put this in `lsp-types` because TH was getting in the way
304 documentChangeUri :: DocumentChange -> Uri
305 documentChangeUri (InL x) = x ^. textDocument . uri
306 documentChangeUri (InR (InL x)) = x ^. uri
307 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
308 documentChangeUri (InR (InR (InR x))) = x ^. uri
310 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
311 => FromServerMessage -> m ()
312 updateState (FromServerMess SWindowWorkDoneProgressCreate req) =
313 sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
314 updateState (FromServerMess SProgress req) = case req ^. params . value of
316 modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
318 modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
321 -- Keep track of dynamic capability registration
322 updateState (FromServerMess SClientRegisterCapability req) = do
323 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
325 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
327 updateState (FromServerMess SClientUnregisterCapability req) = do
328 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
330 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
331 in s { curDynCaps = newCurDynCaps }
333 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
334 let List diags = n ^. params . diagnostics
335 doc = n ^. params . uri
337 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
338 in s { curDiagnostics = newDiags }
340 updateState (FromServerMess SWorkspaceApplyEdit r) = do
342 -- First, prefer the versioned documentChanges field
343 allChangeParams <- case r ^. params . edit . documentChanges of
345 mapM_ (checkIfNeedsOpened . documentChangeUri) cs
346 return $ mapMaybe getParamsFromDocumentChange cs
347 -- Then fall back to the changes field
348 Nothing -> case r ^. params . edit . changes of
350 mapM_ checkIfNeedsOpened (HashMap.keys cs)
351 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
353 error "WorkspaceEdit contains neither documentChanges nor changes!"
356 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
357 return $ s { vfs = newVFS }
359 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
360 mergedParams = map mergeParams groupedParams
362 -- TODO: Don't do this when replaying a session
363 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
365 sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
367 -- Update VFS to new document versions
368 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
369 latestVersions = map ((^. textDocument) . last) sortedVersions
370 bumpedVersions = map (version . _Just +~ 1) latestVersions
372 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
375 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
376 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
377 in s { vfs = newVFS }
379 where checkIfNeedsOpened uri = do
380 oldVFS <- vfs <$> get
383 -- if its not open, open it
384 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
385 let fp = fromJust $ uriToFilePath uri
386 contents <- liftIO $ T.readFile fp
387 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
388 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
389 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
392 let (newVFS,_) = openVFS (vfs s) msg
393 return $ s { vfs = newVFS }
395 getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
396 getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) =
397 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
398 in DidChangeTextDocumentParams docId (List changeEvents)
400 getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
401 getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
402 getParamsFromDocumentChange _ = Nothing
405 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
406 -- where n is the current version
407 textDocumentVersions uri = do
408 m <- vfsMap . vfs <$> get
409 let curVer = fromMaybe 0 $
410 _lsp_version <$> m Map.!? (toNormalizedUri uri)
411 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
413 textDocumentEdits uri edits = do
414 vers <- textDocumentVersions uri
415 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
417 getChangeParams uri (List edits) = do
418 map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
420 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
421 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
422 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
423 updateState _ = return ()
425 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
427 h <- serverIn <$> ask
429 liftIO $ B.hPut h (addHeader $ encode msg)
431 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
432 -- after duration seconds. This will override the global timeout
433 -- for waiting for messages to arrive defined in 'SessionConfig'.
434 withTimeout :: Int -> Session a -> Session a
435 withTimeout duration f = do
436 chan <- asks messageChan
437 timeoutId <- getCurTimeoutId
438 modify $ \s -> s { overridingTimeout = True }
440 threadDelay (duration * 1000000)
441 writeChan chan (TimeoutMessage timeoutId)
443 bumpTimeoutId timeoutId
444 modify $ \s -> s { overridingTimeout = False }
447 data LogMsgType = LogServer | LogClient
450 -- | Logs the message if the config specified it
451 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
452 => LogMsgType -> a -> m ()
454 shouldLog <- asks $ logMessages . config
455 shouldColor <- asks $ logColor . config
456 liftIO $ when shouldLog $ do
457 when shouldColor $ setSGR [SetColor Foreground Dull color]
458 putStrLn $ arrow ++ showPretty msg
459 when shouldColor $ setSGR [Reset]
462 | t == LogServer = "<-- "
465 | t == LogServer = Magenta
468 showPretty = B.unpack . encodePretty