2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE MultiParamTypeClasses #-}
5 {-# LANGUAGE FlexibleContexts #-}
6 {-# LANGUAGE RankNTypes #-}
8 module Language.Haskell.LSP.Test.Session
15 , runSessionWithHandles
31 import Control.Concurrent hiding (yield)
32 import Control.Exception
33 import Control.Lens hiding (List)
35 import Control.Monad.Fail
36 import Control.Monad.IO.Class
37 import Control.Monad.Except
38 #if __GLASGOW_HASKELL__ >= 806
39 import qualified Control.Monad.Fail as Fail
41 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
42 import qualified Control.Monad.Trans.Reader as Reader (ask)
43 import Control.Monad.Trans.State (StateT, runStateT)
44 import qualified Control.Monad.Trans.State as State (get, put)
45 import qualified Data.ByteString.Lazy.Char8 as B
47 import Data.Aeson.Encode.Pretty
48 import Data.Conduit as Conduit
49 import Data.Conduit.Parser as Parser
53 import qualified Data.Map as Map
54 import qualified Data.Text as T
55 import qualified Data.Text.IO as T
56 import qualified Data.HashMap.Strict as HashMap
59 import Language.Haskell.LSP.Messages
60 import Language.Haskell.LSP.Types.Capabilities
61 import Language.Haskell.LSP.Types
62 import Language.Haskell.LSP.Types.Lens hiding (error)
63 import Language.Haskell.LSP.VFS
64 import Language.Haskell.LSP.Test.Decoding
65 import Language.Haskell.LSP.Test.Exceptions
66 import System.Console.ANSI
67 import System.Directory
70 -- | A session representing one instance of launching and connecting to a server.
72 -- You can send and receive messages to the server within 'Session' via 'getMessage',
73 -- 'sendRequest' and 'sendNotification'.
76 type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
78 #if __GLASGOW_HASKELL__ >= 806
79 instance MonadFail Session where
81 lastMsg <- fromJust . lastReceivedMessage <$> get
82 liftIO $ throw (UnexpectedMessage s lastMsg)
85 -- | Stuff you can configure for a 'Session'.
86 data SessionConfig = SessionConfig
87 { messageTimeout :: Int -- ^ Maximum time to wait for a message in seconds, defaults to 60.
88 , logStdErr :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
89 , logMessages :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
90 , logColor :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
91 , lspConfig :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
94 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
95 defaultConfig :: SessionConfig
96 defaultConfig = SessionConfig 60 False True True Nothing
98 instance Default SessionConfig where
101 data SessionMessage = ServerMessage FromServerMessage
105 data SessionContext = SessionContext
108 , rootDir :: FilePath
109 , messageChan :: Chan SessionMessage
110 , requestMap :: MVar RequestMap
111 , initRsp :: MVar InitializeResponse
112 , config :: SessionConfig
113 , sessionCapabilities :: ClientCapabilities
116 class Monad m => HasReader r m where
118 asks :: (r -> b) -> m b
121 instance Monad m => HasReader r (ParserStateReader a s r m) where
122 ask = lift $ lift Reader.ask
124 instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
125 ask = lift $ lift Reader.ask
127 data SessionState = SessionState
131 , curDiagnostics :: Map.Map Uri [Diagnostic]
132 , curTimeoutId :: Int
133 , overridingTimeout :: Bool
134 -- ^ The last received message from the server.
135 -- Used for providing exception information
136 , lastReceivedMessage :: Maybe FromServerMessage
139 class Monad m => HasState s m where
144 modify :: (s -> s) -> m ()
145 modify f = get >>= put . f
147 modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
148 modifyM f = get >>= f >>= put
150 instance Monad m => HasState s (ParserStateReader a s r m) where
152 put = lift . State.put
154 instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
157 put = lift . State.put
159 type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
161 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
162 runSession context state session = runReaderT (runStateT conduit state) context
164 conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
166 handler (Unexpected "ConduitParser.empty") = do
167 lastMsg <- fromJust . lastReceivedMessage <$> get
168 name <- getParserName
169 liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
174 msg <- liftIO $ readChan (messageChan context)
178 watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
179 watchdog = Conduit.awaitForever $ \msg -> do
180 curId <- curTimeoutId <$> get
182 ServerMessage sMsg -> yield sMsg
183 TimeoutMessage tId -> when (curId == tId) $ throw Timeout
185 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
186 -- It also does not automatically send initialize and exit messages.
187 runSessionWithHandles :: Handle -- ^ Server in
188 -> Handle -- ^ Server out
189 -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
191 -> ClientCapabilities
192 -> FilePath -- ^ Root directory
195 runSessionWithHandles serverIn serverOut serverHandler config caps rootDir session = do
196 absRootDir <- canonicalizePath rootDir
198 hSetBuffering serverIn NoBuffering
199 hSetBuffering serverOut NoBuffering
201 reqMap <- newMVar newRequestMap
202 messageChan <- newChan
203 initRsp <- newEmptyMVar
205 let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
206 initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
208 threadId <- forkIO $ void $ serverHandler serverOut context
209 (result, _) <- runSession context initState session
215 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
216 updateStateC = awaitForever $ \msg -> do
220 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
221 updateState (NotPublishDiagnostics n) = do
222 let List diags = n ^. params . diagnostics
223 doc = n ^. params . uri
225 let newDiags = Map.insert doc diags (curDiagnostics s)
226 in s { curDiagnostics = newDiags })
228 updateState (ReqApplyWorkspaceEdit r) = do
230 allChangeParams <- case r ^. params . edit . documentChanges of
232 mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
233 return $ map getParams cs
234 Nothing -> case r ^. params . edit . changes of
236 mapM_ checkIfNeedsOpened (HashMap.keys cs)
237 return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
238 Nothing -> error "No changes!"
241 newVFS <- liftIO $ changeFromServerVFS (vfs s) r
242 return $ s { vfs = newVFS }
244 let groupedParams = groupBy (\a b -> (a ^. textDocument == b ^. textDocument)) allChangeParams
245 mergedParams = map mergeParams groupedParams
247 -- TODO: Don't do this when replaying a session
248 forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
250 -- Update VFS to new document versions
251 let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
252 latestVersions = map ((^. textDocument) . last) sortedVersions
253 bumpedVersions = map (version . _Just +~ 1) latestVersions
255 forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
258 update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
259 newVFS = Map.adjust update uri oldVFS
260 in s { vfs = newVFS }
262 where checkIfNeedsOpened uri = do
263 oldVFS <- vfs <$> get
266 -- if its not open, open it
267 unless (uri `Map.member` oldVFS) $ do
268 let fp = fromJust $ uriToFilePath uri
269 contents <- liftIO $ T.readFile fp
270 let item = TextDocumentItem (filePathToUri fp) "" 0 contents
271 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
272 liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
275 newVFS <- liftIO $ openVFS (vfs s) msg
276 return $ s { vfs = newVFS }
278 getParams (TextDocumentEdit docId (List edits)) =
279 let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
280 in DidChangeTextDocumentParams docId (List changeEvents)
282 textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
284 textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
286 getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
288 mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
289 mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
290 in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
291 updateState _ = return ()
293 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
295 h <- serverIn <$> ask
297 liftIO $ B.hPut h (addHeader $ encode msg)
299 -- | Execute a block f that will throw a 'TimeoutException'
300 -- after duration seconds. This will override the global timeout
301 -- for waiting for messages to arrive defined in 'SessionConfig'.
302 withTimeout :: Int -> Session a -> Session a
303 withTimeout duration f = do
304 chan <- asks messageChan
305 timeoutId <- curTimeoutId <$> get
306 modify $ \s -> s { overridingTimeout = True }
308 threadDelay (duration * 1000000)
309 writeChan chan (TimeoutMessage timeoutId)
311 modify $ \s -> s { curTimeoutId = timeoutId + 1,
312 overridingTimeout = False
316 data LogMsgType = LogServer | LogClient
319 -- | Logs the message if the config specified it
320 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
321 => LogMsgType -> a -> m ()
323 shouldLog <- asks $ logMessages . config
324 shouldColor <- asks $ logColor . config
325 liftIO $ when shouldLog $ do
326 when shouldColor $ setSGR [SetColor Foreground Dull color]
327 putStrLn $ arrow ++ showPretty msg
328 when shouldColor $ setSGR [Reset]
331 | t == LogServer = "<-- "
334 | t == LogServer = Magenta
337 showPretty = B.unpack . encodePretty