2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
74 import System.Process (waitForProcess, ProcessHandle())
77 -- | A session representing one instance of launching and connecting to a server.
79 -- You can send and receive messages to the server within 'Session' via
80 -- 'Language.Haskell.LSP.Test.message',
81 -- 'Language.Haskell.LSP.Test.sendRequest' and
82 -- 'Language.Haskell.LSP.Test.sendNotification'.
84 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
85 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
87 #if __GLASGOW_HASKELL__ >= 806
88 instance MonadFail Session where
90 lastMsg <- fromJust . lastReceivedMessage <$> get
91 liftIO $ throw (UnexpectedMessage s lastMsg)
94 -- | Stuff you can configure for a 'Session'.
95 data SessionConfig = SessionConfig
96 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
98 -- ^ Redirect the server's stderr to this stdout, defaults to False.
99 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
100 , logMessages :: Bool
101 -- ^ Trace the messages sent and received to stdout, defaults to False.
102 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
103 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
104 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
105 , ignoreLogNotifications :: Bool
106 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
107 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
112 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
113 defaultConfig :: SessionConfig
114 defaultConfig = SessionConfig 60 False False True Nothing False
116 instance Default SessionConfig where
119 data SessionMessage = ServerMessage FromServerMessage
123 data SessionContext = SessionContext
126 , rootDir :: FilePath
127 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
128 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
129 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
130 , requestMap :: MVar RequestMap
131 , initRsp :: MVar InitializeResponse
132 , config :: SessionConfig
133 , sessionCapabilities :: ClientCapabilities
136 class Monad m => HasReader r m where
138 asks :: (r -> b) -> m b
141 instance HasReader SessionContext Session where
142 ask = Session (lift $ lift Reader.ask)
144 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
145 ask = lift $ lift Reader.ask
147 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
148 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
150 -- Pass this the timeoutid you *were* waiting on
151 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
152 bumpTimeoutId prev = do
153 v <- asks curTimeoutId
154 -- when updating the curtimeoutid, account for the fact that something else
155 -- might have bumped the timeoutid in the meantime
156 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
158 data SessionState = SessionState
162 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
163 , overridingTimeout :: Bool
164 -- ^ The last received message from the server.
165 -- Used for providing exception information
166 , lastReceivedMessage :: Maybe FromServerMessage
167 , curDynCaps :: Map.Map T.Text Registration
168 -- ^ The capabilities that the server has dynamically registered with us so
172 class Monad m => HasState s m where
177 modify :: (s -> s) -> m ()
178 modify f = get >>= put . f
180 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
181 modifyM f = get >>= f >>= put
183 instance HasState SessionState Session where
184 get = Session (lift State.get)
185 put = Session . lift . State.put
187 instance Monad m => HasState s (StateT s m) where
191 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
196 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
201 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
202 runSession context state (Session session) = runReaderT (runStateT conduit state) context
204 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
206 handler (Unexpected "ConduitParser.empty") = do
207 lastMsg <- fromJust . lastReceivedMessage <$> get
208 name <- getParserName
209 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
214 msg <- liftIO $ readChan (messageChan context)
215 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
219 isLogNotification (ServerMessage (NotShowMessage _)) = True
220 isLogNotification (ServerMessage (NotLogMessage _)) = True
221 isLogNotification _ = False
223 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
224 watchdog = Conduit.awaitForever $ \msg -> do
225 curId <- getCurTimeoutId
227 ServerMessage sMsg -> yield sMsg
228 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
230 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
231 -- It also does not automatically send initialize and exit messages.
232 runSessionWithHandles :: Handle -- ^ Server in
233 -> Handle -- ^ Server out
234 -> ProcessHandle -- ^ Server process
235 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
237 -> ClientCapabilities
238 -> FilePath -- ^ Root directory
239 -> Session () -- ^ To exit the Server properly
242 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
243 absRootDir <- canonicalizePath rootDir
245 hSetBuffering serverIn NoBuffering
246 hSetBuffering serverOut NoBuffering
247 -- This is required to make sure that we don’t get any
248 -- newline conversion or weird encoding issues.
249 hSetBinaryMode serverIn True
250 hSetBinaryMode serverOut True
252 reqMap <- newMVar newRequestMap
253 messageChan <- newChan
254 timeoutIdVar <- newMVar 0
255 initRsp <- newEmptyMVar
257 mainThreadId <- myThreadId
259 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
260 initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
261 runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
263 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
264 serverListenerLauncher =
265 forkIO $ catch (serverHandler serverOut context) errorHandler
266 server = (Just serverIn, Just serverOut, Nothing, serverProc)
267 msgTimeoutMs = messageTimeout config * 10^6
268 serverAndListenerFinalizer tid = do
269 finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
270 -- Make sure to kill the listener first, before closing
271 -- handles etc via cleanupProcess
273 -- Give the server some time to exit cleanly
274 timeout msgTimeoutMs (waitForProcess serverProc)
275 cleanupProcess server
277 (result, _) <- bracket serverListenerLauncher
278 serverAndListenerFinalizer
279 (const $ runSession' session)
282 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
283 updateStateC = awaitForever $ \msg -> do
287 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
288 => FromServerMessage -> m ()
290 -- Keep track of dynamic capability registration
291 updateState (ReqRegisterCapability req) = do
292 let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
294 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
296 updateState (ReqUnregisterCapability req) = do
297 let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
299 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
300 in s { curDynCaps = newCurDynCaps }
302 updateState (NotPublishDiagnostics n) = do
303 let List diags = n ^. params . diagnostics
304 doc = n ^. params . uri
306 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
307 in s { curDiagnostics = newDiags }
309 updateState (ReqApplyWorkspaceEdit r) = do
311 -- First, prefer the versioned documentChanges field
312 allChangeParams <- case r ^. params . edit . documentChanges of
314 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
315 return $ map getParams cs
316 -- Then fall back to the changes field
317 Nothing -> case r ^. params . edit . changes of
319 mapM_ checkIfNeedsOpened (HashMap.keys cs)
320 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
322 error "WorkspaceEdit contains neither documentChanges nor changes!"
325 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
326 return $ s { vfs = newVFS }
328 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
329 mergedParams = map mergeParams groupedParams
331 -- TODO: Don't do this when replaying a session
332 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
334 -- Update VFS to new document versions
335 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
336 latestVersions = map ((^. textDocument) . last) sortedVersions
337 bumpedVersions = map (version . _Just +~ 1) latestVersions
339 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
342 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
343 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
344 in s { vfs = newVFS }
346 where checkIfNeedsOpened uri = do
347 oldVFS <- vfs <$> get
350 -- if its not open, open it
351 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
352 let fp = fromJust $ uriToFilePath uri
353 contents <- liftIO $ T.readFile fp
354 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
355 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
356 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
359 let (newVFS,_) = openVFS (vfs s) msg
360 return $ s { vfs = newVFS }
362 getParams (TextDocumentEdit docId (List edits)) =
363 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
364 in DidChangeTextDocumentParams docId (List changeEvents)
366 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
367 -- where n is the current version
368 textDocumentVersions uri = do
369 m <- vfsMap . vfs <$> get
370 let curVer = fromMaybe 0 $
371 _lsp_version <$> m Map.!? (toNormalizedUri uri)
372 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer..]
374 textDocumentEdits uri edits = do
375 vers <- textDocumentVersions uri
376 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
378 getChangeParams uri (List edits) =
379 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
381 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
382 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
383 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
384 updateState _ = return ()
386 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
388 h <- serverIn <$> ask
390 liftIO $ B.hPut h (addHeader $ encode msg)
392 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
393 -- after duration seconds. This will override the global timeout
394 -- for waiting for messages to arrive defined in 'SessionConfig'.
395 withTimeout :: Int -> Session a -> Session a
396 withTimeout duration f = do
397 chan <- asks messageChan
398 timeoutId <- getCurTimeoutId
399 modify $ \s -> s { overridingTimeout = True }
401 threadDelay (duration * 1000000)
402 writeChan chan (TimeoutMessage timeoutId)
404 bumpTimeoutId timeoutId
405 modify $ \s -> s { overridingTimeout = False }
408 data LogMsgType = LogServer | LogClient
411 -- | Logs the message if the config specified it
412 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
413 => LogMsgType -> a -> m ()
415 shouldLog <- asks $ logMessages . config
416 shouldColor <- asks $ logColor . config
417 liftIO $ when shouldLog $ do
418 when shouldColor $ setSGR [SetColor Foreground Dull color]
419 putStrLn $ arrow ++ showPretty msg
420 when shouldColor $ setSGR [Reset]
423 | t == LogServer = "<-- "
426 | t == LogServer = Magenta
429 showPretty = B.unpack . encodePretty