2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
9 module Language.Haskell.LSP.Test.Session
16 , runSessionWithHandles
32 import Control.Applicative
33 import Control.Concurrent hiding (yield)
34 import Control.Exception
35 import Control.Lens hiding (List)
37 import Control.Monad.IO.Class
38 import Control.Monad.Except
39 #if __GLASGOW_HASKELL__ == 806
40 import Control.Monad.Fail
42 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
43 import qualified Control.Monad.Trans.Reader as Reader (ask)
44 import Control.Monad.Trans.State (StateT, runStateT)
45 import qualified Control.Monad.Trans.State as State
46 import qualified Data.ByteString.Lazy.Char8 as B
48 import Data.Aeson.Encode.Pretty
49 import Data.Conduit as Conduit
50 import Data.Conduit.Parser as Parser
54 import qualified Data.Map as Map
55 import qualified Data.Text as T
56 import qualified Data.Text.IO as T
57 import qualified Data.HashMap.Strict as HashMap
60 import Language.Haskell.LSP.Messages
61 import Language.Haskell.LSP.Types.Capabilities
62 import Language.Haskell.LSP.Types
63 import Language.Haskell.LSP.Types.Lens hiding (error)
64 import Language.Haskell.LSP.VFS
65 import Language.Haskell.LSP.Test.Compat
66 import Language.Haskell.LSP.Test.Decoding
67 import Language.Haskell.LSP.Test.Exceptions
68 import System.Console.ANSI
69 import System.Directory
71 import System.Process (ProcessHandle())
74 -- | A session representing one instance of launching and connecting to a server.
76 -- You can send and receive messages to the server within 'Session' via
77 -- 'Language.Haskell.LSP.Test.message',
78 -- 'Language.Haskell.LSP.Test.sendRequest' and
79 -- 'Language.Haskell.LSP.Test.sendNotification'.
81 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
82 deriving (Functor, Applicative, Monad, MonadIO, Alternative)
84 #if __GLASGOW_HASKELL__ >= 806
85 instance MonadFail Session where
87 lastMsg <- fromJust . lastReceivedMessage <$> get
88 liftIO $ throw (UnexpectedMessage s lastMsg)
91 -- | Stuff you can configure for a 'Session'.
92 data SessionConfig = SessionConfig
93 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
94 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
95 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
96 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
97 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
100 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
101 defaultConfig :: SessionConfig
102 defaultConfig = SessionConfig 60 False False True Nothing
104 instance Default SessionConfig where
107 data SessionMessage = ServerMessage FromServerMessage
111 data SessionContext = SessionContext
114 , rootDir :: FilePath
115 , messageChan :: Chan SessionMessage
116 , requestMap :: MVar RequestMap
117 , initRsp :: MVar InitializeResponse
118 , config :: SessionConfig
119 , sessionCapabilities :: ClientCapabilities
122 class Monad m => HasReader r m where
124 asks :: (r -> b) -> m b
127 instance HasReader SessionContext Session where
128 ask = Session (lift $ lift Reader.ask)
130 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
131 ask = lift $ lift Reader.ask
133 data SessionState = SessionState
137 , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
138 , curTimeoutId :: Int
139 , overridingTimeout :: Bool
140 -- ^ The last received message from the server.
141 -- Used for providing exception information
142 , lastReceivedMessage :: Maybe FromServerMessage
145 class Monad m => HasState s m where
150 modify :: (s -> s) -> m ()
151 modify f = get >>= put . f
153 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
154 modifyM f = get >>= f >>= put
156 instance HasState SessionState Session where
157 get = Session (lift State.get)
158 put = Session . lift . State.put
160 instance Monad m => HasState s (ConduitM a b (StateT s m))
163 put = lift . State.put
165 instance Monad m => HasState s (ConduitParser a (StateT s m))
168 put = lift . State.put
170 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
171 runSession context state (Session session) = runReaderT (runStateT conduit state) context
173 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
175 handler (Unexpected "ConduitParser.empty") = do
176 lastMsg <- fromJust . lastReceivedMessage <$> get
177 name <- getParserName
178 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
183 msg <- liftIO $ readChan (messageChan context)
187 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
188 watchdog = Conduit.awaitForever $ \msg -> do
189 curId <- curTimeoutId <$> get
191 ServerMessage sMsg -> yield sMsg
192 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
194 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
195 -- It also does not automatically send initialize and exit messages.
196 runSessionWithHandles :: Handle -- ^ Server in
197 -> Handle -- ^ Server out
198 -> ProcessHandle -- ^ Server process
199 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
201 -> ClientCapabilities
202 -> FilePath -- ^ Root directory
203 -> Session () -- ^ To exit the Server properly
206 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
207 absRootDir <- canonicalizePath rootDir
209 hSetBuffering serverIn NoBuffering
210 hSetBuffering serverOut NoBuffering
211 -- This is required to make sure that we don’t get any
212 -- newline conversion or weird encoding issues.
213 hSetBinaryMode serverIn True
214 hSetBinaryMode serverOut True
216 reqMap <- newMVar newRequestMap
217 messageChan <- newChan
218 initRsp <- newEmptyMVar
220 mainThreadId <- myThreadId
222 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
223 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
224 runSession' = runSession context initState
226 errorHandler = throwTo mainThreadId :: SessionException -> IO()
227 serverListenerLauncher =
228 forkIO $ catch (serverHandler serverOut context) errorHandler
229 server = (Just serverIn, Just serverOut, Nothing, serverProc)
230 serverAndListenerFinalizer tid =
231 finally (timeout (messageTimeout config * 1000000)
232 (runSession' exitServer))
233 (cleanupProcess server >> killThread tid)
235 (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
236 (const $ runSession' session)
239 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
240 updateStateC = awaitForever $ \msg -> do
244 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
245 => FromServerMessage -> m ()
246 updateState (NotPublishDiagnostics n) = do
247 let List diags = n ^. params . diagnostics
248 doc = n ^. params . uri
250 let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
251 in s { curDiagnostics = newDiags })
253 updateState (ReqApplyWorkspaceEdit r) = do
255 allChangeParams <- case r ^. params . edit . documentChanges of
257 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
258 return $ map getParams cs
259 Nothing -> case r ^. params . edit . changes of
261 mapM_ checkIfNeedsOpened (HashMap.keys cs)
262 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
263 Nothing -> error "No changes!"
266 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
267 return $ s { vfs = newVFS }
269 let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
270 mergedParams = map mergeParams groupedParams
272 -- TODO: Don't do this when replaying a session
273 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
275 -- Update VFS to new document versions
276 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
277 latestVersions = map ((^. textDocument) . last) sortedVersions
278 bumpedVersions = map (version . _Just +~ 1) latestVersions
280 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
283 update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
284 newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
285 in s { vfs = newVFS }
287 where checkIfNeedsOpened uri = do
288 oldVFS <- vfs <$> get
291 -- if its not open, open it
292 unless (toNormalizedUri uri `Map.member` oldVFS) $ do
293 let fp = fromJust $ uriToFilePath uri
294 contents <- liftIO $ T.readFile fp
295 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
296 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
297 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
300 newVFS <- liftIO $ openVFS (vfs s) msg
301 return $ s { vfs = newVFS }
303 getParams (TextDocumentEdit docId (List edits)) =
304 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
305 in DidChangeTextDocumentParams docId (List changeEvents)
307 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
309 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
311 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
313 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
314 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
315 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
316 updateState _ = return ()
318 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
320 h <- serverIn <$> ask
322 liftIO $ B.hPut h (addHeader $ encode msg)
324 -- | Execute a block f that will throw a 'Timeout' exception
325 -- after duration seconds. This will override the global timeout
326 -- for waiting for messages to arrive defined in 'SessionConfig'.
327 withTimeout :: Int -> Session a -> Session a
328 withTimeout duration f = do
329 chan <- asks messageChan
330 timeoutId <- curTimeoutId <$> get
331 modify $ \s -> s { overridingTimeout = True }
333 threadDelay (duration * 1000000)
334 writeChan chan (TimeoutMessage timeoutId)
336 modify $ \s -> s { curTimeoutId = timeoutId + 1,
337 overridingTimeout = False
341 data LogMsgType = LogServer | LogClient
344 -- | Logs the message if the config specified it
345 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
346 => LogMsgType -> a -> m ()
348 shouldLog <- asks $ logMessages . config
349 shouldColor <- asks $ logColor . config
350 liftIO $ when shouldLog $ do
351 when shouldColor $ setSGR [SetColor Foreground Dull color]
352 putStrLn $ arrow ++ showPretty msg
353 when shouldColor $ setSGR [Reset]
356 | t == LogServer = "<-- "
359 | t == LogServer = Magenta
362 showPretty = B.unpack . encodePretty