2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
56 import qualified Data.Map as Map
57 import qualified Data.Set as Set
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
63 import Language.Haskell.LSP.Messages
64 import Language.Haskell.LSP.Types.Capabilities
65 import Language.Haskell.LSP.Types
66 import Language.Haskell.LSP.Types.Lens
67 import qualified Language.Haskell.LSP.Types.Lens as LSP
68 import Language.Haskell.LSP.VFS
69 import Language.Haskell.LSP.Test.Compat
70 import Language.Haskell.LSP.Test.Decoding
71 import Language.Haskell.LSP.Test.Exceptions
72 import System.Console.ANSI
73 import System.Directory
75 import System.Process (ProcessHandle())
76 #ifndef mingw32_HOST_OS
77 import System.Process (waitForProcess)
81 -- | A session representing one instance of launching and connecting to a server.
83 -- You can send and receive messages to the server within 'Session' via
84 -- 'Language.Haskell.LSP.Test.message',
85 -- 'Language.Haskell.LSP.Test.sendRequest' and
86 -- 'Language.Haskell.LSP.Test.sendNotification'.
88 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
89 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
91 #if __GLASGOW_HASKELL__ >= 806
92 instance MonadFail Session where
94 lastMsg <- fromJust . lastReceivedMessage <$> get
95 liftIO $ throw (UnexpectedMessage s lastMsg)
98 -- | Stuff you can configure for a 'Session'.
99 data SessionConfig = SessionConfig
100 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
102 -- ^ Redirect the server's stderr to this stdout, defaults to False.
103 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
104 , logMessages :: Bool
105 -- ^ Trace the messages sent and received to stdout, defaults to False.
106 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
107 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
108 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
109 , ignoreLogNotifications :: Bool
110 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
111 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
116 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
117 defaultConfig :: SessionConfig
118 defaultConfig = SessionConfig 60 False False True Nothing False
120 instance Default SessionConfig where
123 data SessionMessage = ServerMessage FromServerMessage
127 data SessionContext = SessionContext
130 , rootDir :: FilePath
131 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
132 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
133 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
134 , requestMap :: MVar RequestMap
135 , initRsp :: MVar InitializeResponse
136 , config :: SessionConfig
137 , sessionCapabilities :: ClientCapabilities
140 class Monad m => HasReader r m where
142 asks :: (r -> b) -> m b
145 instance HasReader SessionContext Session where
146 ask = Session (lift $ lift Reader.ask)
148 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
149 ask = lift $ lift Reader.ask
151 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
152 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
154 -- Pass this the timeoutid you *were* waiting on
155 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
156 bumpTimeoutId prev = do
157 v <- asks curTimeoutId
158 -- when updating the curtimeoutid, account for the fact that something else
159 -- might have bumped the timeoutid in the meantime
160 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
162 data SessionState = SessionState
166 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
167 , overridingTimeout :: Bool
168 -- ^ The last received message from the server.
169 -- Used for providing exception information
170 , lastReceivedMessage :: Maybe FromServerMessage
171 , curDynCaps :: Map.Map T.Text Registration
172 -- ^ The capabilities that the server has dynamically registered with us so
174 , curProgressSessions :: Set.Set ProgressToken
177 class Monad m => HasState s m where
182 modify :: (s -> s) -> m ()
183 modify f = get >>= put . f
185 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
186 modifyM f = get >>= f >>= put
188 instance HasState SessionState Session where
189 get = Session (lift State.get)
190 put = Session . lift . State.put
192 instance Monad m => HasState s (StateT s m) where
196 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
201 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
206 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
207 runSession context state (Session session) = runReaderT (runStateT conduit state) context
209 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
211 handler (Unexpected "ConduitParser.empty") = do
212 lastMsg <- fromJust . lastReceivedMessage <$> get
213 name <- getParserName
214 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
219 msg <- liftIO $ readChan (messageChan context)
220 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
224 isLogNotification (ServerMessage (NotShowMessage _)) = True
225 isLogNotification (ServerMessage (NotLogMessage _)) = True
226 isLogNotification _ = False
228 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
229 watchdog = Conduit.awaitForever $ \msg -> do
230 curId <- getCurTimeoutId
232 ServerMessage sMsg -> yield sMsg
233 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
235 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
236 -- It also does not automatically send initialize and exit messages.
237 runSessionWithHandles :: Handle -- ^ Server in
238 -> Handle -- ^ Server out
239 -> ProcessHandle -- ^ Server process
240 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
242 -> ClientCapabilities
243 -> FilePath -- ^ Root directory
244 -> Session () -- ^ To exit the Server properly
247 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
248 absRootDir <- canonicalizePath rootDir
250 hSetBuffering serverIn NoBuffering
251 hSetBuffering serverOut NoBuffering
252 -- This is required to make sure that we don’t get any
253 -- newline conversion or weird encoding issues.
254 hSetBinaryMode serverIn True
255 hSetBinaryMode serverOut True
257 reqMap <- newMVar newRequestMap
258 messageChan <- newChan
259 timeoutIdVar <- newMVar 0
260 initRsp <- newEmptyMVar
262 mainThreadId <- myThreadId
264 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
265 initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty
266 runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
268 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
269 serverListenerLauncher =
270 forkIO $ catch (serverHandler serverOut context) errorHandler
271 server = (Just serverIn, Just serverOut, Nothing, serverProc)
272 msgTimeoutMs = messageTimeout config * 10^6
273 serverAndListenerFinalizer tid = do
274 finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
275 -- Make sure to kill the listener first, before closing
276 -- handles etc via cleanupProcess
278 -- Give the server some time to exit cleanly
279 -- It makes the server hangs in windows so we have to avoid it
280 #ifndef mingw32_HOST_OS
281 timeout msgTimeoutMs (waitForProcess serverProc)
283 cleanupProcess server
285 (result, _) <- bracket serverListenerLauncher
286 serverAndListenerFinalizer
287 (const $ runSession' session)
290 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
291 updateStateC = awaitForever $ \msg -> do
295 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
296 => FromServerMessage -> m ()
297 updateState (NotWorkDoneProgressBegin req) =
298 modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
299 updateState (NotWorkDoneProgressEnd req) =
300 modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
302 -- Keep track of dynamic capability registration
303 updateState (ReqRegisterCapability req) = do
304 let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
306 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
308 updateState (ReqUnregisterCapability req) = do
309 let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
311 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
312 in s { curDynCaps = newCurDynCaps }
314 updateState (NotPublishDiagnostics n) = do
315 let List diags = n ^. params . diagnostics
316 doc = n ^. params . uri
318 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
319 in s { curDiagnostics = newDiags }
321 updateState (ReqApplyWorkspaceEdit r) = do
323 -- First, prefer the versioned documentChanges field
324 allChangeParams <- case r ^. params . edit . documentChanges of
326 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
327 return $ map getParams cs
328 -- Then fall back to the changes field
329 Nothing -> case r ^. params . edit . changes of
331 mapM_ checkIfNeedsOpened (HashMap.keys cs)
332 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
334 error "WorkspaceEdit contains neither documentChanges nor changes!"
337 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
338 return $ s { vfs = newVFS }
340 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
341 mergedParams = map mergeParams groupedParams
343 -- TODO: Don't do this when replaying a session
344 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
346 -- Update VFS to new document versions
347 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
348 latestVersions = map ((^. textDocument) . last) sortedVersions
349 bumpedVersions = map (version . _Just +~ 1) latestVersions
351 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
354 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
355 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
356 in s { vfs = newVFS }
358 where checkIfNeedsOpened uri = do
359 oldVFS <- vfs <$> get
362 -- if its not open, open it
363 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
364 let fp = fromJust $ uriToFilePath uri
365 contents <- liftIO $ T.readFile fp
366 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
367 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
368 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
371 let (newVFS,_) = openVFS (vfs s) msg
372 return $ s { vfs = newVFS }
374 getParams (TextDocumentEdit docId (List edits)) =
375 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
376 in DidChangeTextDocumentParams docId (List changeEvents)
378 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
379 -- where n is the current version
380 textDocumentVersions uri = do
381 m <- vfsMap . vfs <$> get
382 let curVer = fromMaybe 0 $
383 _lsp_version <$> m Map.!? (toNormalizedUri uri)
384 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
386 textDocumentEdits uri edits = do
387 vers <- textDocumentVersions uri
388 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
390 getChangeParams uri (List edits) =
391 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
393 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
394 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
395 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
396 updateState _ = return ()
398 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
400 h <- serverIn <$> ask
402 liftIO $ B.hPut h (addHeader $ encode msg)
404 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
405 -- after duration seconds. This will override the global timeout
406 -- for waiting for messages to arrive defined in 'SessionConfig'.
407 withTimeout :: Int -> Session a -> Session a
408 withTimeout duration f = do
409 chan <- asks messageChan
410 timeoutId <- getCurTimeoutId
411 modify $ \s -> s { overridingTimeout = True }
413 threadDelay (duration * 1000000)
414 writeChan chan (TimeoutMessage timeoutId)
416 bumpTimeoutId timeoutId
417 modify $ \s -> s { overridingTimeout = False }
420 data LogMsgType = LogServer | LogClient
423 -- | Logs the message if the config specified it
424 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
425 => LogMsgType -> a -> m ()
427 shouldLog <- asks $ logMessages . config
428 shouldColor <- asks $ logColor . config
429 liftIO $ when shouldLog $ do
430 when shouldColor $ setSGR [SetColor Foreground Dull color]
431 putStrLn $ arrow ++ showPretty msg
432 when shouldColor $ setSGR [Reset]
435 | t == LogServer = "<-- "
438 | t == LogServer = Magenta
441 showPretty = B.unpack . encodePretty