Fix curtimeoutid being reset in the server exit handler
[lsp-test.git] / src / Language / Haskell / LSP / Test / Session.hs
index a4532b8f1bc2351243e51b9eee3c5d6cf8f94720..ddd07a5da6693cc3ccc27bd0845df8cd7b58734d 100644 (file)
@@ -23,6 +23,8 @@ module Language.Haskell.LSP.Test.Session
   , sendMessage
   , updateState
   , withTimeout
+  , getCurTimeoutId
+  , bumpTimeoutId
   , logMsg
   , LogMsgType(..)
   )
@@ -60,7 +62,7 @@ import Data.Function
 import Language.Haskell.LSP.Messages
 import Language.Haskell.LSP.Types.Capabilities
 import Language.Haskell.LSP.Types
-import Language.Haskell.LSP.Types.Lens hiding (error)
+import Language.Haskell.LSP.Types.Lens
 import Language.Haskell.LSP.VFS
 import Language.Haskell.LSP.Test.Compat
 import Language.Haskell.LSP.Test.Decoding
@@ -70,7 +72,6 @@ import System.Directory
 import System.IO
 import System.Process (ProcessHandle())
 import System.Timeout
-import System.IO.Temp
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -92,15 +93,24 @@ instance MonadFail Session where
 -- | Stuff you can configure for a 'Session'.
 data SessionConfig = SessionConfig
   { messageTimeout :: Int  -- ^ Maximum time to wait for a message in seconds, defaults to 60.
-  , logStdErr      :: Bool -- ^ Redirect the server's stderr to this stdout, defaults to False.
-  , logMessages    :: Bool -- ^ Trace the messages sent and received to stdout, defaults to False.
+  , logStdErr      :: Bool
+  -- ^ Redirect the server's stderr to this stdout, defaults to False.
+  -- Can be overriden with @LSP_TEST_LOG_STDERR@.
+  , logMessages    :: Bool
+  -- ^ Trace the messages sent and received to stdout, defaults to False.
+  -- Can be overriden with the environment variable @LSP_TEST_LOG_MESSAGES@.
   , logColor       :: Bool -- ^ Add ANSI color to the logged messages, defaults to True.
   , lspConfig      :: Maybe Value -- ^ The initial LSP config as JSON value, defaults to Nothing.
+  , ignoreLogNotifications :: Bool
+  -- ^ Whether or not to ignore 'Language.Haskell.LSP.Types.ShowMessageNotification' and
+  -- 'Language.Haskell.LSP.Types.LogMessageNotification', defaults to False.
+  --
+  -- @since 0.9.0.0
   }
 
 -- | The configuration used in 'Language.Haskell.LSP.Test.runSession'.
 defaultConfig :: SessionConfig
-defaultConfig = SessionConfig 60 False False True Nothing
+defaultConfig = SessionConfig 60 False False True Nothing False
 
 instance Default SessionConfig where
   def = defaultConfig
