2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
70 -- | A session representing one instance of launching and connecting to a server.
72 -- You can send and receive messages to the server within 'Session' via
73 -- 'Language.Haskell.LSP.Test.message',
74 -- 'Language.Haskell.LSP.Test.sendRequest' and
75 -- 'Language.Haskell.LSP.Test.sendNotification'.
77 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
79 #if __GLASGOW_HASKELL__ >= 806
80 instance MonadFail Session where
82 lastMsg <- fromJust . lastReceivedMessage <$> get
83 liftIO $ throw (UnexpectedMessage s lastMsg)
86 -- | Stuff you can configure for a 'Session'.
87 data SessionConfig = SessionConfig
88 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
89 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
90 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
91 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
92 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
95 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
96 defaultConfig :: SessionConfig
97 defaultConfig = SessionConfig 60 False False True Nothing
99 instance Default SessionConfig where
102 data SessionMessage = ServerMessage FromServerMessage
106 data SessionContext = SessionContext
109 , rootDir :: FilePath
110 , messageChan :: Chan SessionMessage
111 , requestMap :: MVar RequestMap
112 , initRsp :: MVar InitializeResponse
113 , config :: SessionConfig
114 , sessionCapabilities :: ClientCapabilities
117 class Monad m => HasReader r m where
119 asks :: (r -> b) -> m b
122 instance Monad m => HasReader r (ParserStateReader a s r m) where
123 ask = lift $ lift Reader.ask
125 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
126 ask = lift $ lift Reader.ask
128 data SessionState = SessionState
132 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
133 , curTimeoutId :: Int
134 , overridingTimeout :: Bool
135 -- ^ The last received message from the server.
136 -- Used for providing exception information
137 , lastReceivedMessage :: Maybe FromServerMessage
140 class Monad m => HasState s m where
145 modify :: (s -> s) -> m ()
146 modify f = get >>= put . f
148 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
149 modifyM f = get >>= f >>= put
151 instance Monad m => HasState s (ParserStateReader a s r m) where
153 put = lift . State.put
155 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
158 put = lift . State.put
160 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
162 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
163 runSession context state session = runReaderT (runStateT conduit state) context
165 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
167 handler (Unexpected "ConduitParser.empty") = do
168 lastMsg <- fromJust . lastReceivedMessage <$> get
169 name <- getParserName
170 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
175 msg <- liftIO $ readChan (messageChan context)
179 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
180 watchdog = Conduit.awaitForever $ \msg -> do
181 curId <- curTimeoutId <$> get
183 ServerMessage sMsg -> yield sMsg
184 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
186 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
187 -- It also does not automatically send initialize and exit messages.
188 runSessionWithHandles :: Handle -- ^ Server in
189 -> Handle -- ^ Server out
190 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
192 -> ClientCapabilities
193 -> FilePath -- ^ Root directory
194 -> Session () -- ^ To exit Server
197 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir exitServer session = do
199 absRootDir <- canonicalizePath rootDir
201 hSetBuffering serverIn NoBuffering
202 hSetBuffering serverOut NoBuffering
203 -- This is required to make sure that we don’t get any
204 -- newline conversion or weird encoding issues.
205 hSetBinaryMode serverIn True
206 hSetBinaryMode serverOut True
208 reqMap <- newMVar newRequestMap
209 messageChan <- newChan
210 initRsp <- newEmptyMVar
212 mainThreadId <- myThreadId
214 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
215 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
216 runSession' = runSession context initState
218 errorHandler = throwTo mainThreadId :: SessionException -> IO()
219 serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
220 serverFinalizer tid = finally (timeout 60000000 (runSession' exitServer))
223 (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
226 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
227 updateStateC = awaitForever $ \msg -> do
231 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
232 updateState (NotPublishDiagnostics n) = do
233 let List diags = n ^. params . diagnostics
234 doc = n ^. params . uri
236 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
237 in s { curDiagnostics = newDiags })
239 updateState (ReqApplyWorkspaceEdit r) = do
241 allChangeParams <- case r ^. params . edit . documentChanges of
243 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
244 return $ map getParams cs
245 Nothing -> case r ^. params . edit . changes of
247 mapM_ checkIfNeedsOpened (HashMap.keys cs)
248 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
249 Nothing -> error "No changes!"
252 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
253 return $ s { vfs = newVFS }
255 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
256 mergedParams = map mergeParams groupedParams
258 -- TODO: Don't do this when replaying a session
259 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
261 -- Update VFS to new document versions
262 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
263 latestVersions = map ((^. textDocument) . last) sortedVersions
264 bumpedVersions = map (version . _Just +~ 1) latestVersions
266 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
269 update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
270 newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
271 in s { vfs = newVFS }
273 where checkIfNeedsOpened uri = do
274 oldVFS <- vfs <$> get
277 -- if its not open, open it
278 unless (toNormalizedUri uri `Map.member` oldVFS) $ do
279 let fp = fromJust $ uriToFilePath uri
280 contents <- liftIO $ T.readFile fp
281 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
282 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
283 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
286 newVFS <- liftIO $ openVFS (vfs s) msg
287 return $ s { vfs = newVFS }
289 getParams (TextDocumentEdit docId (List edits)) =
290 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
291 in DidChangeTextDocumentParams docId (List changeEvents)
293 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
295 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
297 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
299 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
300 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
301 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
302 updateState _ = return ()
304 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
306 h <- serverIn <$> ask
308 liftIO $ B.hPut h (addHeader $ encode msg)
310 -- | Execute a block f that will throw a 'Timeout' exception
311 -- after duration seconds. This will override the global timeout
312 -- for waiting for messages to arrive defined in 'SessionConfig'.
313 withTimeout :: Int -> Session a -> Session a
314 withTimeout duration f = do
315 chan <- asks messageChan
316 timeoutId <- curTimeoutId <$> get
317 modify $ \s -> s { overridingTimeout = True }
319 threadDelay (duration * 1000000)
320 writeChan chan (TimeoutMessage timeoutId)
322 modify $ \s -> s { curTimeoutId = timeoutId + 1,
323 overridingTimeout = False
327 data LogMsgType = LogServer | LogClient
330 -- | Logs the message if the config specified it
331 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
332 => LogMsgType -> a -> m ()
334 shouldLog <- asks $ logMessages . config
335 shouldColor <- asks $ logColor . config
336 liftIO $ when shouldLog $ do
337 when shouldColor $ setSGR [SetColor Foreground Dull color]
338 putStrLn $ arrow ++ showPretty msg
339 when shouldColor $ setSGR [Reset]
342 | t == LogServer = "<-- "
345 | t == LogServer = Magenta
348 showPretty = B.unpack . encodePretty