2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
74 import System.Process (ProcessHandle())
75 #ifndef mingw32_HOST_OS
76 import System.Process (waitForProcess)
80 -- | A session representing one instance of launching and connecting to a server.
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
93 lastMsg <- fromJust . lastReceivedMessage <$> get
94 liftIO $ throw (UnexpectedMessage s lastMsg)
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
101 -- ^ Redirect the server's stderr to this stdout, defaults to False.
102 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103 , logMessages :: Bool
104 -- ^ Trace the messages sent and received to stdout, defaults to False.
105 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108 , ignoreLogNotifications :: Bool
109 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
115 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
116 defaultConfig :: SessionConfig
117 defaultConfig = SessionConfig 60 False False True Nothing False
119 instance Default SessionConfig where
122 data SessionMessage = ServerMessage FromServerMessage
126 data SessionContext = SessionContext
129 , rootDir :: FilePath
130 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
131 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
132 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
133 , requestMap :: MVar RequestMap
134 , initRsp :: MVar InitializeResponse
135 , config :: SessionConfig
136 , sessionCapabilities :: ClientCapabilities
139 class Monad m => HasReader r m where
141 asks :: (r -> b) -> m b
144 instance HasReader SessionContext Session where
145 ask = Session (lift $ lift Reader.ask)
147 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
148 ask = lift $ lift Reader.ask
150 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
151 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
153 -- Pass this the timeoutid you *were* waiting on
154 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
155 bumpTimeoutId prev = do
156 v <- asks curTimeoutId
157 -- when updating the curtimeoutid, account for the fact that something else
158 -- might have bumped the timeoutid in the meantime
159 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
161 data SessionState = SessionState
165 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
166 , overridingTimeout :: Bool
167 -- ^ The last received message from the server.
168 -- Used for providing exception information
169 , lastReceivedMessage :: Maybe FromServerMessage
170 , curDynCaps :: Map.Map T.Text Registration
171 -- ^ The capabilities that the server has dynamically registered with us so
175 class Monad m => HasState s m where
180 modify :: (s -> s) -> m ()
181 modify f = get >>= put . f
183 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
184 modifyM f = get >>= f >>= put
186 instance HasState SessionState Session where
187 get = Session (lift State.get)
188 put = Session . lift . State.put
190 instance Monad m => HasState s (StateT s m) where
194 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
199 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
204 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
205 runSession context state (Session session) = runReaderT (runStateT conduit state) context
207 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
209 handler (Unexpected "ConduitParser.empty") = do
210 lastMsg <- fromJust . lastReceivedMessage <$> get
211 name <- getParserName
212 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
217 msg <- liftIO $ readChan (messageChan context)
218 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
222 isLogNotification (ServerMessage (NotShowMessage _)) = True
223 isLogNotification (ServerMessage (NotLogMessage _)) = True
224 isLogNotification _ = False
226 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
227 watchdog = Conduit.awaitForever $ \msg -> do
228 curId <- getCurTimeoutId
230 ServerMessage sMsg -> yield sMsg
231 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
233 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
234 -- It also does not automatically send initialize and exit messages.
235 runSessionWithHandles :: Handle -- ^ Server in
236 -> Handle -- ^ Server out
237 -> ProcessHandle -- ^ Server process
238 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
240 -> ClientCapabilities
241 -> FilePath -- ^ Root directory
242 -> Session () -- ^ To exit the Server properly
245 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
246 absRootDir <- canonicalizePath rootDir
248 hSetBuffering serverIn NoBuffering
249 hSetBuffering serverOut NoBuffering
250 -- This is required to make sure that we don’t get any
251 -- newline conversion or weird encoding issues.
252 hSetBinaryMode serverIn True
253 hSetBinaryMode serverOut True
255 reqMap <- newMVar newRequestMap
256 messageChan <- newChan
257 timeoutIdVar <- newMVar 0
258 initRsp <- newEmptyMVar
260 mainThreadId <- myThreadId
262 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
263 initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
264 runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
266 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
267 serverListenerLauncher =
268 forkIO $ catch (serverHandler serverOut context) errorHandler
269 server = (Just serverIn, Just serverOut, Nothing, serverProc)
270 msgTimeoutMs = messageTimeout config * 10^6
271 serverAndListenerFinalizer tid = do
272 finally (timeout msgTimeoutMs (runSession' exitServer)) $ do
273 -- Make sure to kill the listener first, before closing
274 -- handles etc via cleanupProcess
276 -- Give the server some time to exit cleanly
277 -- It makes the server hangs in windows so we have to avoid it
278 #ifndef mingw32_HOST_OS
279 timeout msgTimeoutMs (waitForProcess serverProc)
281 cleanupProcess server
283 (result, _) <- bracket serverListenerLauncher
284 serverAndListenerFinalizer
285 (const $ runSession' session)
288 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
289 updateStateC = awaitForever $ \msg -> do
293 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
294 => FromServerMessage -> m ()
296 -- Keep track of dynamic capability registration
297 updateState (ReqRegisterCapability req) = do
298 let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
300 s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
302 updateState (ReqUnregisterCapability req) = do
303 let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
305 let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
306 in s { curDynCaps = newCurDynCaps }
308 updateState (NotPublishDiagnostics n) = do
309 let List diags = n ^. params . diagnostics
310 doc = n ^. params . uri
312 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
313 in s { curDiagnostics = newDiags }
315 updateState (ReqApplyWorkspaceEdit r) = do
317 -- First, prefer the versioned documentChanges field
318 allChangeParams <- case r ^. params . edit . documentChanges of
320 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
321 return $ map getParams cs
322 -- Then fall back to the changes field
323 Nothing -> case r ^. params . edit . changes of
325 mapM_ checkIfNeedsOpened (HashMap.keys cs)
326 concat <$> mapM (uncurry getChangeParams) (HashMap.toList cs)
328 error "WorkspaceEdit contains neither documentChanges nor changes!"
331 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
332 return $ s { vfs = newVFS }
334 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
335 mergedParams = map mergeParams groupedParams
337 -- TODO: Don't do this when replaying a session
338 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
340 -- Update VFS to new document versions
341 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
342 latestVersions = map ((^. textDocument) . last) sortedVersions
343 bumpedVersions = map (version . _Just +~ 1) latestVersions
345 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
348 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
349 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
350 in s { vfs = newVFS }
352 where checkIfNeedsOpened uri = do
353 oldVFS <- vfs <$> get
356 -- if its not open, open it
357 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
358 let fp = fromJust $ uriToFilePath uri
359 contents <- liftIO $ T.readFile fp
360 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
361 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
362 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
365 let (newVFS,_) = openVFS (vfs s) msg
366 return $ s { vfs = newVFS }
368 getParams (TextDocumentEdit docId (List edits)) =
369 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
370 in DidChangeTextDocumentParams docId (List changeEvents)
372 -- For a uri returns an infinite list of versions [n,n+1,n+2,...]
373 -- where n is the current version
374 textDocumentVersions uri = do
375 m <- vfsMap . vfs <$> get
376 let curVer = fromMaybe 0 $
377 _lsp_version <$> m Map.!? (toNormalizedUri uri)
378 pure $ map (VersionedTextDocumentIdentifier uri . Just) [curVer + 1..]
380 textDocumentEdits uri edits = do
381 vers <- textDocumentVersions uri
382 pure $ map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip vers edits
384 getChangeParams uri (List edits) =
385 map <$> pure getParams <*> textDocumentEdits uri (reverse edits)
387 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
388 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
389 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
390 updateState _ = return ()
392 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
394 h <- serverIn <$> ask
396 liftIO $ B.hPut h (addHeader $ encode msg)
398 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
399 -- after duration seconds. This will override the global timeout
400 -- for waiting for messages to arrive defined in 'SessionConfig'.
401 withTimeout :: Int -> Session a -> Session a
402 withTimeout duration f = do
403 chan <- asks messageChan
404 timeoutId <- getCurTimeoutId
405 modify $ \s -> s { overridingTimeout = True }
407 threadDelay (duration * 1000000)
408 writeChan chan (TimeoutMessage timeoutId)
410 bumpTimeoutId timeoutId
411 modify $ \s -> s { overridingTimeout = False }
414 data LogMsgType = LogServer | LogClient
417 -- | Logs the message if the config specified it
418 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
419 => LogMsgType -> a -> m ()
421 shouldLog <- asks $ logMessages . config
422 shouldColor <- asks $ logColor . config
423 liftIO $ when shouldLog $ do
424 when shouldColor $ setSGR [SetColor Foreground Dull color]
425 putStrLn $ arrow ++ showPretty msg
426 when shouldColor $ setSGR [Reset]
429 | t == LogServer = "<-- "
432 | t == LogServer = Magenta
435 showPretty = B.unpack . encodePretty