Watch files to send didChangeWatchedFiles notifications
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
1 {-# LANGUAGE CPP               #-}
2 {-# LANGUAGE OverloadedStrings #-}
3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
4 {-# LANGUAGE FlexibleInstances #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 {-# LANGUAGE FlexibleContexts #-}
7 {-# LANGUAGE RankNTypes #-}
8
9 module Language.Haskell.LSP.Test.Session
10   ( Session(..)
11   , SessionConfig(..)
12   , defaultConfig
13   , SessionMessage(..)
14   , SessionContext(..)
15   , SessionState(..)
16   , runSessionWithHandles
17   , get
18   , put
19   , modify
20   , modifyM
21   , ask
22   , asks
23   , sendMessage
24   , updateState
25   , withTimeout
26   , getCurTimeoutId
27   , bumpTimeoutId
28   , logMsg
29   , LogMsgType(..)
30   )
31
32 where
33
34 import Control.Applicative
35 import Control.Concurrent hiding (yield)
36 import Control.Exception
37 import Control.Lens hiding (List)
38 import Control.Monad
39 import Control.Monad.IO.Class
40 import Control.Monad.Except
41 #if __GLASGOW_HASKELL__ == 806
42 import Control.Monad.Fail
43 #endif
44 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
45 import qualified Control.Monad.Trans.Reader as Reader (ask)
46 import Control.Monad.Trans.State (StateT, runStateT)
47 import qualified Control.Monad.Trans.State as State
48 import qualified Data.ByteString.Lazy.Char8 as B
49 import Data.Aeson
50 import Data.Aeson.Encode.Pretty
51 import Data.Conduit as Conduit
52 import Data.Conduit.Parser as Parser
53 import Data.Default
54 import Data.Foldable
55 import Data.List
56 import qualified Data.Map as Map
57 import qualified Data.Text as T
58 import qualified Data.Text.IO as T
59 import qualified Data.HashMap.Strict as HashMap
60 import Data.Maybe
61 import Data.Function
62 import Language.Haskell.LSP.Messages
63 import Language.Haskell.LSP.Types.Capabilities
64 import Language.Haskell.LSP.Types
65 import Language.Haskell.LSP.Types.Lens
66 import qualified Language.Haskell.LSP.Types.Lens as LSP
67 import Language.Haskell.LSP.VFS
68 import Language.Haskell.LSP.Test.Compat
69 import Language.Haskell.LSP.Test.Decoding
70 import Language.Haskell.LSP.Test.Exceptions
71 import System.Console.ANSI
72 import System.Directory
73 import System.FSNotify (watchTree, eventPath, withManager, WatchManager)
74 import qualified System.FSNotify as FS
75 import System.IO
76 import System.Process (ProcessHandle())
77 import System.Timeout
78 import System.FilePath.Glob (match, commonDirectory, compile)
79
80 -- | A session representing one instance of launching and connecting to a server.
81 --
82 -- You can send and receive messages to the server within 'Session' via
83 -- 'Language.Haskell.LSP.Test.message',
84 -- 'Language.Haskell.LSP.Test.sendRequest' and
85 -- 'Language.Haskell.LSP.Test.sendNotification'.
86
87 newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
88   deriving (Functor, Applicative, Monad, MonadIO, Alternative)
89
90 #if __GLASGOW_HASKELL__ >= 806
91 instance MonadFail Session where
92   fail s = do
93     lastMsg <- fromJust . lastReceivedMessage <$> get
94     liftIO $ throw (UnexpectedMessage s lastMsg)
95 #endif
96
97 -- | Stuff you can configure for a 'Session'.
98 data SessionConfig = SessionConfig
99   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
100   , logStdErr      :: Bool
101   -- ^ Redirect the server's stderr to this stdout, defaults to False.
102   -- Can be overriden with @LSP_TEST_LOG_STDERR@.
103   , logMessages    :: Bool
104   -- ^ Trace the messages sent and received to stdout, defaults to False.
105   -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
106   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
107   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
108   , ignoreLogNotifications :: Bool
109   -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
110   -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
111   --
112   -- @since 0.9.0.0
113   }
114
115 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
116 defaultConfig :: SessionConfig
117 defaultConfig = SessionConfig 60 False False True Nothing False
118
119 instance Default SessionConfig where
120   def = defaultConfig
121
122 data SessionMessage = ServerMessage FromServerMessage
123                     | TimeoutMessage Int
124   deriving Show
125
126 data SessionContext = SessionContext
127   {
128     serverIn :: Handle
129   , rootDir :: FilePath
130   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
131   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
132   , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
133   , requestMap :: MVar RequestMap
134   , initRsp :: MVar InitializeResponse
135   , config :: SessionConfig
136   , sessionCapabilities :: ClientCapabilities
137   , watchManager :: WatchManager
138   }
139
140 class Monad m => HasReader r m where
141   ask :: m r
142   asks :: (r -> b) -> m b
143   asks f = f <$> ask
144
145 instance HasReader SessionContext Session where
146   ask  = Session (lift $ lift Reader.ask)
147
148 instance Monad m => HasReader r (ConduitT a b (StateT s (ReaderT r m))) where
149   ask = lift $ lift Reader.ask
150
151 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
152 getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
153
154 -- Pass this the timeoutid you *were* waiting on
155 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
156 bumpTimeoutId prev = do
157   v <- asks curTimeoutId
158   -- when updating the curtimeoutid, account for the fact that something else
159   -- might have bumped the timeoutid in the meantime
160   liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
161
162 data SessionState = SessionState
163   {
164     curReqId :: LspId
165   , vfs :: VFS
166   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
167   , overridingTimeout :: Bool
168   -- ^ The last received message from the server.
169   -- Used for providing exception information
170   , lastReceivedMessage :: Maybe FromServerMessage
171   , curDynCaps :: Map.Map T.Text Registration
172   -- ^ The capabilities that the server has dynamically registered with us so
173   -- far
174   , unwatchers :: Map.Map T.Text [IO ()]
175   }
176
177 class Monad m => HasState s m where
178   get :: m s
179
180   put :: s -> m ()
181
182   modify :: (s -> s) -> m ()
183   modify f = get >>= put . f
184
185   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
186   modifyM f = get >>= f >>= put
187
188 instance HasState SessionState Session where
189   get = Session (lift State.get)
190   put = Session . lift . State.put
191
192 instance Monad m => HasState s (StateT s m) where
193   get = State.get
194   put = State.put
195
196 instance (Monad m, (HasState s m)) => HasState s (ConduitT a b m)
197  where
198   get = lift get
199   put = lift . put
200
201 instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
202  where
203   get = lift get
204   put = lift . put
205
206 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
207 runSession context state (Session session) = runReaderT (runStateT conduit state) context
208   where
209     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
210
211     handler (Unexpected "ConduitParser.empty") = do
212       lastMsg <- fromJust . lastReceivedMessage <$> get
213       name <- getParserName
214       liftIO $ throw (UnexpectedMessage (T.unpack name) lastMsg)
215
216     handler e = throw e
217
218     chanSource = do
219       msg <- liftIO $ readChan (messageChan context)
220       unless (ignoreLogNotifications (config context) && isLogNotification msg) $
221         yield msg
222       chanSource
223
224     isLogNotification (ServerMessage (NotShowMessage _)) = True
225     isLogNotification (ServerMessage (NotLogMessage _)) = True
226     isLogNotification _ = False
227
228     watchdog :: ConduitT SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
229     watchdog = Conduit.awaitForever $ \msg -> do
230       curId <- getCurTimeoutId
231       case msg of
232         ServerMessage sMsg -> yield sMsg
233         TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
234
235 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
236 -- It also does not automatically send initialize and exit messages.
237 runSessionWithHandles :: Handle -- ^ Server in
238                       -> Handle -- ^ Server out
239                       -> ProcessHandle -- ^ Server process
240                       -> (Handle -> SessionContext -> IO ()) -- ^ Server listener
241                       -> SessionConfig
242                       -> ClientCapabilities
243                       -> FilePath -- ^ Root directory
244                       -> Session () -- ^ To exit the Server properly
245                       -> Session a
246                       -> IO a
247 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
248   absRootDir <- canonicalizePath rootDir
249
250   hSetBuffering serverIn  NoBuffering
251   hSetBuffering serverOut NoBuffering
252   -- This is required to make sure that we don’t get any
253   -- newline conversion or weird encoding issues.
254   hSetBinaryMode serverIn True
255   hSetBinaryMode serverOut True
256
257   reqMap <- newMVar newRequestMap
258   messageChan <- newChan
259   timeoutIdVar <- newMVar 0
260   initRsp <- newEmptyMVar
261
262   mainThreadId <- myThreadId
263
264   withManager $ \watchManager -> do
265     let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps watchManager
266         initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty
267         -- Interesting note: turning on TypeFamilies causes type inference to
268         -- infer the type runSession' :: Session () -> IO ((), SessionState)
269         -- instead of     runSession' :: Session a  -> IO (a , SessionState)
270         runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
271
272         errorHandler = throwTo mainThreadId :: SessionException -> IO ()
273         serverListenerLauncher =
274           forkIO $ catch (serverHandler serverOut context) errorHandler
275         server = (Just serverIn, Just serverOut, Nothing, serverProc)
276         serverAndListenerFinalizer tid = do
277           finally (timeout (messageTimeout config * 1^6)
278                           (runSession' exitServer))
279                   -- Make sure to kill the listener first, before closing
280                   -- handles etc via cleanupProcess
281                   (killThread tid >> cleanupProcess server)
282
283     (result, _) <- bracket serverListenerLauncher
284                           serverAndListenerFinalizer
285                           (const $ runSession' session)
286     return result
287
288 updateStateC :: ConduitT FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
289 updateStateC = awaitForever $ \msg -> do
290   updateState msg
291   yield msg
292
293 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
294             => FromServerMessage -> m ()
295
296 -- Keep track of dynamic capability registration
297 updateState (ReqRegisterCapability req) = do
298   let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
299   modify $ \s ->
300     s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
301
302   -- Process the new registrations
303   forM_ newRegs $ \(regId, reg) -> do
304     when (reg ^. method == WorkspaceDidChangeWatchedFiles) $ do
305       processFileWatchRegistration regId reg
306
307 updateState (ReqUnregisterCapability req) = do
308   let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
309   modify $ \s ->
310     let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
311     in s { curDynCaps = newCurDynCaps }
312
313   -- Process the unregistrations
314   processFileWatchUnregistrations unRegs
315
316 updateState (NotPublishDiagnostics n) = do
317   let List diags = n ^. params . diagnostics
318       doc = n ^. params . uri
319   modify $ \s ->
320     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
321       in s { curDiagnostics = newDiags }
322
323 updateState (ReqApplyWorkspaceEdit r) = do
324
325   allChangeParams <- case r ^. params . edit . documentChanges of
326     Just (List cs) -> do
327       mapM_ (checkIfNeedsOpened . (^. textDocument . uri)) cs
328       return $ map getParams cs
329     Nothing -> case r ^. params . edit . changes of
330       Just cs -> do
331         mapM_ checkIfNeedsOpened (HashMap.keys cs)
332         return $ concatMap (uncurry getChangeParams) (HashMap.toList cs)
333       Nothing -> error "No changes!"
334
335   modifyM $ \s -> do
336     newVFS <- liftIO $ changeFromServerVFS (vfs s) r
337     return $ s { vfs = newVFS }
338
339   let groupedParams = groupBy (\a b -> a ^. textDocument == b ^. textDocument) allChangeParams
340       mergedParams = map mergeParams groupedParams
341
342   -- TODO: Don't do this when replaying a session
343   forM_ mergedParams (sendMessage . NotificationMessage "2.0" TextDocumentDidChange)
344
345   -- Update VFS to new document versions
346   let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
347       latestVersions = map ((^. textDocument) . last) sortedVersions
348       bumpedVersions = map (version . _Just +~ 1) latestVersions
349
350   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
351     modify $ \s ->
352       let oldVFS = vfs s
353           update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
354           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
355       in s { vfs = newVFS }
356
357   where checkIfNeedsOpened uri = do
358           oldVFS <- vfs <$> get
359           ctx <- ask
360
361           -- if its not open, open it
362           unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
363             let fp = fromJust $ uriToFilePath uri
364             contents <- liftIO $ T.readFile fp
365             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
366                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
367             -- TODO: use 'sendMessage'?
368             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
369
370             modifyM $ \s -> do
371               let (newVFS,_) = openVFS (vfs s) msg
372               return $ s { vfs = newVFS }
373
374         getParams (TextDocumentEdit docId (List edits)) =
375           let changeEvents = map (\e -> TextDocumentContentChangeEvent (Just (e ^. range)) Nothing (e ^. newText)) edits
376             in DidChangeTextDocumentParams docId (List changeEvents)
377
378         textDocumentVersions uri = map (VersionedTextDocumentIdentifier uri . Just) [0..]
379
380         textDocumentEdits uri edits = map (\(v, e) -> TextDocumentEdit v (List [e])) $ zip (textDocumentVersions uri) edits
381
382         getChangeParams uri (List edits) = map getParams (textDocumentEdits uri (reverse edits))
383
384         mergeParams :: [DidChangeTextDocumentParams] -> DidChangeTextDocumentParams
385         mergeParams params = let events = concat (toList (map (toList . (^. contentChanges)) params))
386                               in DidChangeTextDocumentParams (head params ^. textDocument) (List events)
387 updateState _ = return ()
388
389 sendMessage :: (MonadIO m, HasReader SessionContext m, ToJSON a) => a -> m ()
390 sendMessage msg = do
391   h <- serverIn <$> ask
392   logMsg LogClient msg
393   liftIO $ B.hPut h (addHeader $ encode msg)
394
395 -- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
396 -- after duration seconds. This will override the global timeout
397 -- for waiting for messages to arrive defined in 'SessionConfig'.
398 withTimeout :: Int -> Session a -> Session a
399 withTimeout duration f = do
400   chan <- asks messageChan
401   timeoutId <- getCurTimeoutId
402   modify $ \s -> s { overridingTimeout = True }
403   liftIO $ forkIO $ do
404     threadDelay (duration * 1000000)
405     writeChan chan (TimeoutMessage timeoutId)
406   res <- f
407   bumpTimeoutId timeoutId
408   modify $ \s -> s { overridingTimeout = False }
409   return res
410
411 -- TODO: add a shouldTimeout helper. need to add exceptions within Session
412 data LogMsgType = LogServer | LogClient
413   deriving Eq
414
415 -- | Logs the message if the config specified it
416 logMsg :: (ToJSON a, MonadIO m, HasReader SessionContext m)
417        => LogMsgType -> a -> m ()
418 logMsg t msg = do
419   shouldLog <- asks $ logMessages . config
420   shouldColor <- asks $ logColor . config
421   liftIO $ when shouldLog $ do
422     when shouldColor $ setSGR [SetColor Foreground Dull color]
423     putStrLn $ arrow ++ showPretty msg
424     when shouldColor $ setSGR [Reset]
425
426   where arrow
427           | t == LogServer  = "<-- "
428           | otherwise       = "--> "
429         color
430           | t == LogServer  = Magenta
431           | otherwise       = Cyan
432
433         showPretty = B.unpack . encodePretty
434
435 -- File watching
436
437 processFileWatchRegistration :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
438                              => T.Text -> Registration -> m ()
439 processFileWatchRegistration regId reg = do
440   mgr <- asks watchManager
441   let mOpts = do
442         regOpts <- reg ^. registerOptions
443         case fromJSON regOpts  of
444           Error _ -> Nothing
445           Success x -> Just x
446   case mOpts of
447     Nothing -> pure ()
448     Just (DidChangeWatchedFilesRegistrationOptions (List ws)) ->
449       forM_ ws $ \(FileSystemWatcher pat' watchKind) -> do
450         pat <- liftIO $ canonicalizePath pat'
451         let glob = compile pat
452             -- the root-most dir before any globbing stuff happens
453             dir = fst $ commonDirectory glob
454             pred = match glob . eventPath
455             -- If no watchKind specified, spec defaults to all true
456             WatchKind wkC wkM wkD = fromMaybe (WatchKind True True True) watchKind
457         handle <- asks serverIn
458         unwatch <- liftIO $ watchTree mgr dir pred $ \event -> do
459           let fe = FileEvent (filePathToUri (eventPath event)) typ
460               typ = case event of
461                 FS.Added _ _ _ -> FcCreated
462                 FS.Modified _ _ _ -> FcChanged
463                 FS.Removed _ _ _ -> FcDeleted
464                 -- This is a bit of a guess
465                 FS.Unknown _ _ _ -> FcChanged
466               matches = case typ of
467                 FcCreated -> wkC
468                 FcChanged -> wkM
469                 FcDeleted -> wkD
470               params = DidChangeWatchedFilesParams (List [fe])
471               msg = fmClientDidChangeWatchedFilesNotification params
472           liftIO $ when matches $ B.hPut handle (addHeader $ encode msg)
473         modify $ \s ->
474           s { unwatchers = Map.insertWith (++) regId [unwatch] (unwatchers s) }
475
476 processFileWatchUnregistrations :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
477                                 => [T.Text] -> m ()
478 processFileWatchUnregistrations regIds =
479   forM_ regIds $ \regId -> modifyM $ \s -> do
480     let fs = fromMaybe [] (Map.lookup regId (unwatchers s))
481     liftIO $ sequence fs
482     return $ s { unwatchers = Map.delete regId (unwatchers s) }