3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
10 module Language.LSP.Test.Session
35 import Control.Applicative
36 import Control.Concurrent hiding (yield)
37 import Control.Exception
38 import Control.Lens hiding (List)
40 import Control.Monad.IO.Class
41 import Control.Monad.Except
42 #if __GLASGOW_HASKELL__ == 806
43 import Control.Monad.Fail
45 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
46 import qualified Control.Monad.Trans.Reader as Reader (ask)
47 import Control.Monad.Trans.State (StateT, runStateT)
48 import qualified Control.Monad.Trans.State as State
49 import qualified Data.ByteString.Lazy.Char8 as B
51 import Data.Aeson.Encode.Pretty
52 import Data.Conduit as Conduit
53 import Data.Conduit.Parser as Parser
57 import qualified Data.Map as Map
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
63 import Language.LSP.Types.Capabilities
64 import Language.LSP.Types
65 import Language.LSP.Types.Lens
66 import qualified Language.LSP.Types.Lens as LSP
67 import Language.LSP.VFS
68 import Language.LSP.Test.Compat
69 import Language.LSP.Test.Decoding
70 import Language.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
74 import System.Process (ProcessHandle())
75 #ifndef mingw32_HOST_OS
76 import System.Process (waitForProcess)
80 -- | A session representing one instance of launching and connecting to a server.
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.LSP.Test.message',
84 -- 'Language.LSP.Test.sendRequest' and
85 -- 'Language.LSP.Test.sendNotification'.
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
93 lastMsg <- fromJust . lastReceivedMessage <$> get
94 liftIO $ throw (UnexpectedMessage s lastMsg)
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101 -- ^ Redirect the server's stderr to this stdout, defaults to False.
102 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103 , logMessages :: Bool
104 -- ^ Trace the messages sent and received to stdout, defaults to False.
105 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108 , ignoreLogNotifications :: Bool
109 -- ^ Whether or not to ignore 'Language.LSP.Types.ShowMessageNotification' and
110 -- 'Language.LSP.Types.LogMessageNotification', defaults to False.
113 , initialWorkspaceFolders :: Maybe [WorkspaceFolder]
114 -- ^ The initial workspace folders to send in the @initialize@ request.
115 -- Defaults to Nothing.
118 -- | The configuration used in 'Language.LSP.Test.runSession'.
119 defaultConfig :: SessionConfig
120 defaultConfig = SessionConfig 60 False False True Nothing False Nothing
122 instance Default SessionConfig where
125 data SessionMessage = ServerMessage FromServerMessage
129 data SessionContext = SessionContext
132 , rootDir :: FilePath
133 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
134 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
135 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
136 , requestMap :: MVar RequestMap
137 , initRsp :: MVar InitializeResponse
138 , config :: SessionConfig
139 , sessionCapabilities :: ClientCapabilities
142 class Monad m => HasReader r m where
144 asks :: (r -> b) -> m b
147 instance HasReader SessionContext Session where
148 ask = Session (lift $ lift Reader.ask)
150 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
151 ask = lift $ lift Reader.ask
153 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
154 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
156 -- Pass this the timeoutid you *were* waiting on
157 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
158 bumpTimeoutId prev = do
159 v <- asks curTimeoutId
160 -- when updating the curtimeoutid, account for the fact that something else
161 -- might have bumped the timeoutid in the meantime
162 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
164 data SessionState = SessionState
168 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
169 , overridingTimeout :: Bool
170 -- ^ The last received message from the server.
171 -- Used for providing exception information
172 , lastReceivedMessage :: Maybe FromServerMessage
173 , curDynCaps :: Map.Map T.Text SomeRegistration
174 -- ^ The capabilities that the server has dynamically registered with us so
178 class Monad m => HasState s m where
183 modify :: (s -> s) -> m ()
184 modify f = get >>= put . f
186 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
187 modifyM f = get >>= f >>= put
189 instance HasState SessionState Session where
190 get = Session (lift State.get)
191 put = Session . lift . State.put
193 instance Monad m => HasState s (StateT s m) where
197 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
202 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
207 runSessionMonad :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
208 runSessionMonad context state (Session session) = runReaderT (runStateT conduit state) context
210 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
212 handler (Unexpected "ConduitParser.empty") = do
213 lastMsg <- fromJust . lastReceivedMessage <$> get
214 name <- getParserName
215 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
220 msg <- liftIO $ readChan (messageChan context)
221 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
225 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
226 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
227 isLogNotification _ = False
229 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
230 watchdog = Conduit.awaitForever $ \msg -> do
231 curId <- getCurTimeoutId
233 ServerMessage sMsg -> yield sMsg
234 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
236 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
237 -- It also does not automatically send initialize and exit messages.
238 runSession' :: Handle -- ^ Server in
239 -> Handle -- ^ Server out
240 -> Maybe ProcessHandle -- ^ Server process
241 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
243 -> ClientCapabilities
244 -> FilePath -- ^ Root directory
245 -> Session () -- ^ To exit the Server properly
248 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
249 absRootDir <- canonicalizePath rootDir
251 hSetBuffering serverIn NoBuffering
252 hSetBuffering serverOut NoBuffering
253 -- This is required to make sure that we don’t get any
254 -- newline conversion or weird encoding issues.
255 hSetBinaryMode serverIn True
256 hSetBinaryMode serverOut True
258 reqMap <- newMVar newRequestMap
259 messageChan <- newChan
260 timeoutIdVar <- newMVar 0
261 initRsp <- newEmptyMVar
263 mainThreadId <- myThreadId
265 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
266 initState vfs = SessionState 0 vfs mempty False Nothing mempty
267 runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
269 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
270 serverListenerLauncher =
271 forkIO $ catch (serverHandler serverOut context) errorHandler
272 msgTimeoutMs = messageTimeout config * 10^6
273 serverAndListenerFinalizer tid = do
275 | Just sp <- mServerProc = do
276 -- Give the server some time to exit cleanly
277 -- It makes the server hangs in windows so we have to avoid it
278 #ifndef mingw32_HOST_OS
279 timeout msgTimeoutMs (waitForProcess sp)
281 cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
282 | otherwise = pure ()
283 finally (timeout msgTimeoutMs (runSession' exitServer))
284 -- Make sure to kill the listener first, before closing
285 -- handles etc via cleanupProcess
286 (killThread tid >> cleanup)
288 (result, _) <- bracket serverListenerLauncher
289 serverAndListenerFinalizer
290 (const $ initVFS $ \vfs -> runSessionMonad context (initState vfs) session)
293 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
294 updateStateC = awaitForever $ \msg -> do
298 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
299 => FromServerMessage -> m ()
301 -- Keep track of dynamic capability registration
302 updateState (FromServerMess SClientRegisterCapability req) = do
303 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
305 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
307 updateState (FromServerMess SClientUnregisterCapability req) = do
308 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
310 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
311 in s { curDynCaps = newCurDynCaps }
313 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
314 let List diags = n ^. params . diagnostics
315 doc = n ^. params . uri
317 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
318 in s { curDiagnostics = newDiags }
320 updateState (FromServerMess SWorkspaceApplyEdit r) = do
322 -- First, prefer the versioned documentChanges field
323 allChangeParams <- case r ^. params . edit . documentChanges of
325 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
326 return $ map getParams cs
327 -- Then fall back to the changes field
328 Nothing -> case r ^. params . edit . changes of
330 mapM_ checkIfNeedsOpened (HashMap.keys cs)
331 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
333 error "WorkspaceEdit contains neither documentChanges nor changes!"
336 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
337 return $ s { vfs = newVFS }
339 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
340 mergedParams = map mergeParams groupedParams
342 -- TODO: Don't do this when replaying a session
343 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
345 -- Update VFS to new document versions
346 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
347 latestVersions = map ((^. textDocument) . last) sortedVersions
348 bumpedVersions = map (version . _Just +~ 1) latestVersions
350 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
353 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
354 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
355 in s { vfs = newVFS }
357 where checkIfNeedsOpened uri = do
358 oldVFS <- vfs <$> get
361 -- if its not open, open it
362 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
363 let fp = fromJust $ uriToFilePath uri
364 contents <- liftIO $ T.readFile fp
365 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
366 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
367 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
370 let (newVFS,_) = openVFS (vfs s) msg
371 return $ s { vfs = newVFS }
373 getParams (TextDocumentEdit docId (List edits)) =
374 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
375 in DidChangeTextDocumentParams docId (List changeEvents)
377 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
378 -- where n is the current version
379 textDocumentVersions uri = do
380 m <- vfsMap . vfs <$> get
381 let curVer = fromMaybe 0 $
382 _lsp_version <$> m Map.!? (toNormalizedUri uri)
383 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
385 textDocumentEdits uri edits = do
386 vers <- textDocumentVersions uri
387 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
389 getChangeParams uri (List edits) =
390 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
392 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
393 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
394 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
395 updateState _ = return ()
397 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
399 h <- serverIn <$> ask
401 liftIO $ B.hPut h (addHeader $ encode msg)
403 -- | Execute a block f that will throw a 'Language.LSP.Test.Exception.Timeout' exception
404 -- after duration seconds. This will override the global timeout
405 -- for waiting for messages to arrive defined in 'SessionConfig'.
406 withTimeout :: Int -> Session a -> Session a
407 withTimeout duration f = do
408 chan <- asks messageChan
409 timeoutId <- getCurTimeoutId
410 modify $ \s -> s { overridingTimeout = True }
412 threadDelay (duration * 1000000)
413 writeChan chan (TimeoutMessage timeoutId)
415 bumpTimeoutId timeoutId
416 modify $ \s -> s { overridingTimeout = False }
419 data LogMsgType = LogServer | LogClient
422 -- | Logs the message if the config specified it
423 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
424 => LogMsgType -> a -> m ()
426 shouldLog <- asks $ logMessages . config
427 shouldColor <- asks $ logColor . config
428 liftIO $ when shouldLog $ do
429 when shouldColor $ setSGR [SetColor Foreground Dull color]
430 putStrLn $ arrow ++ showPretty msg
431 when shouldColor $ setSGR [Reset]
434 | t == LogServer = "<-- "
437 | t == LogServer = Magenta
440 showPretty = B.unpack . encodePretty