2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
53 import qualified Data.Map as Map
54 import qualified Data.Text as T
55 import qualified Data.Text.IO as T
56 import qualified Data.HashMap.Strict as HashMap
59 import Language.Haskell.LSP.Messages
60 import Language.Haskell.LSP.Types.Capabilities
61 import Language.Haskell.LSP.Types
62 import Language.Haskell.LSP.Types.Lens hiding (error)
63 import Language.Haskell.LSP.VFS
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
70 -- | A session representing one instance of launching and connecting to a server.
72 -- You can send and receive messages to the server within 'Session' via
73 -- 'Language.Haskell.LSP.Test.message',
74 -- 'Language.Haskell.LSP.Test.sendRequest' and
75 -- 'Language.Haskell.LSP.Test.sendNotification'.
77 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
79 #if __GLASGOW_HASKELL__ >= 806
80 instance MonadFail Session where
82 lastMsg <- fromJust . lastReceivedMessage <$> get
83 liftIO $ throw (UnexpectedMessage s lastMsg)
86 -- | Stuff you can configure for a 'Session'.
87 data SessionConfig = SessionConfig
88 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
89 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
90 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
91 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
92 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
95 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
96 defaultConfig :: SessionConfig
97 defaultConfig = SessionConfig 60 False False True Nothing
99 instance Default SessionConfig where
102 data SessionMessage = ServerMessage FromServerMessage
106 data SessionContext = SessionContext
109 , rootDir :: FilePath
110 , messageChan :: Chan SessionMessage
111 , requestMap :: MVar RequestMap
112 , initRsp :: MVar InitializeResponse
113 , config :: SessionConfig
114 , sessionCapabilities :: ClientCapabilities
117 class Monad m => HasReader r m where
119 asks :: (r -> b) -> m b
122 instance Monad m => HasReader r (ParserStateReader a s r m) where
123 ask = lift $ lift Reader.ask
125 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
126 ask = lift $ lift Reader.ask
128 data SessionState = SessionState
132 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
133 , curTimeoutId :: Int
134 , overridingTimeout :: Bool
135 -- ^ The last received message from the server.
136 -- Used for providing exception information
137 , lastReceivedMessage :: Maybe FromServerMessage
140 class Monad m => HasState s m where
145 modify :: (s -> s) -> m ()
146 modify f = get >>= put . f
148 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
149 modifyM f = get >>= f >>= put
151 instance Monad m => HasState s (ParserStateReader a s r m) where
153 put = lift . State.put
155 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
158 put = lift . State.put
160 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
162 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
163 runSession context state session = runReaderT (runStateT conduit state) context
165 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
167 handler (Unexpected "ConduitParser.empty") = do
168 lastMsg <- fromJust . lastReceivedMessage <$> get
169 name <- getParserName
170 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
175 msg <- liftIO $ readChan (messageChan context)
179 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
180 watchdog = Conduit.awaitForever $ \msg -> do
181 curId <- curTimeoutId <$> get
183 ServerMessage sMsg -> yield sMsg
184 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
186 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
187 -- It also does not automatically send initialize and exit messages.
188 runSessionWithHandles :: Handle -- ^ Server in
189 -> Handle -- ^ Server out
190 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
192 -> ClientCapabilities
193 -> FilePath -- ^ Root directory
196 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
197 -- We use this IORef to make exception non-fatal when the server is supposed to shutdown.
199 exitOk <- newIORef False
201 absRootDir <- canonicalizePath rootDir
203 hSetBuffering serverIn NoBuffering
204 hSetBuffering serverOut NoBuffering
205 -- This is required to make sure that we don’t get any
206 -- newline conversion or weird encoding issues.
207 hSetBinaryMode serverIn True
208 hSetBinaryMode serverOut True
210 reqMap <- newMVar newRequestMap
211 messageChan <- newChan
212 initRsp <- newEmptyMVar
214 mainThreadId <- myThreadId
216 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218 errorHandler ex = do x <- readIORef exitOk
219 unless x $ throwTo mainThreadId (ex :: SessionException)
220 launchServerHandler = forkIO $ catch (serverHandler serverOut context) errorHandler
221 (result, _) <- bracket
223 (\tid -> do runSession context initState sendExitMessage
225 atomicWriteIORef exitOk True)
226 (const $ runSession context initState session)
229 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
230 updateStateC = awaitForever $ \msg -> do
234 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
235 updateState (NotPublishDiagnostics n) = do
236 let List diags = n ^. params . diagnostics
237 doc = n ^. params . uri
239 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
240 in s { curDiagnostics = newDiags })
242 updateState (ReqApplyWorkspaceEdit r) = do
244 allChangeParams <- case r ^. params . edit . documentChanges of
246 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
247 return $ map getParams cs
248 Nothing -> case r ^. params . edit . changes of
250 mapM_ checkIfNeedsOpened (HashMap.keys cs)
251 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
252 Nothing -> error "No changes!"
255 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
256 return $ s { vfs = newVFS }
258 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
259 mergedParams = map mergeParams groupedParams
261 -- TODO: Don't do this when replaying a session
262 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
264 -- Update VFS to new document versions
265 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
266 latestVersions = map ((^. textDocument) . last) sortedVersions
267 bumpedVersions = map (version . _Just +~ 1) latestVersions
269 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
272 update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
273 newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
274 in s { vfs = newVFS }
276 where checkIfNeedsOpened uri = do
277 oldVFS <- vfs <$> get
280 -- if its not open, open it
281 unless (toNormalizedUri uri `Map.member` oldVFS) $ do
282 let fp = fromJust $ uriToFilePath uri
283 contents <- liftIO $ T.readFile fp
284 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
285 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
286 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
289 newVFS <- liftIO $ openVFS (vfs s) msg
290 return $ s { vfs = newVFS }
292 getParams (TextDocumentEdit docId (List edits)) =
293 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
294 in DidChangeTextDocumentParams docId (List changeEvents)
296 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
298 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
300 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
302 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
303 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
304 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
305 updateState _ = return ()
307 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
309 h <- serverIn <$> ask
311 liftIO $ B.hPut h (addHeader $ encode msg)
313 sendExitMessage :: (MonadIO m, HasReader SessionContext m) => m ()
314 sendExitMessage = sendMessage (NotificationMessage "2.0" Exit ExitParams)
316 -- | Execute a block f that will throw a 'Timeout' exception
317 -- after duration seconds. This will override the global timeout
318 -- for waiting for messages to arrive defined in 'SessionConfig'.
319 withTimeout :: Int -> Session a -> Session a
320 withTimeout duration f = do
321 chan <- asks messageChan
322 timeoutId <- curTimeoutId <$> get
323 modify $ \s -> s { overridingTimeout = True }
325 threadDelay (duration * 1000000)
326 writeChan chan (TimeoutMessage timeoutId)
328 modify $ \s -> s { curTimeoutId = timeoutId + 1,
329 overridingTimeout = False
333 data LogMsgType = LogServer | LogClient
336 -- | Logs the message if the config specified it
337 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
338 => LogMsgType -> a -> m ()
340 shouldLog <- asks $ logMessages . config
341 shouldColor <- asks $ logColor . config
342 liftIO $ when shouldLog $ do
343 when shouldColor $ setSGR [SetColor Foreground Dull color]
344 putStrLn $ arrow ++ showPretty msg
345 when shouldColor $ setSGR [Reset]
348 | t == LogServer = "<-- "
351 | t == LogServer = Magenta
354 showPretty = B.unpack . encodePretty