2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
71 -- | A session representing one instance of launching and connecting to a server.
73 -- You can send and receive messages to the server within 'Session' via
74 -- 'Language.Haskell.LSP.Test.message',
75 -- 'Language.Haskell.LSP.Test.sendRequest' and
76 -- 'Language.Haskell.LSP.Test.sendNotification'.
78 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
80 #if __GLASGOW_HASKELL__ >= 806
81 instance MonadFail Session where
83 lastMsg <- fromJust . lastReceivedMessage <$> get
84 liftIO $ throw (UnexpectedMessage s lastMsg)
87 -- | Stuff you can configure for a 'Session'.
88 data SessionConfig = SessionConfig
89 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
90 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
91 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
92 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
93 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
96 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
97 defaultConfig :: SessionConfig
98 defaultConfig = SessionConfig 60 False False True Nothing
100 instance Default SessionConfig where
103 data SessionMessage = ServerMessage FromServerMessage
107 data SessionContext = SessionContext
110 , rootDir :: FilePath
111 , messageChan :: Chan SessionMessage
112 , requestMap :: MVar RequestMap
113 , initRsp :: MVar InitializeResponse
114 , config :: SessionConfig
115 , sessionCapabilities :: ClientCapabilities
118 class Monad m => HasReader r m where
120 asks :: (r -> b) -> m b
123 instance Monad m => HasReader r (ParserStateReader a s r m) where
124 ask = lift $ lift Reader.ask
126 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
127 ask = lift $ lift Reader.ask
129 data SessionState = SessionState
133 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
134 , curTimeoutId :: Int
135 , overridingTimeout :: Bool
136 -- ^ The last received message from the server.
137 -- Used for providing exception information
138 , lastReceivedMessage :: Maybe FromServerMessage
141 class Monad m => HasState s m where
146 modify :: (s -> s) -> m ()
147 modify f = get >>= put . f
149 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
150 modifyM f = get >>= f >>= put
152 instance Monad m => HasState s (ParserStateReader a s r m) where
154 put = lift . State.put
156 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
159 put = lift . State.put
161 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
163 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
164 runSession context state session = runReaderT (runStateT conduit state) context
166 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
168 handler (Unexpected "ConduitParser.empty") = do
169 lastMsg <- fromJust . lastReceivedMessage <$> get
170 name <- getParserName
171 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
176 msg <- liftIO $ readChan (messageChan context)
180 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
181 watchdog = Conduit.awaitForever $ \msg -> do
182 curId <- curTimeoutId <$> get
184 ServerMessage sMsg -> yield sMsg
185 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
187 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
188 -- It also does not automatically send initialize and exit messages.
189 runSessionWithHandles :: Handle -- ^ Server in
190 -> Handle -- ^ Server out
191 -> ProcessHandle -- ^ Server process
192 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
194 -> ClientCapabilities
195 -> FilePath -- ^ Root directory
196 -> Session () -- ^ To exit the Server properly
199 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
201 absRootDir <- canonicalizePath rootDir
203 hSetBuffering serverIn NoBuffering
204 hSetBuffering serverOut NoBuffering
205 -- This is required to make sure that we don’t get any
206 -- newline conversion or weird encoding issues.
207 hSetBinaryMode serverIn True
208 hSetBinaryMode serverOut True
210 reqMap <- newMVar newRequestMap
211 messageChan <- newChan
212 initRsp <- newEmptyMVar
214 mainThreadId <- myThreadId
216 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
217 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
218 runSession' = runSession context initState
220 errorHandler = throwTo mainThreadId :: SessionException -> IO()
221 serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
222 server = (Just serverIn, Just serverOut, Nothing, serverProc)
223 serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
224 (runSession' exitServer))
225 (terminateProcess serverProc
229 (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
232 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
233 updateStateC = awaitForever $ \msg -> do
237 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
238 updateState (NotPublishDiagnostics n) = do
239 let List diags = n ^. params . diagnostics
240 doc = n ^. params . uri
242 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
243 in s { curDiagnostics = newDiags })
245 updateState (ReqApplyWorkspaceEdit r) = do
247 allChangeParams <- case r ^. params . edit . documentChanges of
249 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
250 return $ map getParams cs
251 Nothing -> case r ^. params . edit . changes of
253 mapM_ checkIfNeedsOpened (HashMap.keys cs)
254 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
255 Nothing -> error "No changes!"
258 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
259 return $ s { vfs = newVFS }
261 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
262 mergedParams = map mergeParams groupedParams
264 -- TODO: Don't do this when replaying a session
265 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
267 -- Update VFS to new document versions
268 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
269 latestVersions = map ((^. textDocument) . last) sortedVersions
270 bumpedVersions = map (version . _Just +~ 1) latestVersions
272 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
275 update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
276 newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
277 in s { vfs = newVFS }
279 where checkIfNeedsOpened uri = do
280 oldVFS <- vfs <$> get
283 -- if its not open, open it
284 unless (toNormalizedUri uri `Map.member` oldVFS) $ do
285 let fp = fromJust $ uriToFilePath uri
286 contents <- liftIO $ T.readFile fp
287 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
288 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
289 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
292 newVFS <- liftIO $ openVFS (vfs s) msg
293 return $ s { vfs = newVFS }
295 getParams (TextDocumentEdit docId (List edits)) =
296 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
297 in DidChangeTextDocumentParams docId (List changeEvents)
299 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
301 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
303 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
305 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
306 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
307 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
308 updateState _ = return ()
310 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
312 h <- serverIn <$> ask
314 liftIO $ B.hPut h (addHeader $ encode msg)
316 -- | Execute a block f that will throw a 'Timeout' exception
317 -- after duration seconds. This will override the global timeout
318 -- for waiting for messages to arrive defined in 'SessionConfig'.
319 withTimeout :: Int -> Session a -> Session a
320 withTimeout duration f = do
321 chan <- asks messageChan
322 timeoutId <- curTimeoutId <$> get
323 modify $ \s -> s { overridingTimeout = True }
325 threadDelay (duration * 1000000)
326 writeChan chan (TimeoutMessage timeoutId)
328 modify $ \s -> s { curTimeoutId = timeoutId + 1,
329 overridingTimeout = False
333 data LogMsgType = LogServer | LogClient
336 -- | Logs the message if the config specified it
337 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
338 => LogMsgType -> a -> m ()
340 shouldLog <- asks $ logMessages . config
341 shouldColor <- asks $ logColor . config
342 liftIO $ when shouldLog $ do
343 when shouldColor $ setSGR [SetColor Foreground Dull color]
344 putStrLn $ arrow ++ showPretty msg
345 when shouldColor $ setSGR [Reset]
348 | t == LogServer = "<-- "
351 | t == LogServer = Magenta
354 showPretty = B.unpack . encodePretty