Watch files to send didChangeWatchedFiles notifications
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index bbfdf386ac167bd0f8ab5b9a277b754920aacf57..a82651af11b5617203417505db7bb1268cc3478d 100644 (file)
@@ -1,12 +1,13 @@
 {-# LANGUAGE CPP               #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE FlexibleInstances #-}
 {-# LANGUAGE MultiParamTypeClasses #-}
 {-# LANGUAGE FlexibleContexts #-}
 {-# LANGUAGE RankNTypes #-}
 
 module Language.Haskell.LSP.Test.Session
-  ( Session
+  ( Session(..)
   , SessionConfig(..)
   , defaultConfig
   , SessionMessage(..)
@@ -22,25 +23,28 @@ module Language.Haskell.LSP.Test.Session
   , sendMessage
   , updateState
   , withTimeout
+  , getCurTimeoutId
+  , bumpTimeoutId
   , logMsg
   , LogMsgType(..)
   )
 
 where
 
+import Control.Applicative
 import Control.Concurrent hiding (yield)
 import Control.Exception
 import Control.Lens hiding (List)
 import Control.Monad
 import Control.Monad.IO.Class
 import Control.Monad.Except
-#if __GLASGOW_HASKELL__ >= 806
+#if __GLASGOW_HASKELL__ == 806
 import Control.Monad.Fail
 #endif
 import Control.Monad.Trans.Reader (ReaderT, runReaderT)
 import qualified Control.Monad.Trans.Reader as Reader (ask)
 import Control.Monad.Trans.State (StateT, runStateT)
-import qualified Control.Monad.Trans.State as State (get, put)
+import qualified Control.Monad.Trans.State as State
 import qualified Data.ByteString.Lazy.Char8 as B
 import Data.Aeson
 import Data.Aeson.Encode.Pretty
@@ -58,16 +62,20 @@ import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
-import Language.Haskell.LSP.Types.Lens hiding (error)
+import Language.Haskell.LSP.Types.Lens
+import qualified Language.Haskell.LSP.Types.Lens as LSP
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
 import Language.Haskell.LSP.Test.Exceptions
 import System.Console.ANSI
 import System.Directory
+import System.FSNotify (watchTree, eventPath, withManager, WatchManager)
+import qualified System.FSNotify as FS
 import System.IO
 import System.Process (ProcessHandle())
 import System.Timeout
+import System.FilePath.Glob (match, commonDirectory, compile)
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -76,7 +84,8 @@ import System.Timeout
 -- 'Language.Haskell.LSP.Test.sendRequest' and
 -- 'Language.Haskell.LSP.Test.sendNotification'.
 
-type Session = ParserStateReader FromServerMessage SessionState SessionContext IO
+newtype Session a = Session (ConduitParser FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) a)
+  deriving (Functor, Applicative, Monad, MonadIO, Alternative)
 
 #if __GLASGOW_HASKELL__ >= 806
 instance MonadFail Session where
@@ -88,15 +97,24 @@ instance MonadFail Session where
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
-  , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
-  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
+  , logStdErr      :: Bool
+  -- ^ Redirect the server's stderr to this stdout, defaults to False.
+  -- Can be overriden with @LSP_TEST_LOG_STDERR@.
+  , logMessages    :: Bool
+  -- ^ Trace the messages sent and received to stdout, defaults to False.
+  -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
+  , ignoreLogNotifications :: Bool
+  -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
+  -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
+  --
+  -- @since 0.9.0.0
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False False True Nothing
+defaultConfig = SessionConfig 60 False False True Nothing False
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -109,11 +127,14 @@ data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan SessionMessage
+  , messageChan :: Chan SessionMessage -- ^ Where all messages come through
+  -- Keep curTimeoutId in SessionContext, as its tied to messageChan
+  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
   , sessionCapabilities :: ClientCapabilities
+  , watchManager :: WatchManager
   }
 
 class Monad m => HasReader r m where
@@ -121,22 +142,36 @@ class Monad m => HasReader r m where
   asks :: (r -> b) -> m b
   asks f = f <$> ask
 
-instance Monad m => HasReader r (ParserStateReader a s r m) where
-  ask = lift $ lift Reader.ask
+instance HasReader SessionContext Session where
+  ask  = Session (lift $ lift Reader.ask)
 
-instance Monad m => HasReader SessionContext (ConduitM a b (StateT s (ReaderT SessionContext m))) where
+instance Monad m => HasReader r (ConduitT a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
+getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+
+-- Pass this the timeoutid you *were* waiting on
+bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
+bumpTimeoutId prev = do
+  v <- asks curTimeoutId
+  -- when updating the curtimeoutid, account for the fact that something else
+  -- might have bumped the timeoutid in the meantime
+  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
-  , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
   , lastReceivedMessage :: Maybe FromServerMessage
+  , curDynCaps :: Map.Map T.Text Registration
+  -- ^ The capabilities that the server has dynamically registered with us so
+  -- far
+  , unwatchers :: Map.Map T.Text [IO ()]
   }
 
 class Monad m => HasState s m where
@@ -150,19 +185,26 @@ class Monad m => HasState s m where
   modifyM :: (HasState s m, Monad m) => (s -> m s) -> m ()
   modifyM f = get >>= f >>= put
 
-instance Monad m => HasState s (ParserStateReader a s r m) where
-  get = lift State.get
-  put = lift . State.put
+instance HasState SessionState Session where
+  get = Session (lift State.get)
+  put = Session . lift . State.put
 
-instance Monad m => HasState SessionState (ConduitM a b (StateT SessionState m))
+instance Monad m => HasState s (StateT s m) where
+  get = State.get
+  put = State.put
+
+instance (Monad m, (HasState s m)) => HasState s (ConduitT a b m)
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
-type ParserStateReader a s r m = ConduitParser a (StateT s (ReaderT r m))
+instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
+ where
+  get = lift get
+  put = lift . put
 
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
-runSession context state session = runReaderT (runStateT conduit state) context
+runSession context state (Session session) = runReaderT (runStateT conduit state) context
   where
     conduit = runConduit $ chanSource .| watchdog .| updateStateC .| runConduitParser (catchError session handler)
 
@@ -175,15 +217,20 @@ runSession context state session = runReaderT (runStateT conduit state) context
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
+      unless (ignoreLogNotifications (config context) && isLogNotification msg) $
         yield msg
       chanSource
 
-    watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+    isLogNotification (ServerMessage (NotShowMessage _)) = True
+    isLogNotification (ServerMessage (NotLogMessage _)) = True
+    isLogNotification _ = False
+
+    watchdog :: ConduitT SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
-      curId <- curTimeoutId <$> get
+      curId <- getCurTimeoutId
       case msg of
         ServerMessage sMsg -> yield sMsg
-        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
+        TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
@@ -198,7 +245,6 @@ runSessionWithHandles :: Handle -- ^ Server in
                       -> Session a
                       -> IO a
 runSessionWithHandles serverIn serverOut serverProc serverHandler config caps rootDir exitServer session = do
-  
   absRootDir <- canonicalizePath rootDir
 
   hSetBuffering serverIn  NoBuffering
@@ -210,36 +256,69 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
+  timeoutIdVar <- newMVar 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState = SessionState (IdInt 0) mempty mempty 0 False Nothing
-      runSession' = runSession context initState
+  withManager $ \watchManager -> do
+    let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps watchManager
+        initState vfs = SessionState (IdInt 0) vfs mempty False Nothing mempty mempty
+        -- Interesting note: turning on TypeFamilies causes type inference to
+        -- infer the type runSession' :: Session () -> IO ((), SessionState)
+        -- instead of     runSession' :: Session a  -> IO (a , SessionState)
+        runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
 
         errorHandler = throwTo mainThreadId :: SessionException -> IO ()
-      serverLauncher = forkIO $ catch (serverHandler serverOut context) errorHandler
+        serverListenerLauncher =
+          forkIO $ catch (serverHandler serverOut context) errorHandler
         server = (Just serverIn, Just serverOut, Nothing, serverProc)
-      serverFinalizer tid = finally (timeout (messageTimeout config * 1000000)
+        serverAndListenerFinalizer tid = do
+          finally (timeout (messageTimeout config * 1^6)
                           (runSession' exitServer))
-                                    (cleanupRunningProcess server >> killThread tid)
+                  -- Make sure to kill the listener first, before closing
+                  -- handles etc via cleanupProcess
+                  (killThread tid >> cleanupProcess server)
 
-  (result, _) <- bracket serverLauncher serverFinalizer (const $ runSession' session)
+    (result, _) <- bracket serverListenerLauncher
+                          serverAndListenerFinalizer
+                          (const $ runSession' session)
     return result
 
-updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
+updateStateC :: ConduitT FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
 updateStateC = awaitForever $ \msg -> do
   updateState msg
   yield msg
 
-updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m) => FromServerMessage -> m ()
+updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+            => FromServerMessage -> m ()
+
+-- Keep track of dynamic capability registration
+updateState (ReqRegisterCapability req) = do
+  let List newRegs = (\r -> (r ^. LSP.id, r)) <$> req ^. params . registrations
+  modify $ \s ->
+    s { curDynCaps = Map.union (Map.fromList newRegs) (curDynCaps s) }
+
+  -- Process the new registrations
+  forM_ newRegs $ \(regId, reg) -> do
+    when (reg ^. method == WorkspaceDidChangeWatchedFiles) $ do
+      processFileWatchRegistration regId reg
+
+updateState (ReqUnregisterCapability req) = do
+  let List unRegs = (^. LSP.id) <$> req ^. params . unregistrations
+  modify $ \s ->
+    let newCurDynCaps = foldr' Map.delete (curDynCaps s) unRegs
+    in s { curDynCaps = newCurDynCaps }
+
+  -- Process the unregistrations
+  processFileWatchUnregistrations unRegs
+
 updateState (NotPublishDiagnostics n) = do
   let List diags = n ^. params . diagnostics
       doc = n ^. params . uri
-  modify (\s ->
+  modify \s ->
     let newDiags = Map.insert (toNormalizedUri doc) diags (curDiagnostics s)
-      in s { curDiagnostics = newDiags })
+      in s { curDiagnostics = newDiags }
 
 updateState (ReqApplyWorkspaceEdit r) = do
 
@@ -271,8 +350,8 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t mf) = VirtualFile (fromMaybe oldV v) t mf
-          newVFS = Map.adjust update (toNormalizedUri uri) oldVFS
+          update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
+          newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
       in s { vfs = newVFS }
 
   where checkIfNeedsOpened uri = do
@@ -280,15 +359,16 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
-          unless (toNormalizedUri uri `Map.member` oldVFS) $ do
+          unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
                 msg = NotificationMessage "2.0" TextDocumentDidOpen (DidOpenTextDocumentParams item)
+            -- TODO: use 'sendMessage'?
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
-              newVFS <- liftIO $ openVFS (vfs s) msg
+              let (newVFS,_) = openVFS (vfs s) msg
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
@@ -312,23 +392,23 @@ sendMessage msg = do
   logMsg LogClient msg
   liftIO $ B.hPut h (addHeader $ encode msg)
 
--- | Execute a block f that will throw a 'Timeout' exception
+-- | Execute a block f that will throw a 'Language.Haskell.LSP.Test.Exception.Timeout' exception
 -- after duration seconds. This will override the global timeout
 -- for waiting for messages to arrive defined in 'SessionConfig'.
 withTimeout :: Int -> Session a -> Session a
 withTimeout duration f = do
   chan <- asks messageChan
-  timeoutId <- curTimeoutId <$> get
+  timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
   liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
-  modify $ \s -> s { curTimeoutId = timeoutId + 1,
-                     overridingTimeout = False
-                   }
+  bumpTimeoutId timeoutId
+  modify $ \s -> s { overridingTimeout = False }
   return res
 
+-- TODO: add a shouldTimeout helper. need to add exceptions within Session
 data LogMsgType = LogServer | LogClient
   deriving Eq
 
@@ -352,3 +432,51 @@ logMsg t msg = do
 
         showPretty = B.unpack . encodePretty
 
+-- File watching
+
+processFileWatchRegistration :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+                             => T.Text -> Registration -> m ()
+processFileWatchRegistration regId reg = do
+  mgr <- asks watchManager
+  let mOpts = do
+        regOpts <- reg ^. registerOptions
+        case fromJSON regOpts  of
+          Error _ -> Nothing
+          Success x -> Just x
+  case mOpts of
+    Nothing -> pure ()
+    Just (DidChangeWatchedFilesRegistrationOptions (List ws)) ->
+      forM_ ws $ \(FileSystemWatcher pat' watchKind) -> do
+        pat <- liftIO $ canonicalizePath pat'
+        let glob = compile pat
+            -- the root-most dir before any globbing stuff happens
+            dir = fst $ commonDirectory glob
+            pred = match glob . eventPath
+            -- If no watchKind specified, spec defaults to all true
+            WatchKind wkC wkM wkD = fromMaybe (WatchKind True True True) watchKind
+        handle <- asks serverIn
+        unwatch <- liftIO $ watchTree mgr dir pred $ \event -> do
+          let fe = FileEvent (filePathToUri (eventPath event)) typ
+              typ = case event of
+                FS.Added _ _ _ -> FcCreated
+                FS.Modified _ _ _ -> FcChanged
+                FS.Removed _ _ _ -> FcDeleted
+                -- This is a bit of a guess
+                FS.Unknown _ _ _ -> FcChanged
+              matches = case typ of
+                FcCreated -> wkC
+                FcChanged -> wkM
+                FcDeleted -> wkD
+              params = DidChangeWatchedFilesParams (List [fe])
+              msg = fmClientDidChangeWatchedFilesNotification params
+          liftIO $ when matches $ B.hPut handle (addHeader $ encode msg)
+        modify $ \s ->
+          s { unwatchers = Map.insertWith (++) regId [unwatch] (unwatchers s) }
+
+processFileWatchUnregistrations :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
+                                => [T.Text] -> m ()
+processFileWatchUnregistrations regIds =
+  forM_ regIds $ \regId -> modifyM $ \s -> do
+    let fs = fromMaybe [] (Map.lookup regId (unwatchers s))
+    liftIO $ sequence fs
+    return $ s { unwatchers = Map.delete regId (unwatchers s) }