2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import Language.Haskell.LSP.VFS
67 import Language.Haskell.LSP.Test.Compat
68 import Language.Haskell.LSP.Test.Decoding
69 import Language.Haskell.LSP.Test.Exceptions
70 import System.Console.ANSI
71 import System.Directory
73 import System.Process (ProcessHandle())
76 -- | A session representing one instance of launching and connecting to a server.
78 -- You can send and receive messages to the server within 'Session' via
79 -- 'Language.Haskell.LSP.Test.message',
80 -- 'Language.Haskell.LSP.Test.sendRequest' and
81 -- 'Language.Haskell.LSP.Test.sendNotification'.
83 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
84 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
86 #if __GLASGOW_HASKELL__ >= 806
87 instance MonadFail Session where
89 lastMsg <- fromJust . lastReceivedMessage <$> get
90 liftIO $ throw (UnexpectedMessage s lastMsg)
93 -- | Stuff you can configure for a 'Session'.
94 data SessionConfig = SessionConfig
95 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
97 -- ^ Redirect the server's stderr to this stdout, defaults to False.
98 -- Can be overriden with @LSP_TEST_LOG_STDERR@.
100 -- ^ Trace the messages sent and received to stdout, defaults to False.
101 -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
102 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
103 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
104 , ignoreLogNotifications :: Bool
105 -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
106 -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
111 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
112 defaultConfig :: SessionConfig
113 defaultConfig = SessionConfig 60 False False True Nothing False
115 instance Default SessionConfig where
118 data SessionMessage = ServerMessage FromServerMessage
122 data SessionContext = SessionContext
125 , rootDir :: FilePath
126 , messageChan :: Chan SessionMessage -- ^ Where all messages come through
127 -- Keep curTimeoutId in SessionContext, as its tied to messageChan
128 , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
129 , requestMap :: MVar RequestMap
130 , initRsp :: MVar InitializeResponse
131 , config :: SessionConfig
132 , sessionCapabilities :: ClientCapabilities
135 class Monad m => HasReader r m where
137 asks :: (r -> b) -> m b
140 instance HasReader SessionContext Session where
141 ask = Session (lift $ lift Reader.ask)
143 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
144 ask = lift $ lift Reader.ask
146 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
147 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
149 -- Pass this the timeoutid you *were* waiting on
150 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
151 bumpTimeoutId prev = do
152 v <- asks curTimeoutId
153 -- when updating the curtimeoutid, account for the fact that something else
154 -- might have bumped the timeoutid in the meantime
155 liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
157 data SessionState = SessionState
161 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
162 , overridingTimeout :: Bool
163 -- ^ The last received message from the server.
164 -- Used for providing exception information
165 , lastReceivedMessage :: Maybe FromServerMessage
168 class Monad m => HasState s m where
173 modify :: (s -> s) -> m ()
174 modify f = get >>= put . f
176 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
177 modifyM f = get >>= f >>= put
179 instance HasState SessionState Session where
180 get = Session (lift State.get)
181 put = Session . lift . State.put
183 instance Monad m => HasState s (StateT s m) where
187 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
192 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
197 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
198 runSession context state (Session session) = runReaderT (runStateT conduit state) context
200 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
202 handler (Unexpected "ConduitParser.empty") = do
203 lastMsg <- fromJust . lastReceivedMessage <$> get
204 name <- getParserName
205 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
210 msg <- liftIO $ readChan (messageChan context)
211 unless (ignoreLogNotifications (config context) && isLogNotification msg) $
215 isLogNotification (ServerMessage (NotShowMessage _)) = True
216 isLogNotification (ServerMessage (NotLogMessage _)) = True
217 isLogNotification _ = False
219 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
220 watchdog = Conduit.awaitForever $ \msg -> do
221 curId <- getCurTimeoutId
223 ServerMessage sMsg -> yield sMsg
224 TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
226 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
227 -- It also does not automatically send initialize and exit messages.
228 runSessionWithHandles :: Handle -- ^ Server in
229 -> Handle -- ^ Server out
230 -> ProcessHandle -- ^ Server process
231 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
233 -> ClientCapabilities
234 -> FilePath -- ^ Root directory
235 -> Session () -- ^ To exit the Server properly
238 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
239 absRootDir <- canonicalizePath rootDir
241 hSetBuffering serverIn NoBuffering
242 hSetBuffering serverOut NoBuffering
243 -- This is required to make sure that we don’t get any
244 -- newline conversion or weird encoding issues.
245 hSetBinaryMode serverIn True
246 hSetBinaryMode serverOut True
248 reqMap <- newMVar newRequestMap
249 messageChan <- newChan
250 timeoutIdVar <- newMVar 0
251 initRsp <- newEmptyMVar
253 mainThreadId <- myThreadId
255 let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
256 initState vfs = SessionState (IdInt 0) vfs mempty False Nothing
257 runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
259 errorHandler = throwTo mainThreadId :: SessionException -> IO ()
260 serverListenerLauncher =
261 forkIO $ catch (serverHandler serverOut context) errorHandler
262 server = (Just serverIn, Just serverOut, Nothing, serverProc)
263 serverAndListenerFinalizer tid = do
264 finally (timeout (messageTimeout config * 1^6)
265 (runSession' exitServer))
266 (cleanupProcess server >> killThread tid)
268 (result, _) <- bracket serverListenerLauncher
269 serverAndListenerFinalizer
270 (const $ runSession' session)
273 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
274 updateStateC = awaitForever $ \msg -> do
278 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
279 => FromServerMessage -> m ()
280 updateState (NotPublishDiagnostics n) = do
281 let List diags = n ^. params . diagnostics
282 doc = n ^. params . uri
284 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
285 in s { curDiagnostics = newDiags })
287 updateState (ReqApplyWorkspaceEdit r) = do
289 allChangeParams <- case r ^. params . edit . documentChanges of
291 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
292 return $ map getParams cs
293 Nothing -> case r ^. params . edit . changes of
295 mapM_ checkIfNeedsOpened (HashMap.keys cs)
296 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
297 Nothing -> error "No changes!"
300 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
301 return $ s { vfs = newVFS }
303 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
304 mergedParams = map mergeParams groupedParams
306 -- TODO: Don't do this when replaying a session
307 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
309 -- Update VFS to new document versions
310 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
311 latestVersions = map ((^. textDocument) . last) sortedVersions
312 bumpedVersions = map (version . _Just +~ 1) latestVersions
314 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
317 update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
318 newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
319 in s { vfs = newVFS }
321 where checkIfNeedsOpened uri = do
322 oldVFS <- vfs <$> get
325 -- if its not open, open it
326 unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
327 let fp = fromJust $ uriToFilePath uri
328 contents <- liftIO $ T.readFile fp
329 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
330 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
331 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
334 let (newVFS,_) = openVFS (vfs s) msg
335 return $ s { vfs = newVFS }
337 getParams (TextDocumentEdit docId (List edits)) =
338 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
339 in DidChangeTextDocumentParams docId (List changeEvents)
341 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
343 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
345 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
347 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
348 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
349 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
350 updateState _ = return ()
352 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
354 h <- serverIn <$> ask
356 liftIO $ B.hPut h (addHeader $ encode msg)
358 -- | Execute a block f that will throw a 'Timeout' exception
359 -- after duration seconds. This will override the global timeout
360 -- for waiting for messages to arrive defined in 'SessionConfig'.
361 withTimeout :: Int -> Session a -> Session a
362 withTimeout duration f = do
363 chan <- asks messageChan
364 timeoutId <- getCurTimeoutId
365 modify $ \s -> s { overridingTimeout = True }
367 threadDelay (duration * 1000000)
368 writeChan chan (TimeoutMessage timeoutId)
370 bumpTimeoutId timeoutId
371 modify $ \s -> s { overridingTimeout = False }
374 data LogMsgType = LogServer | LogClient
377 -- | Logs the message if the config specified it
378 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
379 => LogMsgType -> a -> m ()
381 shouldLog <- asks $ logMessages . config
382 shouldColor <- asks $ logColor . config
383 liftIO $ when shouldLog $ do
384 when shouldColor $ setSGR [SetColor Foreground Dull color]
385 putStrLn $ arrow ++ showPretty msg
386 when shouldColor $ setSGR [Reset]
389 | t == LogServer = "<-- "
392 | t == LogServer = Magenta
395 showPretty = B.unpack . encodePretty