@@ -113,7 +123,9 @@ data SessionContext = SessionContext
   {
     serverIn :: Handle
   , rootDir :: FilePath
-  , messageChan :: Chan SessionMessage
+  , messageChan :: Chan SessionMessage -- ^ Where all messages come through
+  -- Keep curTimeoutId in SessionContext, as its tied to messageChan
+  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar InitializeResponse
   , config :: SessionConfig
@@ -131,12 +143,22 @@ instance HasReader SessionContext Session where
 instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
+getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+
+-- Pass this the timeoutid you *were* waiting on
+bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
+bumpTimeoutId prev = do
+  v <- asks curTimeoutId
+  -- when updating the curtimeoutid, account for the fact that something else
+  -- might have bumped the timeoutid in the meantime
+  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+
 data SessionState = SessionState
   {
     curReqId :: LspId
   , vfs :: VFS
   , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
-  , curTimeoutId :: Int
   , overridingTimeout :: Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
@@ -158,15 +180,19 @@ instance HasState SessionState Session where
   get = Session (lift State.get)
   put = Session . lift . State.put
 
-instance Monad m => HasState s (ConduitM a b (StateT s m))
+instance Monad m => HasState s (StateT s m) where
+  get = State.get
+  put = State.put
+
+instance (Monad m, (HasState s m)) => HasState s (ConduitM a b m)
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
-instance Monad m => HasState s (ConduitParser a (StateT s m))
+instance (Monad m, (HasState s m)) => HasState s (ConduitParser a m)
  where
-  get = lift State.get
-  put = lift . State.put
+  get = lift get
+  put = lift . put
 
 runSession :: SessionContext -> SessionState -> Session a -> IO (a, SessionState)
 runSession context state (Session session) = runReaderT (runStateT conduit state) context
@@ -182,15 +208,20 @@ runSession context state (Session session) = runReaderT (runStateT conduit state
 
     chanSource = do
       msg <- liftIO $ readChan (messageChan context)
+      unless (ignoreLogNotifications (config context) && isLogNotification msg) $
         yield msg
       chanSource
 
+    isLogNotification (ServerMessage (NotShowMessage _)) = True
+    isLogNotification (ServerMessage (NotLogMessage _)) = True
+    isLogNotification _ = False
+
     watchdog :: ConduitM SessionMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
     watchdog = Conduit.awaitForever $ \msg -> do
-      curId <- curTimeoutId <$> get
+      curId <- getCurTimeoutId
       case msg of
         ServerMessage sMsg -> yield sMsg
-        TimeoutMessage tId -> when (curId == tId) $ throw Timeout
+        TimeoutMessage tId -> when (curId == tId) $ lastReceivedMessage <$> get >>= throw . Timeout
 
 -- | An internal version of 'runSession' that allows for a custom handler to listen to the server.
 -- It also does not automatically send initialize and exit messages.
@@ -216,26 +247,26 @@ runSessionWithHandles serverIn serverOut serverProc serverHandler config caps ro
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
+  timeoutIdVar <- newMVar 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
-  let context = SessionContext serverIn absRootDir messageChan reqMap initRsp config caps
-      initState tmp_dir = SessionState (IdInt 0) (VFS mempty tmp_dir)
-                                       mempty 0 False Nothing
-      runSession' ses = withSystemTempDirectory "lsp-test" $ \tmp_dir ->
-                      runSession context (initState tmp_dir) ses
+  let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
+      initState vfs = SessionState (IdInt 0) vfs mempty False Nothing
+      runSession' ses = initVFS $ \vfs -> runSession context (initState vfs) ses
 
       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
       serverListenerLauncher =
         forkIO $ catch (serverHandler serverOut context) errorHandler
       server = (Just serverIn, Just serverOut, Nothing, serverProc)
-      serverAndListenerFinalizer tid =
-        finally (timeout (messageTimeout config * 1000000)
+      serverAndListenerFinalizer tid = do
+        finally (timeout (messageTimeout config * 1^6)
                          (runSession' exitServer))
                 (cleanupProcess server >> killThread tid)
 
-  (result, _) <- bracket serverListenerLauncher serverAndListenerFinalizer
+  (result, _) <- bracket serverListenerLauncher
+                         serverAndListenerFinalizer
                          (const $ runSession' session)
   return result
 
@@ -283,7 +314,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
   forM_ bumpedVersions $ \(VersionedTextDocumentIdentifier uri v) ->
     modify $ \s ->
       let oldVFS = vfs s
-          update (VirtualFile oldV t) = VirtualFile (fromMaybe oldV v) t
+          update (VirtualFile oldV file_ver t) = VirtualFile (fromMaybe oldV v) (file_ver + 1) t
           newVFS = updateVFS (Map.adjust update (toNormalizedUri uri)) oldVFS
       in s { vfs = newVFS }
 
@@ -292,7 +323,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
           ctx <- ask
 
           -- if its not open, open it
-          unless (toNormalizedUri uri `Map.member` (vfsMap oldVFS)) $ do
+          unless (toNormalizedUri uri `Map.member` vfsMap oldVFS) $ do
             let fp = fromJust $ uriToFilePath uri
             contents <- liftIO $ T.readFile fp
             let item = TextDocumentItem (filePathToUri fp) "" 0 contents
@@ -300,7 +331,7 @@ updateState (ReqApplyWorkspaceEdit r) = do
             liftIO $ B.hPut (serverIn ctx) $ addHeader (encode msg)
 
             modifyM $ \s -> do
-              newVFS <- liftIO $ openVFS (vfs s) msg
+              let (newVFS,_) = openVFS (vfs s) msg
               return $ s { vfs = newVFS }
 
         getParams (TextDocumentEdit docId (List edits)) =
@@ -330,15 +361,14 @@ sendMessage msg = do
 withTimeout :: Int -> Session a -> Session a
 withTimeout duration f = do
   chan <- asks messageChan
-  timeoutId <- curTimeoutId <$> get
+  timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
   liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
-  modify $ \s -> s { curTimeoutId = timeoutId + 1,
-                     overridingTimeout = False
-                   }
+  bumpTimeoutId timeoutId
+  modify $ \s -> s { overridingTimeout = False }
   return res
 
 data LogMsgType = LogServer | LogClient
@@ -364,3 +394,4 @@ logMsg t msg = do
 
         showPretty = B.unpack . encodePretty
 
+