{-# LANGUAGE CPP #-}
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
import Data.Default
import Data.Foldable
import Data.List
-import qualified Data.Map as Map
+import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as T
import qualified Data.Text.IO as T
import System.Process (waitForProcess)
#endif
import System.Timeout
+import Data.IORef
-- | A session representing one instance of launching and connecting to a server.
--
, rootDir :: FilePath
, messageChan :: Chan SessionMessage -- ^ Where all messages come through
-- Keep curTimeoutId in SessionContext, as its tied to messageChan
- , curTimeoutId :: MVar Int -- ^ The current timeout we are waiting on
+ , curTimeoutId :: IORef Int -- ^ The current timeout we are waiting on
, requestMap :: MVar RequestMap
, initRsp :: MVar (ResponseMessage Initialize)
, config :: SessionConfig
ask = lift $ lift Reader.ask
getCurTimeoutId :: (HasReader SessionContext m, MonadIO m) => m Int
-getCurTimeoutId = asks curTimeoutId >>= liftIO . readMVar
+getCurTimeoutId = asks curTimeoutId >>= liftIO . readIORef
-- Pass this the timeoutid you *were* waiting on
bumpTimeoutId :: (HasReader SessionContext m, MonadIO m) => Int -> m ()
v <- asks curTimeoutId
-- when updating the curtimeoutid, account for the fact that something else
-- might have bumped the timeoutid in the meantime
- liftIO $ modifyMVar_ v (\x -> pure (max x (prev + 1)))
+ liftIO $ atomicModifyIORef' v (\x -> (max x (prev + 1), ()))
data SessionState = SessionState
{
- curReqId :: Int
- , vfs :: VFS
- , curDiagnostics :: Map.Map NormalizedUri [Diagnostic]
- , overridingTimeout :: Bool
+ curReqId :: !Int
+ , vfs :: !VFS
+ , curDiagnostics :: !(Map.Map NormalizedUri [Diagnostic])
+ , overridingTimeout :: !Bool
-- ^ The last received message from the server.
-- Used for providing exception information
- , lastReceivedMessage :: Maybe FromServerMessage
- , curDynCaps :: Map.Map T.Text SomeRegistration
+ , lastReceivedMessage :: !(Maybe FromServerMessage)
+ , curDynCaps :: !(Map.Map T.Text SomeRegistration)
-- ^ The capabilities that the server has dynamically registered with us so
-- far
- , curProgressSessions :: Set.Set ProgressToken
+ , curProgressSessions :: !(Set.Set ProgressToken)
}
class Monad m => HasState s m where
reqMap <- newMVar newRequestMap
messageChan <- newChan
- timeoutIdVar <- newMVar 0
+ timeoutIdVar <- newIORef 0
initRsp <- newEmptyMVar
mainThreadId <- myThreadId
updateStateC :: ConduitM FromServerMessage FromServerMessage (StateT SessionState (ReaderT SessionContext IO)) ()
updateStateC = awaitForever $ \msg -> do
updateState msg
+ respond msg
yield msg
+ where
+ respond :: (MonadIO m, HasReader SessionContext m) => FromServerMessage -> m ()
+ respond (FromServerMess SWindowWorkDoneProgressCreate req) =
+ sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
+ respond (FromServerMess SWorkspaceApplyEdit r) = do
+ sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
+ respond _ = pure ()
+
-- extract Uri out from DocumentChange
-- didn't put this in `lsp-types` because TH was getting in the way
updateState :: (MonadIO m, HasReader SessionContext m, HasState SessionState m)
=> FromServerMessage -> m ()
-updateState (FromServerMess SWindowWorkDoneProgressCreate req) =
- sendMessage $ ResponseMessage "2.0" (Just $ req ^. LSP.id) (Right ())
updateState (FromServerMess SProgress req) = case req ^. params . value of
Begin _ ->
modify $ \s -> s { curProgressSessions = Set.insert (req ^. params . token) $ curProgressSessions s }
-- TODO: Don't do this when replaying a session
forM_ mergedParams (sendMessage . NotificationMessage "2.0" STextDocumentDidChange)
- sendMessage $ ResponseMessage "2.0" (Just $ r ^. LSP.id) (Right $ ApplyWorkspaceEditResponseBody True Nothing)
-
-- Update VFS to new document versions
let sortedVersions = map (sortBy (compare `on` (^. textDocument . version))) groupedParams
latestVersions = map ((^. textDocument) . last) sortedVersions
chan <- asks messageChan
timeoutId <- getCurTimeoutId
modify $ \s -> s { overridingTimeout = True }
- liftIO $ forkIO $ do
+ tid <- liftIO $ forkIO $ do
threadDelay (duration * 1000000)
writeChan chan (TimeoutMessage timeoutId)
res <- f
+ liftIO $ killThread tid
bumpTimeoutId timeoutId
modify $ \s -> s { overridingTimeout = False }
return res