3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
9 {-# LANGUAGE TypeInType #-}
11 module Language.LSP.Test.Session
37 import Control.Applicative
38 import Control.Concurrent hiding (yield)
39 import Control.Exception
40 import Control.Lens hiding (List)
42 import Control.Monad.IO.Class
43 import Control.Monad.Except
44 #if __GLASGOW_HASKELL__ == 806
45 import Control.Monad.Fail
47 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
48 import qualified Control.Monad.Trans.Reader as Reader (ask)
49 import Control.Monad.Trans.State (StateT, runStateT)
50 import qualified Control.Monad.Trans.State as State
51 import qualified Data.ByteString.Lazy.Char8 as B
53 import Data.Aeson.Encode.Pretty
54 import Data.Conduit as Conduit
55 import Data.Conduit.Parser as Parser
59 import qualified Data.Map as Map
60 import qualified Data.Set as Set
61 import qualified Data.Text as T
62 import qualified Data.Text.IO as T
63 import qualified Data.HashMap.Strict as HashMap
66 import Language.LSP.Types.Capabilities
67 import Language.LSP.Types
68 import Language.LSP.Types.Lens
69 import qualified Language.LSP.Types.Lens as LSP
70 import Language.LSP.VFS
71 import Language.LSP.Test.Compat
72 import Language.LSP.Test.Decoding
73 import Language.LSP.Test.Exceptions
74 import System.Console.ANSI
75 import System.Directory
77 import System.Process (ProcessHandle())
78 #ifndef mingw32_HOST_OS
79 import System.Process (waitForProcess)
83 -- | A session representing one instance of launching and connecting to a server.
85 -- You can send and receive messages to the server within 'Session' via
86 -- 'Language.LSP.Test.message',
87 -- 'Language.LSP.Test.sendRequest' and
88 -- 'Language.LSP.Test.sendNotification'.
90 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
91 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
93 #if __GLASGOW_HASKELL__ >= 806
94 instance MonadFail Session where
96 lastMsg <- fromJust . lastReceivedMessage <$> get
97 liftIO $ throw (UnexpectedMessage s lastMsg)
100 -- | Stuff you can configure for a 'Session'.
101 data SessionConfig = SessionConfig
102 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
104 -- ^ Redirect the server's stderr to this stdout, defaults to False.
105 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
106 , logMessages :: Bool
107 -- ^ Trace the messages sent and received to stdout, defaults to False.
108 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
109 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
110 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
111 , ignoreLogNotifications :: Bool
112 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
113 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
116 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
117 -- ^ The initial workspace folders to send in the @initialize@ request.
118 -- Defaults to Nothing.
121 -- | The configuration used in 'Language.LSP.Test.runSession'.
122 defaultConfig :: SessionConfig
123 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
125 instance Default SessionConfig where
128 data SessionMessage = ServerMessage FromServerMessage
132 data SessionContext = SessionContext
135 , rootDir :: FilePath
136 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
137 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
138 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
139 , requestMap :: MVar RequestMap
140 , initRsp :: MVar (ResponseMessage Initialize)
141 , config :: SessionConfig
142 , sessionCapabilities :: ClientCapabilities
145 class Monad m => HasReader r m where
147 asks :: (r -> b) -> m b
150 instance HasReader SessionContext Session where
151 ask = Session (lift $ lift Reader.ask)
153 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
154 ask = lift $ lift Reader.ask
156 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
157 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
159 -- Pass this the timeoutid you *were* waiting on
160 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
161 bumpTimeoutId prev = do
162 v <- asks curTimeoutId
163 -- when updating the curtimeoutid, account for the fact that something else
164 -- might have bumped the timeoutid in the meantime
165 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
167 data SessionState = SessionState
171 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
172 , overridingTimeout :: Bool
173 -- ^ The last received message from the server.
174 -- Used for providing exception information
175 , lastReceivedMessage :: Maybe FromServerMessage
176 , curDynCaps :: Map.Map T.Text SomeRegistration
177 -- ^ The capabilities that the server has dynamically registered with us so
179 , curProgressSessions :: Set.Set ProgressToken
182 class Monad m => HasState s m where
187 modify :: (s -> s) -> m ()
188 modify f = get >>= put . f
190 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
191 modifyM f = get >>= f >>= put
193 instance HasState SessionState Session where
194 get = Session (lift State.get)
195 put = Session . lift . State.put
197 instance Monad m => HasState s (StateT s m) where
201 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
206 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
211 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
212 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
214 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
216 handler (Unexpected "ConduitParser.empty") = do
217 lastMsg <- fromJust . lastReceivedMessage <$> get
218 name <- getParserName
219 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
224 msg <- liftIO $ readChan (messageChan context)
225 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
229 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
230 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
231 isLogNotification _ = False
233 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
234 watchdog = Conduit.awaitForever $ \msg -> do
235 curId <- getCurTimeoutId
237 ServerMessage sMsg -> yield sMsg
238 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
240 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
241 -- It also does not automatically send initialize and exit messages.
242 runSession' :: Handle -- ^ Server in
243 -> Handle -- ^ Server out
244 -> Maybe ProcessHandle -- ^ Server process
245 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
247 -> ClientCapabilities
248 -> FilePath -- ^ Root directory
249 -> Session () -- ^ To exit the Server properly
252 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
253 absRootDir <- canonicalizePath rootDir
255 hSetBuffering serverIn NoBuffering
256 hSetBuffering serverOut NoBuffering
257 -- This is required to make sure that we don’t get any
258 -- newline conversion or weird encoding issues.
259 hSetBinaryMode serverIn True
260 hSetBinaryMode serverOut True
262 reqMap <- newMVar newRequestMap
263 messageChan <- newChan
264 timeoutIdVar <- newMVar 0
265 initRsp <- newEmptyMVar
267 mainThreadId <- myThreadId
269 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
270 initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
271 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
273 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
274 serverListenerLauncher =
275 forkIO $ catch (serverHandler serverOut context) errorHandler
276 msgTimeoutMs = messageTimeout config * 10^6
277 serverAndListenerFinalizer tid = do
279 | Just sp <- mServerProc = do
280 -- Give the server some time to exit cleanly
281 -- It makes the server hangs in windows so we have to avoid it
282 #ifndef mingw32_HOST_OS
283 timeout msgTimeoutMs (waitForProcess sp)
285 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
286 | otherwise = pure ()
287 finally (timeout msgTimeoutMs (runSession' exitServer))
288 -- Make sure to kill the listener first, before closing
289 -- handles etc via cleanupProcess
290 (killThread tid >> cleanup)
292 (result, _) <- bracket serverListenerLauncher
293 serverAndListenerFinalizer
294 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
297 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
298 updateStateC = awaitForever $ \msg -> do
302 -- extract Uri out from DocumentChange
303 -- didn't put this in `lsp-types` because TH was getting in the way
304 documentChangeUri :: DocumentChange -> Uri
305 documentChangeUri (InL x) = x ^. textDocument . uri
306 documentChangeUri (InR (InL x)) = x ^. uri
307 documentChangeUri (InR (InR (InL x))) = x ^. oldUri
308 documentChangeUri (InR (InR (InR x))) = x ^. uri
310 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
311 => FromServerMessage -> m ()
312 updateState (FromServerMess SProgress req) = case req ^. params . value of
314 modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
316 modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
319 -- Keep track of dynamic capability registration
320 updateState (FromServerMess SClientRegisterCapability req) = do
321 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
323 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
325 updateState (FromServerMess SClientUnregisterCapability req) = do
326 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
328 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
329 in s { curDynCaps = newCurDynCaps }
331 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
332 let List diags = n ^. params . diagnostics
333 doc = n ^. params . uri
335 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
336 in s { curDiagnostics = newDiags }
338 updateState (FromServerMess SWorkspaceApplyEdit r) = do
340 -- First, prefer the versioned documentChanges field
341 allChangeParams <- case r ^. params . edit . documentChanges of
343 mapM_ (checkIfNeedsOpened . documentChangeUri) cs
344 return $ mapMaybe getParamsFromDocumentChange cs
345 -- Then fall back to the changes field
346 Nothing -> case r ^. params . edit . changes of
348 mapM_ checkIfNeedsOpened (HashMap.keys cs)
349 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
351 error "WorkspaceEdit contains neither documentChanges nor changes!"
354 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
355 return $ s { vfs = newVFS }
357 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
358 mergedParams = map mergeParams groupedParams
360 -- TODO: Don't do this when replaying a session
361 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
363 -- Update VFS to new document versions
364 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
365 latestVersions = map ((^. textDocument) . last) sortedVersions
366 bumpedVersions = map (version . _Just +~ 1) latestVersions
368 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
371 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
372 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
373 in s { vfs = newVFS }
375 where checkIfNeedsOpened uri = do
376 oldVFS <- vfs <$> get
379 -- if its not open, open it
380 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
381 let fp = fromJust $ uriToFilePath uri
382 contents <- liftIO $ T.readFile fp
383 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
384 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
385 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
388 let (newVFS,_) = openVFS (vfs s) msg
389 return $ s { vfs = newVFS }
391 getParamsFromTextDocumentEdit :: TextDocumentEdit -> DidChangeTextDocumentParams
392 getParamsFromTextDocumentEdit (TextDocumentEdit docId (List edits)) =
393 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
394 in DidChangeTextDocumentParams docId (List changeEvents)
396 getParamsFromDocumentChange :: DocumentChange -> Maybe DidChangeTextDocumentParams
397 getParamsFromDocumentChange (InL textDocumentEdit) = Just $ getParamsFromTextDocumentEdit textDocumentEdit
398 getParamsFromDocumentChange _ = Nothing
401 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
402 -- where n is the current version
403 textDocumentVersions uri = do
404 m <- vfsMap . vfs <$> get
405 let curVer = fromMaybe 0 $
406 _lsp_version <$> m Map.!? (toNormalizedUri uri)
407 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
409 textDocumentEdits uri edits = do
410 vers <- textDocumentVersions uri
411 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
413 getChangeParams uri (List edits) = do
414 map <$> pure getParamsFromTextDocumentEdit <*> textDocumentEdits uri (reverse edits)
416 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
417 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
418 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
419 updateState _ = return ()
421 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
423 h <- serverIn <$> ask
425 liftIO $ B.hPut h (addHeader $ encode msg)
427 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
428 -- after duration seconds. This will override the global timeout
429 -- for waiting for messages to arrive defined in 'SessionConfig'.
430 withTimeout :: Int -> Session a -> Session a
431 withTimeout duration f = do
432 chan <- asks messageChan
433 timeoutId <- getCurTimeoutId
434 modify $ \s -> s { overridingTimeout = True }
436 threadDelay (duration * 1000000)
437 writeChan chan (TimeoutMessage timeoutId)
439 bumpTimeoutId timeoutId
440 modify $ \s -> s { overridingTimeout = False }
443 data LogMsgType = LogServer | LogClient
446 -- | Logs the message if the config specified it
447 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
448 => LogMsgType -> a -> m ()
450 shouldLog <- asks $ logMessages . config
451 shouldColor <- asks $ logColor . config
452 liftIO $ when shouldLog $ do
453 when shouldColor $ setSGR [SetColor Foreground Dull color]
454 putStrLn $ arrow ++ showPretty msg
455 when shouldColor $ setSGR [Reset]
458 | t == LogServer = "<-- "
461 | t == LogServer = Magenta
464 showPretty = B.unpack . encodePretty