Handle [un]registerCapability and workspace/didChangeWatchedFiles
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
60 import Data.Maybe
61 import Data.Function
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.IO
74 import System.Process (ProcessHandle())
75 import System.Timeout
76
77 -- | A session representing one instance of launching and connecting to a server.
78 --
79 -- You can send and receive messages to the server within 'Session' via
80 -- 'Language.Haskell.LSP.Test.message',
81 -- 'Language.Haskell.LSP.Test.sendRequest' and
82 -- 'Language.Haskell.LSP.Test.sendNotification'.
83
84 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
85   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
86
87 #if __GLASGOW_HASKELL__ >= 806
88 instance MonadFail Session where
89   fail s = do
90     lastMsg <- fromJust . lastReceivedMessage <$> get
91     liftIO $ throw (UnexpectedMessage s lastMsg)
92 #endif
93
94 -- | Stuff you can configure for a 'Session'.
95 data SessionConfig = SessionConfig
96   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
97   , logStdErr      :: Bool
98   -- ^ Redirect the server's stderr to this stdout, defaults to False.
99   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
100   , logMessages    :: Bool
101   -- ^ Trace the messages sent and received to stdout, defaults to False.
102   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
103   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
104   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
105   , ignoreLogNotifications :: Bool
106   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
107   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
108   --
109   -- @since 0.9.0.0
110   }
111
112 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
113 defaultConfig :: SessionConfig
114 defaultConfig = SessionConfig 60 False False True Nothing False
115
116 instance Default SessionConfig where
117   def = defaultConfig
118
119 data SessionMessage = ServerMessage FromServerMessage
120                     | TimeoutMessage Int
121   deriving Show
122
123 data SessionContext = SessionContext
124   {
125     serverIn :: Handle
126   , rootDir :: FilePath
127   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
128   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
129   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
130   , requestMap :: MVar RequestMap
131   , initRsp :: MVar InitializeResponse
132   , config :: SessionConfig
133   , sessionCapabilities :: ClientCapabilities
134   }
135
136 class Monad m => HasReader r m where
137   ask :: m r
138   asks :: (r -> b) -> m b
139   asks f = f <$> ask
140
141 instance HasReader SessionContext Session where
142   ask  = Session (lift $ lift Reader.ask)
143
144 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
145   ask = lift $ lift Reader.ask
146
147 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
148 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
149
150 -- Pass this the timeoutid you *were* waiting on
151 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
152 bumpTimeoutId prev = do
153   v <- asks curTimeoutId
154   -- when updating the curtimeoutid, account for the fact that something else
155   -- might have bumped the timeoutid in the meantime
156   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
157
158 data SessionState = SessionState
159   {
160     curReqId :: LspId
161   , vfs :: VFS
162   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
163   , overridingTimeout :: Bool
164   -- ^ The last received message from the server.
165   -- Used for providing exception information
166   , lastReceivedMessage :: Maybe FromServerMessage
167   , curDynCaps :: Map.Map T.Text Registration
168   -- ^ The capabilities that the server has dynamically registered with us so
169   -- far
170   }
171
172 class Monad m => HasState s m where
173   get :: m s
174
175   put :: s -> m ()
176
177   modify :: (s -> s) -> m ()
178   modify f = get >>= put . f
179
180   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
181   modifyM f = get >>= f >>= put
182
183 instance HasState SessionState Session where
184   get = Session (lift State.get)
185   put = Session . lift . State.put
186
187 instance Monad m => HasState s (StateT s m) where
188   get = State.get
189   put = State.put
190
191 instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
192  where
193   get = lift get
194   put = lift . put
195
196 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
197  where
198   get = lift get
199   put = lift . put
200
201 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
202 runSession context state (Session session) = runReaderT (runStateT conduit state) context
203   where
204     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
205
206     handler (Unexpected "ConduitParser.empty") = do
207       lastMsg <- fromJust . lastReceivedMessage <$> get
208       name <- getParserName
209       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
210
211     handler e = throw e
212
213     chanSource = do
214       msg <- liftIO $ readChan (messageChan context)
215       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
216         yield msg
217       chanSource
218
219     isLogNotification (ServerMessage (NotShowMessage _)) = True
220     isLogNotification (ServerMessage (NotLogMessage _)) = True
221     isLogNotification _ = False
222
223     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
224     watchdog = Conduit.awaitForever $ \msg -> do
225       curId <- getCurTimeoutId
226       case msg of
227         ServerMessage sMsg -> yield sMsg
228         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
229
230 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
231 -- It also does not automatically send initialize and exit messages.
232 runSessionWithHandles :: Handle -- ^ Server in
233                       -> Handle -- ^ Server out
234                       -> ProcessHandle -- ^ Server process
235                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
236                       -> SessionConfig
237                       -> ClientCapabilities
238                       -> FilePath -- ^ Root directory
239                       -> Session () -- ^ To exit the Server properly
240                       -> Session a
241                       -> IO a
242 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
243   absRootDir <- canonicalizePath rootDir
244
245   hSetBuffering serverIn  NoBuffering
246   hSetBuffering serverOut NoBuffering
247   -- This is required to make sure that we don’t get any
248   -- newline conversion or weird encoding issues.
249   hSetBinaryMode serverIn True
250   hSetBinaryMode serverOut True
251
252   reqMap <- newMVar newRequestMap
253   messageChan <- newChan
254   timeoutIdVar <- newMVar 0
255   initRsp <- newEmptyMVar
256
257   mainThreadId <- myThreadId
258
259   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
260       initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty
261       runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
262
263       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
264       serverListenerLauncher =
265         forkIO $ catch (serverHandler serverOut context) errorHandler
266       server = (Just serverIn, Just serverOut, Nothing, serverProc)
267       serverAndListenerFinalizer tid = do
268         finally (timeout (messageTimeout config * 1^6)
269                          (runSession' exitServer))
270                 (cleanupProcess server >> killThread tid)
271
272   (result, _) <- bracket serverListenerLauncher
273                          serverAndListenerFinalizer
274                          (const $ runSession' session)
275   return result
276
277 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
278 updateStateC = awaitForever $ \msg -> do
279   updateState msg
280   yield msg
281
282 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
283             => FromServerMessage -> m ()
284
285 -- Keep track of dynamic capability registration
286 updateState (ReqRegisterCapability req) = do
287   let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
288   modify $ \s ->
289     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
290
291 updateState (ReqUnregisterCapability req) = do
292   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
293   modify $ \s ->
294     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
295     in s { curDynCaps = newCurDynCaps }
296
297 updateState (NotPublishDiagnostics n) = do
298   let List diags = n ^. params . diagnostics
299       doc = n ^. params . uri
300   modify $ \s ->
301     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
302       in s { curDiagnostics = newDiags }
303
304 updateState (ReqApplyWorkspaceEdit r) = do
305
306   allChangeParams <- case r ^. params . edit . documentChanges of
307     Just (List cs) -> do
308       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
309       return $ map getParams cs
310     Nothing -> case r ^. params . edit . changes of
311       Just cs -> do
312         mapM_ checkIfNeedsOpened (HashMap.keys cs)
313         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
314       Nothing -> error "No changes!"
315
316   modifyM $ \s -> do
317     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
318     return $ s { vfs = newVFS }
319
320   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
321       mergedParams = map mergeParams groupedParams
322
323   -- TODO: Don't do this when replaying a session
324   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
325
326   -- Update VFS to new document versions
327   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
328       latestVersions = map ((^. textDocument) . last) sortedVersions
329       bumpedVersions = map (version . _Just +~ 1) latestVersions
330
331   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
332     modify $ \s ->
333       let oldVFS = vfs s
334           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
335           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
336       in s { vfs = newVFS }
337
338   where checkIfNeedsOpened uri = do
339           oldVFS <- vfs <$> get
340           ctx <- ask
341
342           -- if its not open, open it
343           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
344             let fp = fromJust $ uriToFilePath uri
345             contents <- liftIO $ T.readFile fp
346             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
347                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
348             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
349
350             modifyM $ \s -> do
351               let (newVFS,_) = openVFS (vfs s) msg
352               return $ s { vfs = newVFS }
353
354         getParams (TextDocumentEdit docId (List edits)) =
355           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
356             in DidChangeTextDocumentParams docId (List changeEvents)
357
358         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
359
360         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
361
362         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
363
364         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
365         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
366                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
367 updateState _ = return ()
368
369 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
370 sendMessage msg = do
371   h <- serverIn <$> ask
372   logMsg LogClient msg
373   liftIO $ B.hPut h (addHeader $ encode msg)
374
375 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
376 -- after duration seconds. This will override the global timeout
377 -- for waiting for messages to arrive defined in 'SessionConfig'.
378 withTimeout :: Int -> Session a -> Session a
379 withTimeout duration f = do
380   chan <- asks messageChan
381   timeoutId <- getCurTimeoutId
382   modify $ \s -> s { overridingTimeout = True }
383   liftIO $ forkIO $ do
384     threadDelay (duration * 1000000)
385     writeChan chan (TimeoutMessage timeoutId)
386   res <- f
387   bumpTimeoutId timeoutId
388   modify $ \s -> s { overridingTimeout = False }
389   return res
390
391 data LogMsgType = LogServer | LogClient
392   deriving Eq
393
394 -- | Logs the message if the config specified it
395 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
396        => LogMsgType -> a -> m ()
397 logMsg t msg = do
398   shouldLog <- asks $ logMessages . config
399   shouldColor <- asks $ logColor . config
400   liftIO $ when shouldLog $ do
401     when shouldColor $ setSGR [SetColor Foreground Dull color]
402     putStrLn $ arrow ++ showPretty msg
403     when shouldColor $ setSGR [Reset]
404
405   where arrow
406           | t == LogServer  = "<-- "
407           | otherwise       = "--> "
408         color
409           | t == LogServer  = Magenta
410           | otherwise       = Cyan
411
412         showPretty = B.unpack . encodePretty