Kill timeout thread
[lsp-test.git] / src / Language / LSP / Test / Session.hs
index 36007391c35965676be1a5f87a337e33947a3afb..f3d6f8c28e83c7250bd093bde204e79964e08204 100644 (file)
@@ -1,4 +1,5 @@
 {-# LANGUAGE CPP               #-}
+{-# LANGUAGE BangPatterns      #-}
 {-# LANGUAGE GADTs             #-}
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
@@ -29,6 +30,7 @@ module Language.LSP.Test.Session
   , bumpTimeoutId
   , logMsg
   , LogMsgType(..)
+  , documentChangeUri
   )
 
 where
@@ -55,7 +57,8 @@ import Data.Conduit.Parser as Parser
 import Data.Default
 import Data.Foldable
 import Data.List
-import qualified Data.Map as Map
+import qualified Data.Map.Strict as Map
+import qualified Data.Set as Set
 import qualified Data.Text as T
 import qualified Data.Text.IO as T
 import qualified Data.HashMap.Strict as HashMap
@@ -77,6 +80,7 @@ import System.Process (ProcessHandle())
 import System.Process (waitForProcess)
 #endif
 import System.Timeout
+import Data.IORef
 
 -- | A session representing one instance of launching and connecting to a server.
 --
@@ -133,7 +137,7 @@ data SessionContext = SessionContext
   , rootDir :: FilePath
   , messageChan :: Chan SessionMessage -- ^ Where all messages come through
   -- Keep curTimeoutId in SessionContext, as its tied to messageChan
-  , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
+  , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on
   , requestMap :: MVar RequestMap
   , initRsp :: MVar (ResponseMessage Initialize)
   , config :: SessionConfig
@@ -152,7 +156,7 @@ instance Monad m => HasReader r (ConduitM a b (StateT s (ReaderT r m))) where
   ask = lift $ lift Reader.ask
 
 getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
-getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef
 
 -- Pass this the timeoutid you *were* waiting on
 bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
@@ -160,20 +164,21 @@ bumpTimeoutId prev = do
   v <- asks curTimeoutId
   -- when updating the curtimeoutid, account for the fact that something else
   -- might have bumped the timeoutid in the meantime
-  liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+  liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ()))
 
 data SessionState = SessionState
   {
-    curReqId :: Int
-  , vfs :: VFS
-  , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
-  , overridingTimeout :: Bool
+    curReqId :: !Int
+  , vfs :: !VFS
+  , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic])
+  , overridingTimeout :: !Bool
   -- ^ The last received message from the server.
   -- Used for providing exception information
-  , lastReceivedMessage :: Maybe FromServerMessage
-  , curDynCaps :: Map.Map T.Text SomeRegistration
+  , lastReceivedMessage :: !(Maybe FromServerMessage)
+  , curDynCaps :: !(Map.Map T.Text SomeRegistration)
   -- ^ The capabilities that the server has dynamically registered with us so
   -- far
+  , curProgressSessions :: !(Set.Set ProgressToken)
   }
 
 class Monad m => HasState s m where
@@ -258,13 +263,13 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi
 
   reqMap <- newMVar newRequestMap
   messageChan <- newChan
-  timeoutIdVar <- newMVar 0
+  timeoutIdVar <- newIORef 0
   initRsp <- newEmptyMVar
 
   mainThreadId <- myThreadId
 
   let context = SessionContext serverIn absRootDir messageChan timeoutIdVar reqMap initRsp config caps
-      initState vfs = SessionState 0 vfs mempty False Nothing mempty
+      initState vfs = SessionState 0 vfs mempty False Nothing mempty mempty
       runSession' ses = initVFS $ \vfs -> runSessionMonad context (initState vfs) ses
 
       errorHandler = throwTo mainThreadId :: SessionException -> IO ()
@@ -294,10 +299,19 @@ runSession' serverIn serverOut mServerProc serverHandler config caps rootDir exi
 updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
 updateStateC = awaitForever $ \msg -> do
   updateState msg
+  respond msg
   yield msg
+  where
+    respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
+    respond (FromServerMess SWindowWorkDoneProgressCreate req) =
+      sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
+    respond (FromServerMess SWorkspaceApplyEdit r) = do
+      sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
+    respond _ = pure ()
 
 
 -- extract Uri out from DocumentChange
+-- didn't put this in `lsp-types` because TH was getting in the way
 documentChangeUri :: DocumentChange -> Uri
 documentChangeUri (InL x) = x ^. textDocument . uri
 documentChangeUri (InR (InL x)) = x ^. uri
@@ -306,6 +320,12 @@ documentChangeUri (InR (InR (InR x))) = x ^. uri
 
 updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
             => FromServerMessage -> m ()
+updateState (FromServerMess SProgress req) = case req ^. params . value of
+  Begin _ ->
+    modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
+  End _ ->
+    modify $ \s -> s { curProgressSessions = Set.delete (req ^. params . token) $ curProgressSessions s }
+  _ -> pure ()
 
 -- Keep track of dynamic capability registration
 updateState (FromServerMess SClientRegisterCapability req) = do
@@ -423,10 +443,11 @@ withTimeout duration f = do
   chan <- asks messageChan
   timeoutId <- getCurTimeoutId
   modify $ \s -> s { overridingTimeout = True }
-  liftIO $ forkIO $ do
+  tid <- liftIO $ forkIO $ do
     threadDelay (duration * 1000000)
     writeChan chan (TimeoutMessage timeoutId)
   res <- f
+  liftIO $ killThread tid
   bumpTimeoutId timeoutId
   modify $ \s -> s { overridingTimeout = False }
   return res