2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.IO.Class
36 import Control.Monad.Except
37 #if __GLASGOW_HASKELL__ >= 806
38 import Control.Monad.Fail
40 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
41 import qualified Control.Monad.Trans.Reader as Reader (ask)
42 import Control.Monad.Trans.State (StateT, runStateT)
43 import qualified Control.Monad.Trans.State as State (get, put)
44 import qualified Data.ByteString.Lazy.Char8 as B
46 import Data.Aeson.Encode.Pretty
47 import Data.Conduit as Conduit
48 import Data.Conduit.Parser as Parser
52 import qualified Data.Map as Map
53 import qualified Data.Text as T
54 import qualified Data.Text.IO as T
55 import qualified Data.HashMap.Strict as HashMap
58 import Language.Haskell.LSP.Messages
59 import Language.Haskell.LSP.Types.Capabilities
60 import Language.Haskell.LSP.Types
61 import Language.Haskell.LSP.Types.Lens hiding (error)
62 import Language.Haskell.LSP.VFS
63 import Language.Haskell.LSP.Test.Decoding
64 import Language.Haskell.LSP.Test.Exceptions
65 import System.Console.ANSI
66 import System.Directory
69 -- | A session representing one instance of launching and connecting to a server.
71 -- You can send and receive messages to the server within 'Session' via
72 -- 'Language.Haskell.LSP.Test.message',
73 -- 'Language.Haskell.LSP.Test.sendRequest' and
74 -- 'Language.Haskell.LSP.Test.sendNotification'.
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
81 lastMsg <- fromJust . lastReceivedMessage <$> get
82 liftIO $ throw (UnexpectedMessage s lastMsg)
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False False True Nothing
98 instance Default SessionConfig where
101 data SessionMessage = ServerMessage FromServerMessage
105 data SessionContext = SessionContext
108 , rootDir :: FilePath
109 , messageChan :: Chan SessionMessage
110 , requestMap :: MVar RequestMap
111 , initRsp :: MVar InitializeResponse
112 , config :: SessionConfig
113 , sessionCapabilities :: ClientCapabilities
116 class Monad m => HasReader r m where
118 asks :: (r -> b) -> m b
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122 ask = lift $ lift Reader.ask
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125 ask = lift $ lift Reader.ask
127 data SessionState = SessionState
131 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
132 , curTimeoutId :: Int
133 , overridingTimeout :: Bool
134 -- ^ The last received message from the server.
135 -- Used for providing exception information
136 , lastReceivedMessage :: Maybe FromServerMessage
139 class Monad m => HasState s m where
144 modify :: (s -> s) -> m ()
145 modify f = get >>= put . f
147 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148 modifyM f = get >>= f >>= put
150 instance Monad m => HasState s (ParserStateReader a s r m) where
152 put = lift . State.put
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
157 put = lift . State.put
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
164 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
166 handler (Unexpected "ConduitParser.empty") = do
167 lastMsg <- fromJust . lastReceivedMessage <$> get
168 name <- getParserName
169 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
174 msg <- liftIO $ readChan (messageChan context)
178 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179 watchdog = Conduit.awaitForever $ \msg -> do
180 curId <- curTimeoutId <$> get
182 ServerMessage sMsg -> yield sMsg
183 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188 -> Handle -- ^ Server out
189 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
191 -> ClientCapabilities
192 -> FilePath -- ^ Root directory
193 -> Session () -- ^ To exit Server
196 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir exitServer session = do
198 absRootDir <- canonicalizePath rootDir
200 hSetBuffering serverIn NoBuffering
201 hSetBuffering serverOut NoBuffering
202 -- This is required to make sure that we don’t get any
203 -- newline conversion or weird encoding issues.
204 hSetBinaryMode serverIn True
205 hSetBinaryMode serverOut True
207 reqMap <- newMVar newRequestMap
208 messageChan <- newChan
209 initRsp <- newEmptyMVar
211 mainThreadId <- myThreadId
213 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
214 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
215 runSession' = runSession context initState
217 errorHandler = throwTo mainThreadId :: SessionException -> IO()
218 serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
219 serverFinalizer tid = runSession' exitServer >> killThread tid
221 (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
224 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
225 updateStateC = awaitForever $ \msg -> do
229 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
230 updateState (NotPublishDiagnostics n) = do
231 let List diags = n ^. params . diagnostics
232 doc = n ^. params . uri
234 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
235 in s { curDiagnostics = newDiags })
237 updateState (ReqApplyWorkspaceEdit r) = do
239 allChangeParams <- case r ^. params . edit . documentChanges of
241 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
242 return $ map getParams cs
243 Nothing -> case r ^. params . edit . changes of
245 mapM_ checkIfNeedsOpened (HashMap.keys cs)
246 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
247 Nothing -> error "No changes!"
250 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
251 return $ s { vfs = newVFS }
253 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
254 mergedParams = map mergeParams groupedParams
256 -- TODO: Don't do this when replaying a session
257 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
259 -- Update VFS to new document versions
260 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
261 latestVersions = map ((^. textDocument) . last) sortedVersions
262 bumpedVersions = map (version . _Just +~ 1) latestVersions
264 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
267 update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
268 newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
269 in s { vfs = newVFS }
271 where checkIfNeedsOpened uri = do
272 oldVFS <- vfs <$> get
275 -- if its not open, open it
276 unless (toNormalizedUri uri `Map.member` oldVFS) $ do
277 let fp = fromJust $ uriToFilePath uri
278 contents <- liftIO $ T.readFile fp
279 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
280 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
281 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
284 newVFS <- liftIO $ openVFS (vfs s) msg
285 return $ s { vfs = newVFS }
287 getParams (TextDocumentEdit docId (List edits)) =
288 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
289 in DidChangeTextDocumentParams docId (List changeEvents)
291 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
293 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
295 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
297 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
298 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
299 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
300 updateState _ = return ()
302 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
304 h <- serverIn <$> ask
306 liftIO $ B.hPut h (addHeader $ encode msg)
308 -- | Execute a block f that will throw a 'Timeout' exception
309 -- after duration seconds. This will override the global timeout
310 -- for waiting for messages to arrive defined in 'SessionConfig'.
311 withTimeout :: Int -> Session a -> Session a
312 withTimeout duration f = do
313 chan <- asks messageChan
314 timeoutId <- curTimeoutId <$> get
315 modify $ \s -> s { overridingTimeout = True }
317 threadDelay (duration * 1000000)
318 writeChan chan (TimeoutMessage timeoutId)
320 modify $ \s -> s { curTimeoutId = timeoutId + 1,
321 overridingTimeout = False
325 data LogMsgType = LogServer | LogClient
328 -- | Logs the message if the config specified it
329 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
330 => LogMsgType -> a -> m ()
332 shouldLog <- asks $ logMessages . config
333 shouldColor <- asks $ logColor . config
334 liftIO $ when shouldLog $ do
335 when shouldColor $ setSGR [SetColor Foreground Dull color]
336 putStrLn $ arrow ++ showPretty msg
337 when shouldColor $ setSGR [Reset]
340 | t == LogServer = "<-- "
343 | t == LogServer = Magenta
346 showPretty = B.unpack . encodePretty