2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
69 -- | A session representing one instance of launching and connecting to a server.
71 -- You can send and receive messages to the server within 'Session' via 'getMessage',
72 -- 'sendRequest' and 'sendNotification'.
75 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
77 #if __GLASGOW_HASKELL__ >= 806
78 instance MonadFail Session where
80 lastMsg <- fromJust . lastReceivedMessage <$> get
81 liftIO $ throw (UnexpectedMessage s lastMsg)
84 -- | Stuff you can configure for a 'Session'.
85 data SessionConfig = SessionConfig
86 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
87 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
88 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
89 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
90 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
93 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
94 defaultConfig :: SessionConfig
95 defaultConfig = SessionConfig 60 False False True Nothing
97 instance Default SessionConfig where
100 data SessionMessage = ServerMessage FromServerMessage
104 data SessionContext = SessionContext
107 , rootDir :: FilePath
108 , messageChan :: Chan SessionMessage
109 , requestMap :: MVar RequestMap
110 , initRsp :: MVar InitializeResponse
111 , config :: SessionConfig
112 , sessionCapabilities :: ClientCapabilities
115 class Monad m => HasReader r m where
117 asks :: (r -> b) -> m b
120 instance Monad m => HasReader r (ParserStateReader a s r m) where
121 ask = lift $ lift Reader.ask
123 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
124 ask = lift $ lift Reader.ask
126 data SessionState = SessionState
130 , curDiagnostics :: Map.Map Uri [Diagnostic]
131 , curTimeoutId :: Int
132 , overridingTimeout :: Bool
133 -- ^ The last received message from the server.
134 -- Used for providing exception information
135 , lastReceivedMessage :: Maybe FromServerMessage
138 class Monad m => HasState s m where
143 modify :: (s -> s) -> m ()
144 modify f = get >>= put . f
146 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
147 modifyM f = get >>= f >>= put
149 instance Monad m => HasState s (ParserStateReader a s r m) where
151 put = lift . State.put
153 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
156 put = lift . State.put
158 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
160 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
161 runSession context state session = runReaderT (runStateT conduit state) context
163 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
165 handler (Unexpected "ConduitParser.empty") = do
166 lastMsg <- fromJust . lastReceivedMessage <$> get
167 name <- getParserName
168 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
173 msg <- liftIO $ readChan (messageChan context)
177 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
178 watchdog = Conduit.awaitForever $ \msg -> do
179 curId <- curTimeoutId <$> get
181 ServerMessage sMsg -> yield sMsg
182 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
184 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
185 -- It also does not automatically send initialize and exit messages.
186 runSessionWithHandles :: Handle -- ^ Server in
187 -> Handle -- ^ Server out
188 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
190 -> ClientCapabilities
191 -> FilePath -- ^ Root directory
194 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
195 absRootDir <- canonicalizePath rootDir
197 hSetBuffering serverIn NoBuffering
198 hSetBuffering serverOut NoBuffering
200 reqMap <- newMVar newRequestMap
201 messageChan <- newChan
202 initRsp <- newEmptyMVar
204 mainThreadId <- myThreadId
206 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
207 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
208 launchServerHandler = forkIO $ catch (serverHandler serverOut context)
209 (throwTo mainThreadId :: SessionException -> IO ())
210 (result, _) <- bracket launchServerHandler killThread $
211 const $ runSession context initState session
215 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
216 updateStateC = awaitForever $ \msg -> do
220 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
221 updateState (NotPublishDiagnostics n) = do
222 let List diags = n ^. params . diagnostics
223 doc = n ^. params . uri
225 let newDiags = Map.insert doc diags (curDiagnostics s)
226 in s { curDiagnostics = newDiags })
228 updateState (ReqApplyWorkspaceEdit r) = do
230 allChangeParams <- case r ^. params . edit . documentChanges of
232 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
233 return $ map getParams cs
234 Nothing -> case r ^. params . edit . changes of
236 mapM_ checkIfNeedsOpened (HashMap.keys cs)
237 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
238 Nothing -> error "No changes!"
241 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
242 return $ s { vfs = newVFS }
244 let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
245 mergedParams = map mergeParams groupedParams
247 -- TODO: Don't do this when replaying a session
248 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
250 -- Update VFS to new document versions
251 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
252 latestVersions = map ((^. textDocument) . last) sortedVersions
253 bumpedVersions = map (version . _Just +~ 1) latestVersions
255 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
258 update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
259 newVFS = Map.adjust update uri oldVFS
260 in s { vfs = newVFS }
262 where checkIfNeedsOpened uri = do
263 oldVFS <- vfs <$> get
266 -- if its not open, open it
267 unless (uri `Map.member` oldVFS) $ do
268 let fp = fromJust $ uriToFilePath uri
269 contents <- liftIO $ T.readFile fp
270 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
271 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
272 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
275 newVFS <- liftIO $ openVFS (vfs s) msg
276 return $ s { vfs = newVFS }
278 getParams (TextDocumentEdit docId (List edits)) =
279 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
280 in DidChangeTextDocumentParams docId (List changeEvents)
282 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
284 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
286 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
288 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
289 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
290 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
291 updateState _ = return ()
293 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
295 h <- serverIn <$> ask
297 liftIO $ B.hPut h (addHeader $ encode msg)
299 -- | Execute a block f that will throw a 'TimeoutException'
300 -- after duration seconds. This will override the global timeout
301 -- for waiting for messages to arrive defined in 'SessionConfig'.
302 withTimeout :: Int -> Session a -> Session a
303 withTimeout duration f = do
304 chan <- asks messageChan
305 timeoutId <- curTimeoutId <$> get
306 modify $ \s -> s { overridingTimeout = True }
308 threadDelay (duration * 1000000)
309 writeChan chan (TimeoutMessage timeoutId)
311 modify $ \s -> s { curTimeoutId = timeoutId + 1,
312 overridingTimeout = False
316 data LogMsgType = LogServer | LogClient
319 -- | Logs the message if the config specified it
320 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
321 => LogMsgType -> a -> m ()
323 shouldLog <- asks $ logMessages . config
324 shouldColor <- asks $ logColor . config
325 liftIO $ when shouldLog $ do
326 when shouldColor $ setSGR [SetColor Foreground Dull color]
327 putStrLn $ arrow ++ showPretty msg
328 when shouldColor $ setSGR [Reset]
331 | t == LogServer = "<-- "
334 | t == LogServer = Magenta
337 showPretty = B.unpack . encodePretty