3 {-# LANGUAGE OverloadedStrings #-}
4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
5 {-# LANGUAGE FlexibleInstances #-}
6 {-# LANGUAGE MultiParamTypeClasses #-}
7 {-# LANGUAGE FlexibleContexts #-}
8 {-# LANGUAGE RankNTypes #-}
10 module Language.Haskell.LSP.Test.Session
35 import Control.Applicative
36 import Control.Concurrent hiding (yield)
37 import Control.Exception
38 import Control.Lens hiding (List)
40 import Control.Monad.IO.Class
41 import Control.Monad.Except
42 #if __GLASGOW_HASKELL__ == 806
43 import Control.Monad.Fail
45 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
46 import qualified Control.Monad.Trans.Reader as Reader (ask)
47 import Control.Monad.Trans.State (StateT, runStateT)
48 import qualified Control.Monad.Trans.State as State
49 import qualified Data.ByteString.Lazy.Char8 as B
51 import Data.Aeson.Encode.Pretty
52 import Data.Conduit as Conduit
53 import Data.Conduit.Parser as Parser
57 import qualified Data.Map as Map
58 import qualified Data.Text as T
59 import qualified Data.Text.IO as T
60 import qualified Data.HashMap.Strict as HashMap
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
74 import System.Process (ProcessHandle())
75 #ifndef mingw32_HOST_OS
76 import System.Process (waitForProcess)
80 -- | A session representing one instance of launching and connecting to a server.
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
93 lastMsg <- fromJust . lastReceivedMessage <$> get
94 liftIO $ throw (UnexpectedMessage s lastMsg)
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101 -- ^ Redirect the server's stderr to this stdout, defaults to False.
102 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103 , logMessages :: Bool
104 -- ^ Trace the messages sent and received to stdout, defaults to False.
105 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108 , ignoreLogNotifications :: Bool
109 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
115 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
116 defaultConfig :: SessionConfig
117 defaultConfig = SessionConfig 60 False False True Nothing False
119 instance Default SessionConfig where
122 data SessionMessage = ServerMessage FromServerMessage
126 data SessionContext = SessionContext
129 , rootDir :: FilePath
130 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
131 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
132 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
133 , requestMap :: MVar RequestMap
134 , initRsp :: MVar InitializeResponse
135 , config :: SessionConfig
136 , sessionCapabilities :: ClientCapabilities
139 class Monad m => HasReader r m where
141 asks :: (r -> b) -> m b
144 instance HasReader SessionContext Session where
145 ask = Session (lift $ lift Reader.ask)
147 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
148 ask = lift $ lift Reader.ask
150 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
151 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
153 -- Pass this the timeoutid you *were* waiting on
154 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
155 bumpTimeoutId prev = do
156 v <- asks curTimeoutId
157 -- when updating the curtimeoutid, account for the fact that something else
158 -- might have bumped the timeoutid in the meantime
159 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
161 data SessionState = SessionState
165 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
166 , overridingTimeout :: Bool
167 -- ^ The last received message from the server.
168 -- Used for providing exception information
169 , lastReceivedMessage :: Maybe FromServerMessage
170 , curDynCaps :: Map.Map T.Text SomeRegistration
171 -- ^ The capabilities that the server has dynamically registered with us so
175 class Monad m => HasState s m where
180 modify :: (s -> s) -> m ()
181 modify f = get >>= put . f
183 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
184 modifyM f = get >>= f >>= put
186 instance HasState SessionState Session where
187 get = Session (lift State.get)
188 put = Session . lift . State.put
190 instance Monad m => HasState s (StateT s m) where
194 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
199 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
204 runSessionasdf :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
205 runSessionasdf context state (Session session) = runReaderT (runStateT conduit state) context
207 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
209 handler (Unexpected "ConduitParser.empty") = do
210 lastMsg <- fromJust . lastReceivedMessage <$> get
211 name <- getParserName
212 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
217 msg <- liftIO $ readChan (messageChan context)
218 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
222 isLogNotification (ServerMessage (FromServerMess SWindowShowMessage _)) = True
223 isLogNotification (ServerMessage (FromServerMess SWindowLogMessage _)) = True
224 isLogNotification _ = False
226 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
227 watchdog = Conduit.awaitForever $ \msg -> do
228 curId <- getCurTimeoutId
230 ServerMessage sMsg -> yield sMsg
231 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
233 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
234 -- It also does not automatically send initialize and exit messages.
235 runSession' :: Handle -- ^ Server in
236 -> Handle -- ^ Server out
237 -> Maybe ProcessHandle -- ^ Server process
238 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
240 -> ClientCapabilities
241 -> FilePath -- ^ Root directory
242 -> Session () -- ^ To exit the Server properly
245 runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exitServer session = do
246 absRootDir <- canonicalizePath rootDir
248 hSetBuffering serverIn NoBuffering
249 hSetBuffering serverOut NoBuffering
250 -- This is required to make sure that we don’t get any
251 -- newline conversion or weird encoding issues.
252 hSetBinaryMode serverIn True
253 hSetBinaryMode serverOut True
255 reqMap <- newMVar newRequestMap
256 messageChan <- newChan
257 timeoutIdVar <- newMVar 0
258 initRsp <- newEmptyMVar
260 mainThreadId <- myThreadId
262 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
263 initState vfs = SessionState 0 vfs mempty False Nothing mempty
264 runSession' ses = initVFS $ \vfs -> runSessionasdf context (initState vfs) ses
266 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
267 serverListenerLauncher =
268 forkIO $ catch (serverHandler serverOut context) errorHandler
269 server = (Just serverIn, Just serverOut, Nothing, serverProc)
270 msgTimeoutMs = messageTimeout config * 10^6
271 serverAndListenerFinalizer tid = do
273 | Just sp <- mServerProc = cleanupProcess (Just serverIn, Just serverOut, Nothing, sp)
274 | otherwise = pure ()
275 finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
276 -- Make sure to kill the listener first, before closing
277 -- handles etc via cleanupProcess
279 -- Give the server some time to exit cleanly
280 #ifndef mingw32_HOST_OS
281 timeout msgTimeoutMs (waitForProcess serverProc)
285 (result, _) <- bracket serverListenerLauncher
286 serverAndListenerFinalizer
287 (const $ initVFS $ \vfs -> runSessionasdf context (initState vfs) session)
290 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
291 updateStateC = awaitForever $ \msg -> do
295 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
296 => FromServerMessage -> m ()
298 -- Keep track of dynamic capability registration
299 updateState (FromServerMess SClientRegisterCapability req) = do
300 let List newRegs = (\sr@(SomeRegistration r) -> (r ^. LSP.id, sr)) <$> req ^. params . registrations
302 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
304 updateState (FromServerMess SClientUnregisterCapability req) = do
305 let List unRegs = (^. LSP.id) <$> req ^. params . unregisterations
307 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
308 in s { curDynCaps = newCurDynCaps }
310 updateState (FromServerMess STextDocumentPublishDiagnostics n) = do
311 let List diags = n ^. params . diagnostics
312 doc = n ^. params . uri
314 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
315 in s { curDiagnostics = newDiags }
317 updateState (FromServerMess SWorkspaceApplyEdit r) = do
319 -- First, prefer the versioned documentChanges field
320 allChangeParams <- case r ^. params . edit . documentChanges of
322 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
323 return $ map getParams cs
324 -- Then fall back to the changes field
325 Nothing -> case r ^. params . edit . changes of
327 mapM_ checkIfNeedsOpened (HashMap.keys cs)
328 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
330 error "WorkspaceEdit contains neither documentChanges nor changes!"
333 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
334 return $ s { vfs = newVFS }
336 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
337 mergedParams = map mergeParams groupedParams
339 -- TODO: Don't do this when replaying a session
340 forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
342 -- Update VFS to new document versions
343 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
344 latestVersions = map ((^. textDocument) . last) sortedVersions
345 bumpedVersions = map (version . _Just +~ 1) latestVersions
347 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
350 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
351 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
352 in s { vfs = newVFS }
354 where checkIfNeedsOpened uri = do
355 oldVFS <- vfs <$> get
358 -- if its not open, open it
359 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
360 let fp = fromJust $ uriToFilePath uri
361 contents <- liftIO $ T.readFile fp
362 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
363 msg = NotificationMessage "2.0" STextDocumentDidOpen (DidOpenTextDocumentParams item)
364 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
367 let (newVFS,_) = openVFS (vfs s) msg
368 return $ s { vfs = newVFS }
370 getParams (TextDocumentEdit docId (List edits)) =
371 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
372 in DidChangeTextDocumentParams docId (List changeEvents)
374 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
375 -- where n is the current version
376 textDocumentVersions uri = do
377 m <- vfsMap . vfs <$> get
378 let curVer = fromMaybe 0 $
379 _lsp_version <$> m Map.!? (toNormalizedUri uri)
380 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
382 textDocumentEdits uri edits = do
383 vers <- textDocumentVersions uri
384 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
386 getChangeParams uri (List edits) =
387 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
389 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
390 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
391 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
392 updateState _ = return ()
394 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
396 h <- serverIn <$> ask
398 liftIO $ B.hPut h (addHeader $ encode msg)
400 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
401 -- after duration seconds. This will override the global timeout
402 -- for waiting for messages to arrive defined in 'SessionConfig'.
403 withTimeout :: Int -> Session a -> Session a
404 withTimeout duration f = do
405 chan <- asks messageChan
406 timeoutId <- getCurTimeoutId
407 modify $ \s -> s { overridingTimeout = True }
409 threadDelay (duration * 1000000)
410 writeChan chan (TimeoutMessage timeoutId)
412 bumpTimeoutId timeoutId
413 modify $ \s -> s { overridingTimeout = False }
416 data LogMsgType = LogServer | LogClient
419 -- | Logs the message if the config specified it
420 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
421 => LogMsgType -> a -> m ()
423 shouldLog <- asks $ logMessages . config
424 shouldColor <- asks $ logColor . config
425 liftIO $ when shouldLog $ do
426 when shouldColor $ setSGR [SetColor Foreground Dull color]
427 putStrLn $ arrow ++ showPretty msg
428 when shouldColor $ setSGR [Reset]
431 | t == LogServer = "<-- "
434 | t == LogServer = Magenta
437 showPretty = B.unpack . encodePretty