2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.FSNotify (watchTree, eventPath, withManager, WatchManager)
74 import qualified System.FSNotify as FS
76 import System.Process (ProcessHandle())
78 import System.FilePath.Glob (match, commonDirectory, compile)
80 -- | A session representing one instance of launching and connecting to a server.
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
93 lastMsg <- fromJust . lastReceivedMessage <$> get
94 liftIO $ throw (UnexpectedMessage s lastMsg)
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101 -- ^ Redirect the server's stderr to this stdout, defaults to False.
102 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103 , logMessages :: Bool
104 -- ^ Trace the messages sent and received to stdout, defaults to False.
105 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108 , ignoreLogNotifications :: Bool
109 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
115 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
116 defaultConfig :: SessionConfig
117 defaultConfig = SessionConfig 60 False False True Nothing False
119 instance Default SessionConfig where
122 data SessionMessage = ServerMessage FromServerMessage
126 data SessionContext = SessionContext
129 , rootDir :: FilePath
130 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
131 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
132 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
133 , requestMap :: MVar RequestMap
134 , initRsp :: MVar InitializeResponse
135 , config :: SessionConfig
136 , sessionCapabilities :: ClientCapabilities
137 , watchManager :: WatchManager
140 class Monad m => HasReader r m where
142 asks :: (r -> b) -> m b
145 instance HasReader SessionContext Session where
146 ask = Session (lift $ lift Reader.ask)
148 instance Monad m => HasReader r (ConduitT a b (StateT s (ReaderT r m))) where
149 ask = lift $ lift Reader.ask
151 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
152 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
154 -- Pass this the timeoutid you *were* waiting on
155 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
156 bumpTimeoutId prev = do
157 v <- asks curTimeoutId
158 -- when updating the curtimeoutid, account for the fact that something else
159 -- might have bumped the timeoutid in the meantime
160 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
162 data SessionState = SessionState
166 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
167 , overridingTimeout :: Bool
168 -- ^ The last received message from the server.
169 -- Used for providing exception information
170 , lastReceivedMessage :: Maybe FromServerMessage
171 , curDynCaps :: Map.Map T.Text Registration
172 -- ^ The capabilities that the server has dynamically registered with us so
174 , unwatchers :: Map.Map T.Text [IO ()]
177 class Monad m => HasState s m where
182 modify :: (s -> s) -> m ()
183 modify f = get >>= put . f
185 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
186 modifyM f = get >>= f >>= put
188 instance HasState SessionState Session where
189 get = Session (lift State.get)
190 put = Session . lift . State.put
192 instance Monad m => HasState s (StateT s m) where
196 instance (Monad m, (HasState s m)) => HasState s (ConduitT a b m)
201 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
206 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
207 runSession context state (Session session) = runReaderT (runStateT conduit state) context
209 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
211 handler (Unexpected "ConduitParser.empty") = do
212 lastMsg <- fromJust . lastReceivedMessage <$> get
213 name <- getParserName
214 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
219 msg <- liftIO $ readChan (messageChan context)
220 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
224 isLogNotification (ServerMessage (NotShowMessage _)) = True
225 isLogNotification (ServerMessage (NotLogMessage _)) = True
226 isLogNotification _ = False
228 watchdog :: ConduitT SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
229 watchdog = Conduit.awaitForever $ \msg -> do
230 curId <- getCurTimeoutId
232 ServerMessage sMsg -> yield sMsg
233 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
235 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
236 -- It also does not automatically send initialize and exit messages.
237 runSessionWithHandles :: Handle -- ^ Server in
238 -> Handle -- ^ Server out
239 -> ProcessHandle -- ^ Server process
240 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
242 -> ClientCapabilities
243 -> FilePath -- ^ Root directory
244 -> Session () -- ^ To exit the Server properly
247 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
248 absRootDir <- canonicalizePath rootDir
250 hSetBuffering serverIn NoBuffering
251 hSetBuffering serverOut NoBuffering
252 -- This is required to make sure that we don’t get any
253 -- newline conversion or weird encoding issues.
254 hSetBinaryMode serverIn True
255 hSetBinaryMode serverOut True
257 reqMap <- newMVar newRequestMap
258 messageChan <- newChan
259 timeoutIdVar <- newMVar 0
260 initRsp <- newEmptyMVar
262 mainThreadId <- myThreadId
264 withManager $ \watchManager -> do
265 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps watchManager
266 initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty
267 -- Interesting note: turning on TypeFamilies causes type inference to
268 -- infer the type runSession' :: Session () -> IO ((), SessionState)
269 -- instead of runSession' :: Session a -> IO (a , SessionState)
270 runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
272 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
273 serverListenerLauncher =
274 forkIO $ catch (serverHandler serverOut context) errorHandler
275 server = (Just serverIn, Just serverOut, Nothing, serverProc)
276 serverAndListenerFinalizer tid = do
277 finally (timeout (messageTimeout config * 1^6)
278 (runSession' exitServer))
279 -- Make sure to kill the listener first, before closing
280 -- handles etc via cleanupProcess
281 (killThread tid >> cleanupProcess server)
283 (result, _) <- bracket serverListenerLauncher
284 serverAndListenerFinalizer
285 (const $ runSession' session)
288 updateStateC :: ConduitT FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
289 updateStateC = awaitForever $ \msg -> do
293 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
294 => FromServerMessage -> m ()
296 -- Keep track of dynamic capability registration
297 updateState (ReqRegisterCapability req) = do
298 let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
300 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
302 -- Process the new registrations
303 forM_ newRegs $ \(regId, reg) -> do
304 when (reg ^. method == WorkspaceDidChangeWatchedFiles) $ do
305 processFileWatchRegistration regId reg
307 updateState (ReqUnregisterCapability req) = do
308 let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
310 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
311 in s { curDynCaps = newCurDynCaps }
313 -- Process the unregistrations
314 processFileWatchUnregistrations unRegs
316 updateState (NotPublishDiagnostics n) = do
317 let List diags = n ^. params . diagnostics
318 doc = n ^. params . uri
320 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
321 in s { curDiagnostics = newDiags }
323 updateState (ReqApplyWorkspaceEdit r) = do
325 allChangeParams <- case r ^. params . edit . documentChanges of
327 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
328 return $ map getParams cs
329 Nothing -> case r ^. params . edit . changes of
331 mapM_ checkIfNeedsOpened (HashMap.keys cs)
332 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
333 Nothing -> error "No changes!"
336 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
337 return $ s { vfs = newVFS }
339 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
340 mergedParams = map mergeParams groupedParams
342 -- TODO: Don't do this when replaying a session
343 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
345 -- Update VFS to new document versions
346 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
347 latestVersions = map ((^. textDocument) . last) sortedVersions
348 bumpedVersions = map (version . _Just +~ 1) latestVersions
350 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
353 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
354 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
355 in s { vfs = newVFS }
357 where checkIfNeedsOpened uri = do
358 oldVFS <- vfs <$> get
361 -- if its not open, open it
362 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
363 let fp = fromJust $ uriToFilePath uri
364 contents <- liftIO $ T.readFile fp
365 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
366 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
367 -- TODO: use 'sendMessage'?
368 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
371 let (newVFS,_) = openVFS (vfs s) msg
372 return $ s { vfs = newVFS }
374 getParams (TextDocumentEdit docId (List edits)) =
375 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
376 in DidChangeTextDocumentParams docId (List changeEvents)
378 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
380 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
382 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
384 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
385 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
386 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
387 updateState _ = return ()
389 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
391 h <- serverIn <$> ask
393 liftIO $ B.hPut h (addHeader $ encode msg)
395 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
396 -- after duration seconds. This will override the global timeout
397 -- for waiting for messages to arrive defined in 'SessionConfig'.
398 withTimeout :: Int -> Session a -> Session a
399 withTimeout duration f = do
400 chan <- asks messageChan
401 timeoutId <- getCurTimeoutId
402 modify $ \s -> s { overridingTimeout = True }
404 threadDelay (duration * 1000000)
405 writeChan chan (TimeoutMessage timeoutId)
407 bumpTimeoutId timeoutId
408 modify $ \s -> s { overridingTimeout = False }
411 -- TODO: add a shouldTimeout helper. need to add exceptions within Session
412 data LogMsgType = LogServer | LogClient
415 -- | Logs the message if the config specified it
416 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
417 => LogMsgType -> a -> m ()
419 shouldLog <- asks $ logMessages . config
420 shouldColor <- asks $ logColor . config
421 liftIO $ when shouldLog $ do
422 when shouldColor $ setSGR [SetColor Foreground Dull color]
423 putStrLn $ arrow ++ showPretty msg
424 when shouldColor $ setSGR [Reset]
427 | t == LogServer = "<-- "
430 | t == LogServer = Magenta
433 showPretty = B.unpack . encodePretty
437 processFileWatchRegistration :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
438 => T.Text -> Registration -> m ()
439 processFileWatchRegistration regId reg = do
440 mgr <- asks watchManager
442 regOpts <- reg ^. registerOptions
443 case fromJSON regOpts of
448 Just (DidChangeWatchedFilesRegistrationOptions (List ws)) ->
449 forM_ ws $ \(FileSystemWatcher pat' watchKind) -> do
450 pat <- liftIO $ canonicalizePath pat'
451 let glob = compile pat
452 -- the root-most dir before any globbing stuff happens
453 dir = fst $ commonDirectory glob
454 pred = match glob . eventPath
455 -- If no watchKind specified, spec defaults to all true
456 WatchKind wkC wkM wkD = fromMaybe (WatchKind True True True) watchKind
457 handle <- asks serverIn
458 unwatch <- liftIO $ watchTree mgr dir pred $ \event -> do
459 let fe = FileEvent (filePathToUri (eventPath event)) typ
461 FS.Added _ _ _ -> FcCreated
462 FS.Modified _ _ _ -> FcChanged
463 FS.Removed _ _ _ -> FcDeleted
464 -- This is a bit of a guess
465 FS.Unknown _ _ _ -> FcChanged
466 matches = case typ of
470 params = DidChangeWatchedFilesParams (List [fe])
471 msg = fmClientDidChangeWatchedFilesNotification params
472 liftIO $ when matches $ B.hPut handle (addHeader $ encode msg)
474 s { unwatchers = Map.insertWith (++) regId [unwatch] (unwatchers s) }
476 processFileWatchUnregistrations :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
478 processFileWatchUnregistrations regIds =
479 forM_ regIds $ \regId -> modifyM $ \s -> do
480 let fs = fromMaybe [] (Map.lookup regId (unwatchers s))
482 return $ s { unwatchers = Map.delete regId (unwatchers s) }